294d3bedf9
Until we depend on Go 1.16 (which isn't released yet), alias our own variable to the private member of the net package. This will allow an easy find replace to make this go away when we eventually switch to 1.16. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
568 lines
10 KiB
Go
568 lines
10 KiB
Go
// +build !android
|
|
|
|
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package conn
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"syscall"
|
|
"unsafe"
|
|
|
|
"golang.org/x/sys/unix"
|
|
)
|
|
|
|
type IPv4Source struct {
|
|
Src [4]byte
|
|
Ifindex int32
|
|
}
|
|
|
|
type IPv6Source struct {
|
|
src [16]byte
|
|
//ifindex belongs in dst.ZoneId
|
|
}
|
|
|
|
type NativeEndpoint struct {
|
|
sync.Mutex
|
|
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
|
|
src [unsafe.Sizeof(IPv6Source{})]byte
|
|
isV6 bool
|
|
}
|
|
|
|
func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() }
|
|
func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
|
|
func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 }
|
|
|
|
func (endpoint *NativeEndpoint) src4() *IPv4Source {
|
|
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
|
|
}
|
|
|
|
func (endpoint *NativeEndpoint) src6() *IPv6Source {
|
|
return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
|
|
}
|
|
|
|
func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
|
|
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
|
|
}
|
|
|
|
func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
|
|
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
|
|
}
|
|
|
|
type nativeBind struct {
|
|
sock4 int
|
|
sock6 int
|
|
lastMark uint32
|
|
closing sync.RWMutex
|
|
}
|
|
|
|
var _ Endpoint = (*NativeEndpoint)(nil)
|
|
var _ Bind = (*nativeBind)(nil)
|
|
|
|
func CreateEndpoint(s string) (Endpoint, error) {
|
|
var end NativeEndpoint
|
|
addr, err := parseEndpoint(s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ipv4 := addr.IP.To4()
|
|
if ipv4 != nil {
|
|
dst := end.dst4()
|
|
end.isV6 = false
|
|
dst.Port = addr.Port
|
|
copy(dst.Addr[:], ipv4)
|
|
end.ClearSrc()
|
|
return &end, nil
|
|
}
|
|
|
|
ipv6 := addr.IP.To16()
|
|
if ipv6 != nil {
|
|
zone, err := zoneToUint32(addr.Zone)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
dst := end.dst6()
|
|
end.isV6 = true
|
|
dst.Port = addr.Port
|
|
dst.ZoneId = zone
|
|
copy(dst.Addr[:], ipv6[:])
|
|
end.ClearSrc()
|
|
return &end, nil
|
|
}
|
|
|
|
return nil, errors.New("Invalid IP address")
|
|
}
|
|
|
|
func createBind(port uint16) (Bind, uint16, error) {
|
|
var err error
|
|
var bind nativeBind
|
|
var newPort uint16
|
|
|
|
// Attempt ipv6 bind, update port if successful.
|
|
bind.sock6, newPort, err = create6(port)
|
|
if err != nil {
|
|
if err != syscall.EAFNOSUPPORT {
|
|
return nil, 0, err
|
|
}
|
|
} else {
|
|
port = newPort
|
|
}
|
|
|
|
// Attempt ipv4 bind, update port if successful.
|
|
bind.sock4, newPort, err = create4(port)
|
|
if err != nil {
|
|
if err != syscall.EAFNOSUPPORT {
|
|
unix.Close(bind.sock6)
|
|
return nil, 0, err
|
|
}
|
|
} else {
|
|
port = newPort
|
|
}
|
|
|
|
if bind.sock4 == -1 && bind.sock6 == -1 {
|
|
return nil, 0, errors.New("ipv4 and ipv6 not supported")
|
|
}
|
|
|
|
return &bind, port, nil
|
|
}
|
|
|
|
func (bind *nativeBind) LastMark() uint32 {
|
|
return bind.lastMark
|
|
}
|
|
|
|
func (bind *nativeBind) SetMark(value uint32) error {
|
|
bind.closing.RLock()
|
|
defer bind.closing.RUnlock()
|
|
|
|
if bind.sock6 != -1 {
|
|
err := unix.SetsockoptInt(
|
|
bind.sock6,
|
|
unix.SOL_SOCKET,
|
|
unix.SO_MARK,
|
|
int(value),
|
|
)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if bind.sock4 != -1 {
|
|
err := unix.SetsockoptInt(
|
|
bind.sock4,
|
|
unix.SOL_SOCKET,
|
|
unix.SO_MARK,
|
|
int(value),
|
|
)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
bind.lastMark = value
|
|
return nil
|
|
}
|
|
|
|
func (bind *nativeBind) Close() error {
|
|
var err1, err2 error
|
|
bind.closing.RLock()
|
|
if bind.sock6 != -1 {
|
|
unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
|
|
}
|
|
if bind.sock4 != -1 {
|
|
unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
|
|
}
|
|
bind.closing.RUnlock()
|
|
bind.closing.Lock()
|
|
if bind.sock6 != -1 {
|
|
err1 = unix.Close(bind.sock6)
|
|
bind.sock6 = -1
|
|
}
|
|
if bind.sock4 != -1 {
|
|
err2 = unix.Close(bind.sock4)
|
|
bind.sock4 = -1
|
|
}
|
|
bind.closing.Unlock()
|
|
|
|
if err1 != nil {
|
|
return err1
|
|
}
|
|
return err2
|
|
}
|
|
|
|
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
|
bind.closing.RLock()
|
|
defer bind.closing.RUnlock()
|
|
|
|
var end NativeEndpoint
|
|
if bind.sock6 == -1 {
|
|
return 0, nil, NetErrClosed
|
|
}
|
|
n, err := receive6(
|
|
bind.sock6,
|
|
buff,
|
|
&end,
|
|
)
|
|
return n, &end, err
|
|
}
|
|
|
|
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
|
bind.closing.RLock()
|
|
defer bind.closing.RUnlock()
|
|
|
|
var end NativeEndpoint
|
|
if bind.sock4 == -1 {
|
|
return 0, nil, NetErrClosed
|
|
}
|
|
n, err := receive4(
|
|
bind.sock4,
|
|
buff,
|
|
&end,
|
|
)
|
|
return n, &end, err
|
|
}
|
|
|
|
func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
|
|
bind.closing.RLock()
|
|
defer bind.closing.RUnlock()
|
|
|
|
nend := end.(*NativeEndpoint)
|
|
if !nend.isV6 {
|
|
if bind.sock4 == -1 {
|
|
return NetErrClosed
|
|
}
|
|
return send4(bind.sock4, nend, buff)
|
|
} else {
|
|
if bind.sock6 == -1 {
|
|
return NetErrClosed
|
|
}
|
|
return send6(bind.sock6, nend, buff)
|
|
}
|
|
}
|
|
|
|
func (end *NativeEndpoint) SrcIP() net.IP {
|
|
if !end.isV6 {
|
|
return net.IPv4(
|
|
end.src4().Src[0],
|
|
end.src4().Src[1],
|
|
end.src4().Src[2],
|
|
end.src4().Src[3],
|
|
)
|
|
} else {
|
|
return end.src6().src[:]
|
|
}
|
|
}
|
|
|
|
func (end *NativeEndpoint) DstIP() net.IP {
|
|
if !end.isV6 {
|
|
return net.IPv4(
|
|
end.dst4().Addr[0],
|
|
end.dst4().Addr[1],
|
|
end.dst4().Addr[2],
|
|
end.dst4().Addr[3],
|
|
)
|
|
} else {
|
|
return end.dst6().Addr[:]
|
|
}
|
|
}
|
|
|
|
func (end *NativeEndpoint) DstToBytes() []byte {
|
|
if !end.isV6 {
|
|
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
|
|
} else {
|
|
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
|
|
}
|
|
}
|
|
|
|
func (end *NativeEndpoint) SrcToString() string {
|
|
return end.SrcIP().String()
|
|
}
|
|
|
|
func (end *NativeEndpoint) DstToString() string {
|
|
var udpAddr net.UDPAddr
|
|
udpAddr.IP = end.DstIP()
|
|
if !end.isV6 {
|
|
udpAddr.Port = end.dst4().Port
|
|
} else {
|
|
udpAddr.Port = end.dst6().Port
|
|
}
|
|
return udpAddr.String()
|
|
}
|
|
|
|
func (end *NativeEndpoint) ClearDst() {
|
|
for i := range end.dst {
|
|
end.dst[i] = 0
|
|
}
|
|
}
|
|
|
|
func (end *NativeEndpoint) ClearSrc() {
|
|
for i := range end.src {
|
|
end.src[i] = 0
|
|
}
|
|
}
|
|
|
|
func zoneToUint32(zone string) (uint32, error) {
|
|
if zone == "" {
|
|
return 0, nil
|
|
}
|
|
if intr, err := net.InterfaceByName(zone); err == nil {
|
|
return uint32(intr.Index), nil
|
|
}
|
|
n, err := strconv.ParseUint(zone, 10, 32)
|
|
return uint32(n), err
|
|
}
|
|
|
|
func create4(port uint16) (int, uint16, error) {
|
|
|
|
// create socket
|
|
|
|
fd, err := unix.Socket(
|
|
unix.AF_INET,
|
|
unix.SOCK_DGRAM,
|
|
0,
|
|
)
|
|
|
|
if err != nil {
|
|
return -1, 0, err
|
|
}
|
|
|
|
addr := unix.SockaddrInet4{
|
|
Port: int(port),
|
|
}
|
|
|
|
// set sockopts and bind
|
|
|
|
if err := func() error {
|
|
if err := unix.SetsockoptInt(
|
|
fd,
|
|
unix.IPPROTO_IP,
|
|
unix.IP_PKTINFO,
|
|
1,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
|
|
return unix.Bind(fd, &addr)
|
|
}(); err != nil {
|
|
unix.Close(fd)
|
|
return -1, 0, err
|
|
}
|
|
|
|
sa, err := unix.Getsockname(fd)
|
|
if err == nil {
|
|
addr.Port = sa.(*unix.SockaddrInet4).Port
|
|
}
|
|
|
|
return fd, uint16(addr.Port), err
|
|
}
|
|
|
|
func create6(port uint16) (int, uint16, error) {
|
|
|
|
// create socket
|
|
|
|
fd, err := unix.Socket(
|
|
unix.AF_INET6,
|
|
unix.SOCK_DGRAM,
|
|
0,
|
|
)
|
|
|
|
if err != nil {
|
|
return -1, 0, err
|
|
}
|
|
|
|
// set sockopts and bind
|
|
|
|
addr := unix.SockaddrInet6{
|
|
Port: int(port),
|
|
}
|
|
|
|
if err := func() error {
|
|
if err := unix.SetsockoptInt(
|
|
fd,
|
|
unix.IPPROTO_IPV6,
|
|
unix.IPV6_RECVPKTINFO,
|
|
1,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := unix.SetsockoptInt(
|
|
fd,
|
|
unix.IPPROTO_IPV6,
|
|
unix.IPV6_V6ONLY,
|
|
1,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
|
|
return unix.Bind(fd, &addr)
|
|
|
|
}(); err != nil {
|
|
unix.Close(fd)
|
|
return -1, 0, err
|
|
}
|
|
|
|
sa, err := unix.Getsockname(fd)
|
|
if err == nil {
|
|
addr.Port = sa.(*unix.SockaddrInet6).Port
|
|
}
|
|
|
|
return fd, uint16(addr.Port), err
|
|
}
|
|
|
|
func send4(sock int, end *NativeEndpoint, buff []byte) error {
|
|
|
|
// construct message header
|
|
|
|
cmsg := struct {
|
|
cmsghdr unix.Cmsghdr
|
|
pktinfo unix.Inet4Pktinfo
|
|
}{
|
|
unix.Cmsghdr{
|
|
Level: unix.IPPROTO_IP,
|
|
Type: unix.IP_PKTINFO,
|
|
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
|
|
},
|
|
unix.Inet4Pktinfo{
|
|
Spec_dst: end.src4().Src,
|
|
Ifindex: end.src4().Ifindex,
|
|
},
|
|
}
|
|
|
|
end.Lock()
|
|
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
|
end.Unlock()
|
|
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
// clear src and retry
|
|
|
|
if err == unix.EINVAL {
|
|
end.ClearSrc()
|
|
cmsg.pktinfo = unix.Inet4Pktinfo{}
|
|
end.Lock()
|
|
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
|
end.Unlock()
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func send6(sock int, end *NativeEndpoint, buff []byte) error {
|
|
|
|
// construct message header
|
|
|
|
cmsg := struct {
|
|
cmsghdr unix.Cmsghdr
|
|
pktinfo unix.Inet6Pktinfo
|
|
}{
|
|
unix.Cmsghdr{
|
|
Level: unix.IPPROTO_IPV6,
|
|
Type: unix.IPV6_PKTINFO,
|
|
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
|
|
},
|
|
unix.Inet6Pktinfo{
|
|
Addr: end.src6().src,
|
|
Ifindex: end.dst6().ZoneId,
|
|
},
|
|
}
|
|
|
|
if cmsg.pktinfo.Addr == [16]byte{} {
|
|
cmsg.pktinfo.Ifindex = 0
|
|
}
|
|
|
|
end.Lock()
|
|
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
|
end.Unlock()
|
|
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
// clear src and retry
|
|
|
|
if err == unix.EINVAL {
|
|
end.ClearSrc()
|
|
cmsg.pktinfo = unix.Inet6Pktinfo{}
|
|
end.Lock()
|
|
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
|
end.Unlock()
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
|
|
|
// construct message header
|
|
|
|
var cmsg struct {
|
|
cmsghdr unix.Cmsghdr
|
|
pktinfo unix.Inet4Pktinfo
|
|
}
|
|
|
|
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
|
|
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
end.isV6 = false
|
|
|
|
if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
|
|
*end.dst4() = *newDst4
|
|
}
|
|
|
|
// update source cache
|
|
|
|
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
|
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
|
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
|
end.src4().Src = cmsg.pktinfo.Spec_dst
|
|
end.src4().Ifindex = cmsg.pktinfo.Ifindex
|
|
}
|
|
|
|
return size, nil
|
|
}
|
|
|
|
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
|
|
|
// construct message header
|
|
|
|
var cmsg struct {
|
|
cmsghdr unix.Cmsghdr
|
|
pktinfo unix.Inet6Pktinfo
|
|
}
|
|
|
|
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
|
|
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
end.isV6 = true
|
|
|
|
if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
|
|
*end.dst6() = *newDst6
|
|
}
|
|
|
|
// update source cache
|
|
|
|
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
|
|
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
|
|
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
|
|
end.src6().src = cmsg.pktinfo.Addr
|
|
end.dst6().ZoneId = cmsg.pktinfo.Ifindex
|
|
}
|
|
|
|
return size, nil
|
|
}
|