From c70f0c5da2a97715f5989f0d95ec795bdb085898 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 6 Oct 2017 22:56:01 +0200 Subject: [PATCH] Definition of platform specific socket bind --- src/conn.go | 2 +- src/conn_default.go | 2 +- src/conn_linux.go | 232 +++++++++++++++++++++++++++++++++++++------- src/uapi.go | 2 +- 4 files changed, 198 insertions(+), 40 deletions(-) diff --git a/src/conn.go b/src/conn.go index 2cf588d..60cd789 100644 --- a/src/conn.go +++ b/src/conn.go @@ -56,7 +56,7 @@ func updateUDPConn(device *Device) error { // set fwmark - err = setMark(netc.conn, netc.fwmark) + err = SetMark(netc.conn, netc.fwmark) if err != nil { return err } diff --git a/src/conn_default.go b/src/conn_default.go index e7c60a8..279643e 100644 --- a/src/conn_default.go +++ b/src/conn_default.go @@ -6,6 +6,6 @@ import ( "net" ) -func setMark(conn *net.UDPConn, value uint32) error { +func SetMark(conn *net.UDPConn, value uint32) error { return nil } diff --git a/src/conn_linux.go b/src/conn_linux.go index a349a9e..64447a5 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -14,23 +14,30 @@ import ( "unsafe" ) +import "fmt" + /* Supports source address caching - * - * It is important that the endpoint is only updated after the packet content has been authenticated. * * Currently there is no way to achieve this within the net package: * See e.g. https://github.com/golang/go/issues/17930 + * So this code is platform dependent. + * + * It is important that the endpoint is only updated after the packet content has been authenticated! */ + type Endpoint struct { // source (selected based on dst type) // (could use RawSockaddrAny and unsafe) - srcIPv6 unix.RawSockaddrInet6 - srcIPv4 unix.RawSockaddrInet4 - srcIf4 int32 + src6 unix.RawSockaddrInet6 + src4 unix.RawSockaddrInet4 + src4if int32 dst unix.RawSockaddrAny } +type IPv4Socket int +type IPv6Socket int + func zoneToUint32(zone string) (uint32, error) { if zone == "" { return 0, nil @@ -42,10 +49,115 @@ func zoneToUint32(zone string) (uint32, error) { return uint32(n), err } +func CreateIPv4Socket(port int) (IPv4Socket, error) { + + // create socket + + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return -1, err + } + + // set sockopts and bind + + if err := func() error { + + if err := unix.SetsockoptInt( + fd, + unix.SOL_SOCKET, + unix.SO_REUSEADDR, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IP, + unix.IP_PKTINFO, + 1, + ); err != nil { + return err + } + + addr := unix.SockaddrInet4{ + Port: port, + } + return unix.Bind(fd, &addr) + + }(); err != nil { + unix.Close(fd) + } + + return IPv4Socket(fd), err +} + +func CreateIPv6Socket(port int) (IPv6Socket, error) { + + // create socket + + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return -1, err + } + + // set sockopts and bind + + if err := func() error { + + if err := unix.SetsockoptInt( + fd, + unix.SOL_SOCKET, + unix.SO_REUSEADDR, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IPV6, + unix.IPV6_RECVPKTINFO, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IPV6, + unix.IPV6_V6ONLY, + 1, + ); err != nil { + return err + } + + addr := unix.SockaddrInet6{ + Port: port, + } + return unix.Bind(fd, &addr) + + }(); err != nil { + unix.Close(fd) + } + + return IPv6Socket(fd), err +} + func (end *Endpoint) ClearSrc() { - end.srcIf4 = 0 - end.srcIPv4 = unix.RawSockaddrInet4{} - end.srcIPv6 = unix.RawSockaddrInet6{} + end.src4if = 0 + end.src4 = unix.RawSockaddrInet4{} + end.src6 = unix.RawSockaddrInet6{} } func (end *Endpoint) Set(s string) error { @@ -85,8 +197,10 @@ func (end *Endpoint) Set(s string) error { } func send6(sock uintptr, end *Endpoint, buff []byte) error { - var iovec unix.Iovec + // construct message header + + var iovec unix.Iovec iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) iovec.SetLen(len(buff)) @@ -100,8 +214,8 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error { Len: unix.SizeofInet6Pktinfo, }, unix.Inet6Pktinfo{ - Addr: end.srcIPv6.Addr, - Ifindex: end.srcIPv6.Scope_id, + Addr: end.src6.Addr, + Ifindex: end.src6.Scope_id, }, } @@ -130,8 +244,10 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error { } func send4(sock uintptr, end *Endpoint, buff []byte) error { - var iovec unix.Iovec + // construct message header + + var iovec unix.Iovec iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) iovec.SetLen(len(buff)) @@ -142,11 +258,11 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { unix.Cmsghdr{ Level: unix.IPPROTO_IP, Type: unix.IP_PKTINFO, - Len: unix.SizeofInet6Pktinfo, + Len: unix.SizeofInet4Pktinfo, }, unix.Inet4Pktinfo{ - Spec_dst: end.srcIPv4.Addr, - Ifindex: end.srcIf4, + Spec_dst: end.src4.Addr, + Ifindex: end.src4if, }, } @@ -174,7 +290,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { return errno } -func send(c *net.UDPConn, end *Endpoint, buff []byte) error { +func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error { // extract underlying file descriptor @@ -195,12 +311,9 @@ func send(c *net.UDPConn, end *Endpoint, buff []byte) error { return errors.New("Unknown address family of source") } -func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) { +func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) { - file, err := c.File() - if err != nil { - return err, nil, nil - } + // contruct message header var iovec unix.Iovec iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) @@ -208,47 +321,92 @@ func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAdd var cmsg struct { cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo // big enough + pktinfo unix.Inet4Pktinfo + } + + var msghdr unix.Msghdr + msghdr.Iov = &iovec + msghdr.Iovlen = 1 + msghdr.Name = (*byte)(unsafe.Pointer(&end.dst)) + msghdr.Namelen = unix.SizeofSockaddrInet4 + msghdr.Control = (*byte)(unsafe.Pointer(&cmsg)) + msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) + + // recvmsg(sock, &mskhdr, 0) + + size, _, errno := unix.Syscall( + unix.SYS_RECVMSG, + uintptr(sock), + uintptr(unsafe.Pointer(&msghdr)), + 0, + ) + + if errno != 0 { + return 0, errno + } + + fmt.Println(msghdr) + fmt.Println(cmsg) + + // update source cache + + if cmsg.cmsghdr.Level == unix.IPPROTO_IP && + cmsg.cmsghdr.Type == unix.IP_PKTINFO && + cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { + end.src4.Addr = cmsg.pktinfo.Spec_dst + end.src4if = cmsg.pktinfo.Ifindex + } + + return int(size), nil +} + +func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error { + + // contruct message header + + var iovec unix.Iovec + iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) + iovec.SetLen(len(buff)) + + var cmsg struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet6Pktinfo } var msg unix.Msghdr msg.Iov = &iovec msg.Iovlen = 1 msg.Name = (*byte)(unsafe.Pointer(&end.dst)) - msg.Namelen = uint32(unix.SizeofSockaddrAny) + msg.Namelen = uint32(unix.SizeofSockaddrInet6) msg.Control = (*byte)(unsafe.Pointer(&cmsg)) msg.SetControllen(int(unsafe.Sizeof(cmsg))) + // recvmsg(sock, &mskhdr, 0) + _, _, errno := unix.Syscall( unix.SYS_RECVMSG, - file.Fd(), + uintptr(sock), uintptr(unsafe.Pointer(&msg)), 0, ) if errno != 0 { - return errno, nil, nil + return errno } + // update source cache + if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - + end.src6.Addr = cmsg.pktinfo.Addr + end.src6.Scope_id = cmsg.pktinfo.Ifindex } - if cmsg.cmsghdr.Level == unix.IPPROTO_IP && - cmsg.cmsghdr.Type == unix.IP_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo)) - println(info) - - } - - return nil, nil, nil + return nil } -func setMark(conn *net.UDPConn, value uint32) error { +func SetMark(conn *net.UDPConn, value uint32) error { if conn == nil { return nil } diff --git a/src/uapi.go b/src/uapi.go index 326216b..7d08e56 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -166,7 +166,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { device.net.mutex.Lock() if fwmark > 0 || device.net.fwmark > 0 { device.net.fwmark = uint32(fwmark) - err := setMark( + err := SetMark( device.net.conn, device.net.fwmark, )