From 5ba84696e29c6109e84b1f48247ae02a2bcb106e Mon Sep 17 00:00:00 2001
From: "Jason A. Donenfeld" <Jason@zx2c4.com>
Date: Fri, 20 Apr 2018 04:05:11 +0200
Subject: [PATCH] Rework sticky sockets

---
 conn_linux.go        | 380 ++++++++++++++++++++-----------------------
 syscall_linux.go     |  30 ----
 syscall_linux_386.go |  53 ------
 3 files changed, 172 insertions(+), 291 deletions(-)
 delete mode 100644 syscall_linux.go
 delete mode 100644 syscall_linux_386.go

diff --git a/conn_linux.go b/conn_linux.go
index 8b60d65..88b9ef4 100644
--- a/conn_linux.go
+++ b/conn_linux.go
@@ -1,13 +1,18 @@
-/* Copyright 2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+/* Copyright 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
  *
  * This implements userspace semantics of "sticky sockets", modeled after
- * WireGuard's kernelspace implementation.
+ * WireGuard's kernelspace implementation. This is more or less a straight port
+ * of the sticky-sockets.c example code:
+ * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
+ *
+ * Currently there is no way to achieve this within the net package:
+ * See e.g. https://github.com/golang/go/issues/17930
+ * So this code is remains platform dependent.
  */
 
 package main
 
 import (
-	"encoding/binary"
 	"errors"
 	"golang.org/x/sys/unix"
 	"net"
@@ -15,15 +20,36 @@ import (
 	"unsafe"
 )
 
-/* Supports source address caching
- *
- * Currently there is no way to achieve this within the net package:
- * See e.g. https://github.com/golang/go/issues/17930
- * So this code is remains platform dependent.
- */
+type IPv4Source struct {
+	src     [4]byte
+	ifindex int32
+}
+
+type IPv6Source struct {
+	src [16]byte
+	//ifindex belongs in dst.ZoneId
+}
+
 type NativeEndpoint struct {
-	src unix.RawSockaddrInet6
-	dst unix.RawSockaddrInet6
+	dst  [unsafe.Sizeof(unix.SockaddrInet6{})]byte
+	src  [unsafe.Sizeof(IPv6Source{})]byte
+	isV6 bool
+}
+
+func (endpoint *NativeEndpoint) src4() *IPv4Source {
+	return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
+}
+
+func (endpoint *NativeEndpoint) src6() *IPv6Source {
+	return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
+}
+
+func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
+	return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
+}
+
+func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
+	return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
 }
 
 type NativeBind struct {
@@ -34,22 +60,6 @@ type NativeBind struct {
 var _ Endpoint = (*NativeEndpoint)(nil)
 var _ Bind = NativeBind{}
 
-type IPv4Source struct {
-	src     unix.RawSockaddrInet4
-	Ifindex int32
-}
-
-func htons(val uint16) uint16 {
-	var out [unsafe.Sizeof(val)]byte
-	binary.BigEndian.PutUint16(out[:], val)
-	return *((*uint16)(unsafe.Pointer(&out[0])))
-}
-
-func ntohs(val uint16) uint16 {
-	tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val)))
-	return binary.BigEndian.Uint16((*tmp)[:])
-}
-
 func CreateEndpoint(s string) (Endpoint, error) {
 	var end NativeEndpoint
 	addr, err := parseEndpoint(s)
@@ -59,10 +69,9 @@ func CreateEndpoint(s string) (Endpoint, error) {
 
 	ipv4 := addr.IP.To4()
 	if ipv4 != nil {
-		dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
-		dst.Family = unix.AF_INET
-		dst.Port = htons(uint16(addr.Port))
-		dst.Zero = [8]byte{}
+		dst := end.dst4()
+		end.isV6 = false
+		dst.Port = addr.Port
 		copy(dst.Addr[:], ipv4)
 		end.ClearSrc()
 		return &end, nil
@@ -74,17 +83,16 @@ func CreateEndpoint(s string) (Endpoint, error) {
 		if err != nil {
 			return nil, err
 		}
-		dst := &end.dst
-		dst.Family = unix.AF_INET6
-		dst.Port = htons(uint16(addr.Port))
-		dst.Flowinfo = 0
-		dst.Scope_id = zone
+		dst := end.dst6()
+		end.isV6 = true
+		dst.Port = addr.Port
+		dst.ZoneId = zone
 		copy(dst.Addr[:], ipv6[:])
 		end.ClearSrc()
 		return &end, nil
 	}
 
-	return nil, errors.New("Failed to recognize IP address format")
+	return nil, errors.New("Invalid IP address")
 }
 
 func CreateBind(port uint16) (Bind, uint16, error) {
@@ -160,86 +168,85 @@ func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
 
 func (bind NativeBind) Send(buff []byte, end Endpoint) error {
 	nend := end.(*NativeEndpoint)
-	switch nend.dst.Family {
-	case unix.AF_INET6:
-		return send6(bind.sock6, nend, buff)
-	case unix.AF_INET:
+	if !nend.isV6 {
 		return send4(bind.sock4, nend, buff)
-	default:
-		return errors.New("Unknown address family of destination")
+	} else {
+		return send6(bind.sock6, nend, buff)
 	}
 }
 
-func sockaddrToString(addr unix.RawSockaddrInet6) string {
-	var udpAddr net.UDPAddr
-
-	switch addr.Family {
-	case unix.AF_INET6:
-		udpAddr.Port = int(ntohs(addr.Port))
-		udpAddr.IP = addr.Addr[:]
-		return udpAddr.String()
-
-	case unix.AF_INET:
-		ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
-		udpAddr.Port = int(ntohs(ptr.Port))
-		udpAddr.IP = net.IPv4(
-			ptr.Addr[0],
-			ptr.Addr[1],
-			ptr.Addr[2],
-			ptr.Addr[3],
-		)
-		return udpAddr.String()
-
-	default:
-		return "<unknown address family>"
-	}
+func rawAddrToIP4(addr *unix.SockaddrInet4) net.IP {
+	return net.IPv4(
+		addr.Addr[0],
+		addr.Addr[1],
+		addr.Addr[2],
+		addr.Addr[3],
+	)
 }
 
-func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP {
-	switch addr.Family {
-	case unix.AF_INET6:
-		return addr.Addr[:]
-	case unix.AF_INET:
-		ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
-		return net.IPv4(
-			ptr.Addr[0],
-			ptr.Addr[1],
-			ptr.Addr[2],
-			ptr.Addr[3],
-		)
-	default:
-		return nil
-	}
+func rawAddrToIP6(addr *unix.SockaddrInet6) net.IP {
+	return addr.Addr[:]
 }
 
 func (end *NativeEndpoint) SrcIP() net.IP {
-	return rawAddrToIP(end.src)
+	if !end.isV6 {
+		return net.IPv4(
+			end.src4().src[0],
+			end.src4().src[1],
+			end.src4().src[2],
+			end.src4().src[3],
+		)
+	} else {
+		return end.src6().src[:]
+	}
 }
 
 func (end *NativeEndpoint) DstIP() net.IP {
-	return rawAddrToIP(end.dst)
+	if !end.isV6 {
+		return net.IPv4(
+			end.dst4().Addr[0],
+			end.dst4().Addr[1],
+			end.dst4().Addr[2],
+			end.dst4().Addr[3],
+		)
+	} else {
+		return end.dst6().Addr[:]
+	}
 }
 
 func (end *NativeEndpoint) DstToBytes() []byte {
-	ptr := unsafe.Pointer(&end.src)
-	arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
-	return arr[:]
+	if !end.isV6 {
+		return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
+	} else {
+		return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
+	}
 }
 
 func (end *NativeEndpoint) SrcToString() string {
-	return sockaddrToString(end.src)
+	return end.SrcIP().String()
 }
 
 func (end *NativeEndpoint) DstToString() string {
-	return sockaddrToString(end.dst)
+	var udpAddr net.UDPAddr
+	udpAddr.IP = end.DstIP()
+	if !end.isV6 {
+		udpAddr.Port = end.dst4().Port
+	} else {
+		udpAddr.Port = end.dst6().Port
+	}
+	return udpAddr.String()
 }
 
 func (end *NativeEndpoint) ClearDst() {
-	end.dst = unix.RawSockaddrInet6{}
+	for i := range end.dst {
+		end.dst[i] = 0
+	}
 }
 
 func (end *NativeEndpoint) ClearSrc() {
-	end.src = unix.RawSockaddrInet6{}
+	for i := range end.src {
+		end.src[i] = 0
+	}
 }
 
 func zoneToUint32(zone string) (uint32, error) {
@@ -295,6 +302,7 @@ func create4(port uint16) (int, uint16, error) {
 		return unix.Bind(fd, &addr)
 	}(); err != nil {
 		unix.Close(fd)
+		return -1, 0, err
 	}
 
 	return fd, uint16(addr.Port), err
@@ -353,71 +361,16 @@ func create6(port uint16) (int, uint16, error) {
 
 	}(); err != nil {
 		unix.Close(fd)
+		return -1, 0, err
 	}
 
 	return fd, uint16(addr.Port), err
 }
 
-func send6(sock int, end *NativeEndpoint, buff []byte) error {
-
-	// construct message header
-
-	var iovec unix.Iovec
-	iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
-	iovec.SetLen(len(buff))
-
-	cmsg := struct {
-		cmsghdr unix.Cmsghdr
-		pktinfo unix.Inet6Pktinfo
-	}{
-		unix.Cmsghdr{
-			Level: unix.IPPROTO_IPV6,
-			Type:  unix.IPV6_PKTINFO,
-			Len:   unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
-		},
-		unix.Inet6Pktinfo{
-			Addr:    end.src.Addr,
-			Ifindex: end.src.Scope_id,
-		},
-	}
-
-	msghdr := unix.Msghdr{
-		Iov:     &iovec,
-		Iovlen:  1,
-		Name:    (*byte)(unsafe.Pointer(&end.dst)),
-		Namelen: unix.SizeofSockaddrInet6,
-		Control: (*byte)(unsafe.Pointer(&cmsg)),
-	}
-
-	msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
-
-	_, _, errno := sendmsg(sock, &msghdr, 0)
-
-	if errno == 0 {
-		return nil
-	}
-
-	// clear src and retry
-
-	if errno == unix.EINVAL {
-		end.ClearSrc()
-		cmsg.pktinfo = unix.Inet6Pktinfo{}
-		_, _, errno = sendmsg(sock, &msghdr, 0)
-	}
-
-	return errno
-}
-
 func send4(sock int, end *NativeEndpoint, buff []byte) error {
 
 	// construct message header
 
-	var iovec unix.Iovec
-	iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
-	iovec.SetLen(len(buff))
-
-	src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
-
 	cmsg := struct {
 		cmsghdr unix.Cmsghdr
 		pktinfo unix.Inet4Pktinfo
@@ -428,65 +381,86 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
 			Len:   unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
 		},
 		unix.Inet4Pktinfo{
-			Spec_dst: src4.src.Addr,
-			Ifindex:  src4.Ifindex,
+			Spec_dst: end.src4().src,
+			Ifindex:  end.src4().ifindex,
 		},
 	}
 
-	msghdr := unix.Msghdr{
-		Iov:     &iovec,
-		Iovlen:  1,
-		Name:    (*byte)(unsafe.Pointer(&end.dst)),
-		Namelen: unix.SizeofSockaddrInet4,
-		Control: (*byte)(unsafe.Pointer(&cmsg)),
-		Flags:   0,
-	}
-	msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
+	_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
 
-	_, _, errno := sendmsg(sock, &msghdr, 0)
-
-	// clear source and try again
-
-	if errno == unix.EINVAL {
-		end.ClearSrc()
-		cmsg.pktinfo = unix.Inet4Pktinfo{}
-		_, _, errno = sendmsg(sock, &msghdr, 0)
-	}
-
-	// errno = 0 is still an error instance
-
-	if errno == 0 {
+	if err == nil {
 		return nil
 	}
 
-	return errno
+	// clear src and retry
+
+	if err == unix.EINVAL {
+		end.ClearSrc()
+		cmsg.pktinfo = unix.Inet4Pktinfo{}
+		_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
+	}
+
+	return err
+}
+
+func send6(sock int, end *NativeEndpoint, buff []byte) error {
+
+	// construct message header
+
+	cmsg := struct {
+		cmsghdr unix.Cmsghdr
+		pktinfo unix.Inet6Pktinfo
+	}{
+		unix.Cmsghdr{
+			Level: unix.IPPROTO_IPV6,
+			Type:  unix.IPV6_PKTINFO,
+			Len:   unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
+		},
+		unix.Inet6Pktinfo{
+			Addr:    end.src6().src,
+			Ifindex: end.dst6().ZoneId,
+		},
+	}
+
+	if cmsg.pktinfo.Addr == [16]byte{} {
+		cmsg.pktinfo.Ifindex = 0
+	}
+
+	_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
+
+	if err == nil {
+		return nil
+	}
+
+	// clear src and retry
+
+	if err == unix.EINVAL {
+		end.ClearSrc()
+		cmsg.pktinfo = unix.Inet6Pktinfo{}
+		_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
+	}
+
+	return err
 }
 
 func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
 
 	// contruct message header
 
-	var iovec unix.Iovec
-	iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
-	iovec.SetLen(len(buff))
-
 	var cmsg struct {
 		cmsghdr unix.Cmsghdr
 		pktinfo unix.Inet4Pktinfo
 	}
 
-	var msghdr unix.Msghdr
-	msghdr.Iov = &iovec
-	msghdr.Iovlen = 1
-	msghdr.Name = (*byte)(unsafe.Pointer(&end.dst))
-	msghdr.Namelen = unix.SizeofSockaddrInet4
-	msghdr.Control = (*byte)(unsafe.Pointer(&cmsg))
-	msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
+	size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
 
-	size, _, errno := recvmsg(sock, &msghdr, 0)
+	if err != nil {
+		return 0, err
+	}
+	end.isV6 = false
 
-	if errno != 0 {
-		return 0, errno
+	if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
+		*end.dst4() = *newDst4
 	}
 
 	// update source cache
@@ -494,40 +468,31 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
 	if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
 		cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
 		cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
-		src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
-		src4.src.Family = unix.AF_INET
-		src4.src.Addr = cmsg.pktinfo.Spec_dst
-		src4.Ifindex = cmsg.pktinfo.Ifindex
+		end.src4().src = cmsg.pktinfo.Spec_dst
+		end.src4().ifindex = cmsg.pktinfo.Ifindex
 	}
 
-	return int(size), nil
+	return size, nil
 }
 
 func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
 
 	// contruct message header
 
-	var iovec unix.Iovec
-	iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
-	iovec.SetLen(len(buff))
-
 	var cmsg struct {
 		cmsghdr unix.Cmsghdr
 		pktinfo unix.Inet6Pktinfo
 	}
 
-	var msg unix.Msghdr
-	msg.Iov = &iovec
-	msg.Iovlen = 1
-	msg.Name = (*byte)(unsafe.Pointer(&end.dst))
-	msg.Namelen = uint32(unix.SizeofSockaddrInet6)
-	msg.Control = (*byte)(unsafe.Pointer(&cmsg))
-	msg.SetControllen(int(unsafe.Sizeof(cmsg)))
+	size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
 
-	size, _, errno := recvmsg(sock, &msg, 0)
+	if err != nil {
+		return 0, err
+	}
+	end.isV6 = true
 
-	if errno != 0 {
-		return 0, errno
+	if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
+		*end.dst6() = *newDst6
 	}
 
 	// update source cache
@@ -535,10 +500,9 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
 	if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
 		cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
 		cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
-		end.src.Family = unix.AF_INET6
-		end.src.Addr = cmsg.pktinfo.Addr
-		end.src.Scope_id = cmsg.pktinfo.Ifindex
+		end.src6().src = cmsg.pktinfo.Addr
+		end.dst6().ZoneId = cmsg.pktinfo.Ifindex
 	}
 
-	return int(size), nil
+	return size, nil
 }
diff --git a/syscall_linux.go b/syscall_linux.go
deleted file mode 100644
index 3403544..0000000
--- a/syscall_linux.go
+++ /dev/null
@@ -1,30 +0,0 @@
-// +build linux,!386
-
-/* Copyright 2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
- */
-
-package main
-
-import (
-	"golang.org/x/sys/unix"
-	"syscall"
-	"unsafe"
-)
-
-func sendmsg(fd int, msghdr *unix.Msghdr, flags int) (uintptr, uintptr, syscall.Errno) {
-	return unix.Syscall(
-		unix.SYS_SENDMSG,
-		uintptr(fd),
-		uintptr(unsafe.Pointer(msghdr)),
-		uintptr(flags),
-	)
-}
-
-func recvmsg(fd int, msghdr *unix.Msghdr, flags int) (uintptr, uintptr, syscall.Errno) {
-	return unix.Syscall(
-		unix.SYS_RECVMSG,
-		uintptr(fd),
-		uintptr(unsafe.Pointer(msghdr)),
-		uintptr(flags),
-	)
-}
diff --git a/syscall_linux_386.go b/syscall_linux_386.go
deleted file mode 100644
index 76d7c7e..0000000
--- a/syscall_linux_386.go
+++ /dev/null
@@ -1,53 +0,0 @@
-// +build linux,386
-
-/* Copyright 2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
- */
-
-package main
-
-import (
-	"golang.org/x/sys/unix"
-	"syscall"
-	"unsafe"
-)
-
-const (
-	_SENDMSG = 16
-	_RECVMSG = 17
-)
-
-func sendmsg(fd int, msghdr *unix.Msghdr, flags int) (uintptr, uintptr, syscall.Errno) {
-	args := struct {
-		fd     uintptr
-		msghdr uintptr
-		flags  uintptr
-	}{
-		uintptr(fd),
-		uintptr(unsafe.Pointer(msghdr)),
-		uintptr(flags),
-	}
-	return unix.Syscall(
-		unix.SYS_SOCKETCALL,
-		_SENDMSG,
-		uintptr(unsafe.Pointer(&args)),
-		0,
-	)
-}
-
-func recvmsg(fd int, msghdr *unix.Msghdr, flags int) (uintptr, uintptr, syscall.Errno) {
-	args := struct {
-		fd     uintptr
-		msghdr uintptr
-		flags  uintptr
-	}{
-		uintptr(fd),
-		uintptr(unsafe.Pointer(msghdr)),
-		uintptr(flags),
-	}
-	return unix.Syscall(
-		unix.SYS_SOCKETCALL,
-		_RECVMSG,
-		uintptr(unsafe.Pointer(&args)),
-		0,
-	)
-}