From c70f0c5da2a97715f5989f0d95ec795bdb085898 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 6 Oct 2017 22:56:01 +0200 Subject: [PATCH 01/15] Definition of platform specific socket bind --- src/conn.go | 2 +- src/conn_default.go | 2 +- src/conn_linux.go | 232 +++++++++++++++++++++++++++++++++++++------- src/uapi.go | 2 +- 4 files changed, 198 insertions(+), 40 deletions(-) diff --git a/src/conn.go b/src/conn.go index 2cf588d..60cd789 100644 --- a/src/conn.go +++ b/src/conn.go @@ -56,7 +56,7 @@ func updateUDPConn(device *Device) error { // set fwmark - err = setMark(netc.conn, netc.fwmark) + err = SetMark(netc.conn, netc.fwmark) if err != nil { return err } diff --git a/src/conn_default.go b/src/conn_default.go index e7c60a8..279643e 100644 --- a/src/conn_default.go +++ b/src/conn_default.go @@ -6,6 +6,6 @@ import ( "net" ) -func setMark(conn *net.UDPConn, value uint32) error { +func SetMark(conn *net.UDPConn, value uint32) error { return nil } diff --git a/src/conn_linux.go b/src/conn_linux.go index a349a9e..64447a5 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -14,23 +14,30 @@ import ( "unsafe" ) +import "fmt" + /* Supports source address caching - * - * It is important that the endpoint is only updated after the packet content has been authenticated. * * Currently there is no way to achieve this within the net package: * See e.g. https://github.com/golang/go/issues/17930 + * So this code is platform dependent. + * + * It is important that the endpoint is only updated after the packet content has been authenticated! */ + type Endpoint struct { // source (selected based on dst type) // (could use RawSockaddrAny and unsafe) - srcIPv6 unix.RawSockaddrInet6 - srcIPv4 unix.RawSockaddrInet4 - srcIf4 int32 + src6 unix.RawSockaddrInet6 + src4 unix.RawSockaddrInet4 + src4if int32 dst unix.RawSockaddrAny } +type IPv4Socket int +type IPv6Socket int + func zoneToUint32(zone string) (uint32, error) { if zone == "" { return 0, nil @@ -42,10 +49,115 @@ func zoneToUint32(zone string) (uint32, error) { return uint32(n), err } +func CreateIPv4Socket(port int) (IPv4Socket, error) { + + // create socket + + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return -1, err + } + + // set sockopts and bind + + if err := func() error { + + if err := unix.SetsockoptInt( + fd, + unix.SOL_SOCKET, + unix.SO_REUSEADDR, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IP, + unix.IP_PKTINFO, + 1, + ); err != nil { + return err + } + + addr := unix.SockaddrInet4{ + Port: port, + } + return unix.Bind(fd, &addr) + + }(); err != nil { + unix.Close(fd) + } + + return IPv4Socket(fd), err +} + +func CreateIPv6Socket(port int) (IPv6Socket, error) { + + // create socket + + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return -1, err + } + + // set sockopts and bind + + if err := func() error { + + if err := unix.SetsockoptInt( + fd, + unix.SOL_SOCKET, + unix.SO_REUSEADDR, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IPV6, + unix.IPV6_RECVPKTINFO, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IPV6, + unix.IPV6_V6ONLY, + 1, + ); err != nil { + return err + } + + addr := unix.SockaddrInet6{ + Port: port, + } + return unix.Bind(fd, &addr) + + }(); err != nil { + unix.Close(fd) + } + + return IPv6Socket(fd), err +} + func (end *Endpoint) ClearSrc() { - end.srcIf4 = 0 - end.srcIPv4 = unix.RawSockaddrInet4{} - end.srcIPv6 = unix.RawSockaddrInet6{} + end.src4if = 0 + end.src4 = unix.RawSockaddrInet4{} + end.src6 = unix.RawSockaddrInet6{} } func (end *Endpoint) Set(s string) error { @@ -85,8 +197,10 @@ func (end *Endpoint) Set(s string) error { } func send6(sock uintptr, end *Endpoint, buff []byte) error { - var iovec unix.Iovec + // construct message header + + var iovec unix.Iovec iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) iovec.SetLen(len(buff)) @@ -100,8 +214,8 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error { Len: unix.SizeofInet6Pktinfo, }, unix.Inet6Pktinfo{ - Addr: end.srcIPv6.Addr, - Ifindex: end.srcIPv6.Scope_id, + Addr: end.src6.Addr, + Ifindex: end.src6.Scope_id, }, } @@ -130,8 +244,10 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error { } func send4(sock uintptr, end *Endpoint, buff []byte) error { - var iovec unix.Iovec + // construct message header + + var iovec unix.Iovec iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) iovec.SetLen(len(buff)) @@ -142,11 +258,11 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { unix.Cmsghdr{ Level: unix.IPPROTO_IP, Type: unix.IP_PKTINFO, - Len: unix.SizeofInet6Pktinfo, + Len: unix.SizeofInet4Pktinfo, }, unix.Inet4Pktinfo{ - Spec_dst: end.srcIPv4.Addr, - Ifindex: end.srcIf4, + Spec_dst: end.src4.Addr, + Ifindex: end.src4if, }, } @@ -174,7 +290,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { return errno } -func send(c *net.UDPConn, end *Endpoint, buff []byte) error { +func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error { // extract underlying file descriptor @@ -195,12 +311,9 @@ func send(c *net.UDPConn, end *Endpoint, buff []byte) error { return errors.New("Unknown address family of source") } -func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) { +func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) { - file, err := c.File() - if err != nil { - return err, nil, nil - } + // contruct message header var iovec unix.Iovec iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) @@ -208,47 +321,92 @@ func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAdd var cmsg struct { cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo // big enough + pktinfo unix.Inet4Pktinfo + } + + var msghdr unix.Msghdr + msghdr.Iov = &iovec + msghdr.Iovlen = 1 + msghdr.Name = (*byte)(unsafe.Pointer(&end.dst)) + msghdr.Namelen = unix.SizeofSockaddrInet4 + msghdr.Control = (*byte)(unsafe.Pointer(&cmsg)) + msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) + + // recvmsg(sock, &mskhdr, 0) + + size, _, errno := unix.Syscall( + unix.SYS_RECVMSG, + uintptr(sock), + uintptr(unsafe.Pointer(&msghdr)), + 0, + ) + + if errno != 0 { + return 0, errno + } + + fmt.Println(msghdr) + fmt.Println(cmsg) + + // update source cache + + if cmsg.cmsghdr.Level == unix.IPPROTO_IP && + cmsg.cmsghdr.Type == unix.IP_PKTINFO && + cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { + end.src4.Addr = cmsg.pktinfo.Spec_dst + end.src4if = cmsg.pktinfo.Ifindex + } + + return int(size), nil +} + +func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error { + + // contruct message header + + var iovec unix.Iovec + iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) + iovec.SetLen(len(buff)) + + var cmsg struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet6Pktinfo } var msg unix.Msghdr msg.Iov = &iovec msg.Iovlen = 1 msg.Name = (*byte)(unsafe.Pointer(&end.dst)) - msg.Namelen = uint32(unix.SizeofSockaddrAny) + msg.Namelen = uint32(unix.SizeofSockaddrInet6) msg.Control = (*byte)(unsafe.Pointer(&cmsg)) msg.SetControllen(int(unsafe.Sizeof(cmsg))) + // recvmsg(sock, &mskhdr, 0) + _, _, errno := unix.Syscall( unix.SYS_RECVMSG, - file.Fd(), + uintptr(sock), uintptr(unsafe.Pointer(&msg)), 0, ) if errno != 0 { - return errno, nil, nil + return errno } + // update source cache + if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - + end.src6.Addr = cmsg.pktinfo.Addr + end.src6.Scope_id = cmsg.pktinfo.Ifindex } - if cmsg.cmsghdr.Level == unix.IPPROTO_IP && - cmsg.cmsghdr.Type == unix.IP_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo)) - println(info) - - } - - return nil, nil, nil + return nil } -func setMark(conn *net.UDPConn, value uint32) error { +func SetMark(conn *net.UDPConn, value uint32) error { if conn == nil { return nil } diff --git a/src/uapi.go b/src/uapi.go index 326216b..7d08e56 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -166,7 +166,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { device.net.mutex.Lock() if fwmark > 0 || device.net.fwmark > 0 { device.net.fwmark = uint32(fwmark) - err := setMark( + err := SetMark( device.net.conn, device.net.fwmark, ) From 2d856045a0dbfc15d38d738e2a9d159ba2a49a47 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 7 Oct 2017 22:35:23 +0200 Subject: [PATCH 02/15] Begin incorporating new src cache into receive --- src/conn.go | 104 ++++++++++++++++++++++++++++++---------------- src/conn_linux.go | 70 +++++++++++++++++-------------- src/device.go | 33 +++++++++------ src/main.go | 1 + src/receive.go | 53 +++++++++++++++-------- 5 files changed, 164 insertions(+), 97 deletions(-) diff --git a/src/conn.go b/src/conn.go index 60cd789..61be3bf 100644 --- a/src/conn.go +++ b/src/conn.go @@ -3,7 +3,6 @@ package main import ( "errors" "net" - "time" ) func parseEndpoint(s string) (*net.UDPAddr, error) { @@ -27,63 +26,96 @@ func parseEndpoint(s string) (*net.UDPAddr, error) { return addr, err } -func updateUDPConn(device *Device) error { +func ListenerClose(l *Listener) (err error) { + if l.active { + err = CloseIPv4Socket(l.sock) + l.active = false + } + return +} + +func (l *Listener) Init() { + l.update = make(chan struct{}, 1) + ListenerClose(l) +} + +func ListeningUpdate(device *Device) error { netc := &device.net netc.mutex.Lock() defer netc.mutex.Unlock() - // close existing connection + // close existing sockets - if netc.conn != nil { - netc.conn.Close() - netc.conn = nil - - // We need for that fd to be closed in all other go routines, which - // means we have to wait. TODO: find less horrible way of doing this. - time.Sleep(time.Second / 2) + if err := ListenerClose(&netc.ipv4); err != nil { + return err } - // open new connection + if err := ListenerClose(&netc.ipv6); err != nil { + return err + } + + // open new sockets if device.tun.isUp.Get() { - // listen on new address + // listen on IPv4 - conn, err := net.ListenUDP("udp", netc.addr) - if err != nil { - return err + { + list := &netc.ipv6 + sock, port, err := CreateIPv4Socket(netc.port) + if err != nil { + return err + } + netc.port = port + list.sock = sock + list.active = true + + if err := SetMark(list.sock, netc.fwmark); err != nil { + ListenerClose(list) + return err + } + signalSend(list.update) } - // set fwmark + // listen on IPv6 - err = SetMark(netc.conn, netc.fwmark) - if err != nil { - return err + { + list := &netc.ipv6 + sock, port, err := CreateIPv6Socket(netc.port) + if err != nil { + return err + } + netc.port = port + list.sock = sock + list.active = true + + if err := SetMark(list.sock, netc.fwmark); err != nil { + ListenerClose(list) + return err + } + signalSend(list.update) } - // retrieve port (may have been chosen by kernel) - - addr := conn.LocalAddr() - netc.conn = conn - netc.addr, _ = net.ResolveUDPAddr( - addr.Network(), - addr.String(), - ) - - // notify goroutines - - signalSend(device.signal.newUDPConn) + // TODO: clear endpoint caches } return nil } -func closeUDPConn(device *Device) { +func ListeningClose(device *Device) error { netc := &device.net netc.mutex.Lock() - if netc.conn != nil { - netc.conn.Close() + defer netc.mutex.Unlock() + + if err := ListenerClose(&netc.ipv4); err != nil { + return err } - netc.mutex.Unlock() - signalSend(device.signal.newUDPConn) + signalSend(netc.ipv4.update) + + if err := ListenerClose(&netc.ipv6); err != nil { + return err + } + signalSend(netc.ipv6.update) + + return nil } diff --git a/src/conn_linux.go b/src/conn_linux.go index 64447a5..034fb8b 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -28,6 +28,7 @@ import "fmt" type Endpoint struct { // source (selected based on dst type) // (could use RawSockaddrAny and unsafe) + // TODO: Merge src6 unix.RawSockaddrInet6 src4 unix.RawSockaddrInet4 src4if int32 @@ -35,8 +36,14 @@ type Endpoint struct { dst unix.RawSockaddrAny } -type IPv4Socket int -type IPv6Socket int +type Socket int + +/* Returns a byte representation of the source field(s) + * for use in "under load" cookie computations. + */ +func (endpoint *Endpoint) Source() []byte { + return nil +} func zoneToUint32(zone string) (uint32, error) { if zone == "" { @@ -49,7 +56,7 @@ func zoneToUint32(zone string) (uint32, error) { return uint32(n), err } -func CreateIPv4Socket(port int) (IPv4Socket, error) { +func CreateIPv4Socket(port uint16) (Socket, uint16, error) { // create socket @@ -60,13 +67,16 @@ func CreateIPv4Socket(port int) (IPv4Socket, error) { ) if err != nil { - return -1, err + return -1, 0, err + } + + addr := unix.SockaddrInet4{ + Port: int(port), } // set sockopts and bind if err := func() error { - if err := unix.SetsockoptInt( fd, unix.SOL_SOCKET, @@ -85,19 +95,23 @@ func CreateIPv4Socket(port int) (IPv4Socket, error) { return err } - addr := unix.SockaddrInet4{ - Port: port, - } return unix.Bind(fd, &addr) - }(); err != nil { unix.Close(fd) } - return IPv4Socket(fd), err + return Socket(fd), uint16(addr.Port), err } -func CreateIPv6Socket(port int) (IPv6Socket, error) { +func CloseIPv4Socket(sock Socket) error { + return unix.Close(int(sock)) +} + +func CloseIPv6Socket(sock Socket) error { + return unix.Close(int(sock)) +} + +func CreateIPv6Socket(port uint16) (Socket, uint16, error) { // create socket @@ -108,11 +122,15 @@ func CreateIPv6Socket(port int) (IPv6Socket, error) { ) if err != nil { - return -1, err + return -1, 0, err } // set sockopts and bind + addr := unix.SockaddrInet6{ + Port: int(port), + } + if err := func() error { if err := unix.SetsockoptInt( @@ -142,16 +160,13 @@ func CreateIPv6Socket(port int) (IPv6Socket, error) { return err } - addr := unix.SockaddrInet6{ - Port: port, - } return unix.Bind(fd, &addr) }(); err != nil { unix.Close(fd) } - return IPv6Socket(fd), err + return Socket(fd), uint16(addr.Port), err } func (end *Endpoint) ClearSrc() { @@ -311,7 +326,7 @@ func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error { return errors.New("Unknown address family of source") } -func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) { +func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) { // contruct message header @@ -360,7 +375,7 @@ func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) { return int(size), nil } -func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error { +func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) { // contruct message header @@ -383,7 +398,7 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error { // recvmsg(sock, &mskhdr, 0) - _, _, errno := unix.Syscall( + size, _, errno := unix.Syscall( unix.SYS_RECVMSG, uintptr(sock), uintptr(unsafe.Pointer(&msg)), @@ -391,7 +406,7 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error { ) if errno != 0 { - return errno + return 0, errno } // update source cache @@ -403,21 +418,12 @@ func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error { end.src6.Scope_id = cmsg.pktinfo.Ifindex } - return nil + return int(size), nil } -func SetMark(conn *net.UDPConn, value uint32) error { - if conn == nil { - return nil - } - - file, err := conn.File() - if err != nil { - return err - } - +func SetMark(sock Socket, value uint32) error { return unix.SetsockoptInt( - int(file.Fd()), + int(sock), unix.SOL_SOCKET, unix.SO_MARK, int(value), diff --git a/src/device.go b/src/device.go index 61c87bc..509e6a7 100644 --- a/src/device.go +++ b/src/device.go @@ -1,13 +1,18 @@ package main import ( - "net" "runtime" "sync" "sync/atomic" "time" ) +type Listener struct { + sock Socket + active bool + update chan struct{} +} + type Device struct { log *Logger // collection of loggers for levels idCounter uint // for assigning debug ids to peers @@ -22,8 +27,9 @@ type Device struct { } net struct { mutex sync.RWMutex - addr *net.UDPAddr // UDP source address - conn *net.UDPConn // UDP "connection" + ipv4 Listener + ipv6 Listener + port uint16 fwmark uint32 } mutex sync.RWMutex @@ -37,8 +43,9 @@ type Device struct { handshake chan QueueHandshakeElement } signal struct { - stop chan struct{} // halts all go routines - newUDPConn chan struct{} // a net.conn was set (consumed by the receiver routine) + stop chan struct{} // halts all go routines + updateIPv4Socket chan struct{} // a net.conn was set (consumed by the receiver routine) + updateIPv6Socket chan struct{} // a net.conn was set (consumed by the receiver routine) } underLoadUntil atomic.Value ratelimiter Ratelimiter @@ -137,12 +144,16 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { device.log = NewLogger(logLevel, "("+tun.Name()+") ") device.peers = make(map[NoisePublicKey]*Peer) device.tun.device = tun + device.indices.Init() + device.net.ipv4.Init() + device.net.ipv6.Init() device.ratelimiter.Init() + device.routingTable.Reset() device.underLoadUntil.Store(time.Time{}) - // setup pools + // setup buffer pool device.pool.messageBuffers = sync.Pool{ New: func() interface{} { @@ -159,7 +170,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { // prepare signals device.signal.stop = make(chan struct{}) - device.signal.newUDPConn = make(chan struct{}, 1) // start workers @@ -168,12 +178,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { go device.RoutineDecryption() go device.RoutineHandshake() } - + go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) - go device.RoutineReadFromTUN() - go device.RoutineReceiveIncomming() - + go device.RoutineReceiveIncomming(&device.net.ipv4) + go device.RoutineReceiveIncomming(&device.net.ipv6) return device } @@ -204,7 +213,7 @@ func (device *Device) RemoveAllPeers() { func (device *Device) Close() { device.RemoveAllPeers() close(device.signal.stop) - closeUDPConn(device) + ListeningClose(device) } func (device *Device) WaitChannel() chan struct{} { diff --git a/src/main.go b/src/main.go index 196a4c6..a05dbba 100644 --- a/src/main.go +++ b/src/main.go @@ -14,6 +14,7 @@ func printUsage() { } func main() { + test() // parse arguments diff --git a/src/receive.go b/src/receive.go index 52c2718..60c0f2c 100644 --- a/src/receive.go +++ b/src/receive.go @@ -13,10 +13,10 @@ import ( ) type QueueHandshakeElement struct { - msgType uint32 - packet []byte - buffer *[MaxMessageSize]byte - source *net.UDPAddr + msgType uint32 + packet []byte + endpoint Endpoint + buffer *[MaxMessageSize]byte } type QueueInboundElement struct { @@ -92,11 +92,22 @@ func (device *Device) addToHandshakeQueue( } } -func (device *Device) RoutineReceiveIncomming() { +func (device *Device) RoutineReceiveIncomming(IPVersion int) { logDebug := device.log.Debug logDebug.Println("Routine, receive incomming, started") + var listener *Listener + + switch IPVersion { + case ipv4.Version: + listener = &device.net.ipv4 + case ipv6.Version: + listener = &device.net.ipv6 + default: + return + } + for { // wait for new conn @@ -107,14 +118,15 @@ func (device *Device) RoutineReceiveIncomming() { case <-device.signal.stop: return - case <-device.signal.newUDPConn: + case <-listener.update: - // fetch connection + // fetch new socket device.net.mutex.RLock() - conn := device.net.conn + sock := listener.sock + okay := listener.active device.net.mutex.RUnlock() - if conn == nil { + if !okay { continue } @@ -124,11 +136,20 @@ func (device *Device) RoutineReceiveIncomming() { buffer := device.GetMessageBuffer() + var size int + var err error + for { // read next datagram - size, raddr, err := conn.ReadFromUDP(buffer[:]) + var endpoint Endpoint + + if IPVersion == ipv6.Version { + size, err = endpoint.ReceiveIPv4(sock, buffer[:]) + } else { + size, err = endpoint.ReceiveIPv6(sock, buffer[:]) + } if err != nil { break @@ -192,7 +213,7 @@ func (device *Device) RoutineReceiveIncomming() { buffer = device.GetMessageBuffer() continue - // otherwise it is a handshake related packet + // otherwise it is a fixed size & handshake related packet case MessageInitiationType: okay = len(packet) == MessageInitiationSize @@ -208,10 +229,10 @@ func (device *Device) RoutineReceiveIncomming() { device.addToHandshakeQueue( device.queue.handshake, QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - source: raddr, + msgType: msgType, + buffer: buffer, + packet: packet, + endpoint: endpoint, }, ) buffer = device.GetMessageBuffer() @@ -293,8 +314,6 @@ func (device *Device) RoutineHandshake() { // unmarshal packet - logDebug.Println("Process cookie reply from:", elem.source.String()) - var reply MessageCookieReply reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &reply) From a72b0f7ae5dda27d839bb317b7c01d11b215e77a Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sun, 8 Oct 2017 22:03:32 +0200 Subject: [PATCH 03/15] Added new UDPBind interface --- src/conn.go | 83 ++++---------- src/conn_linux.go | 271 ++++++++++++++++++++++++++++++---------------- src/cookie.go | 12 +- src/device.go | 22 ++-- src/peer.go | 5 +- src/receive.go | 57 +++++----- 6 files changed, 236 insertions(+), 214 deletions(-) diff --git a/src/conn.go b/src/conn.go index 61be3bf..db4020d 100644 --- a/src/conn.go +++ b/src/conn.go @@ -5,6 +5,14 @@ import ( "net" ) +type UDPBind interface { + SetMark(value uint32) error + ReceiveIPv6(buff []byte, end *Endpoint) (int, error) + ReceiveIPv4(buff []byte, end *Endpoint) (int, error) + Send(buff []byte, end *Endpoint) error + Close() error +} + func parseEndpoint(s string) (*net.UDPAddr, error) { // ensure that the host is an IP address @@ -26,19 +34,6 @@ func parseEndpoint(s string) (*net.UDPAddr, error) { return addr, err } -func ListenerClose(l *Listener) (err error) { - if l.active { - err = CloseIPv4Socket(l.sock) - l.active = false - } - return -} - -func (l *Listener) Init() { - l.update = make(chan struct{}, 1) - ListenerClose(l) -} - func ListeningUpdate(device *Device) error { netc := &device.net netc.mutex.Lock() @@ -46,11 +41,7 @@ func ListeningUpdate(device *Device) error { // close existing sockets - if err := ListenerClose(&netc.ipv4); err != nil { - return err - } - - if err := ListenerClose(&netc.ipv6); err != nil { + if err := device.net.bind.Close(); err != nil { return err } @@ -58,45 +49,22 @@ func ListeningUpdate(device *Device) error { if device.tun.isUp.Get() { - // listen on IPv4 + // bind to new port - { - list := &netc.ipv6 - sock, port, err := CreateIPv4Socket(netc.port) - if err != nil { - return err - } - netc.port = port - list.sock = sock - list.active = true - - if err := SetMark(list.sock, netc.fwmark); err != nil { - ListenerClose(list) - return err - } - signalSend(list.update) + var err error + netc.bind, netc.port, err = CreateUDPBind(netc.port) + if err != nil { + return err } - // listen on IPv6 + // set mark - { - list := &netc.ipv6 - sock, port, err := CreateIPv6Socket(netc.port) - if err != nil { - return err - } - netc.port = port - list.sock = sock - list.active = true - - if err := SetMark(list.sock, netc.fwmark); err != nil { - ListenerClose(list) - return err - } - signalSend(list.update) + err = netc.bind.SetMark(netc.fwmark) + if err != nil { + return err } - // TODO: clear endpoint caches + // TODO: clear endpoint (src) caches } return nil @@ -106,16 +74,5 @@ func ListeningClose(device *Device) error { netc := &device.net netc.mutex.Lock() defer netc.mutex.Unlock() - - if err := ListenerClose(&netc.ipv4); err != nil { - return err - } - signalSend(netc.ipv4.update) - - if err := ListenerClose(&netc.ipv6); err != nil { - return err - } - signalSend(netc.ipv6.update) - - return nil + return netc.bind.Close() } diff --git a/src/conn_linux.go b/src/conn_linux.go index 034fb8b..8942b03 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -14,35 +14,158 @@ import ( "unsafe" ) -import "fmt" - /* Supports source address caching * * Currently there is no way to achieve this within the net package: * See e.g. https://github.com/golang/go/issues/17930 - * So this code is platform dependent. - * - * It is important that the endpoint is only updated after the packet content has been authenticated! + * So this code is remains platform dependent. */ type Endpoint struct { - // source (selected based on dst type) - // (could use RawSockaddrAny and unsafe) - // TODO: Merge - src6 unix.RawSockaddrInet6 - src4 unix.RawSockaddrInet4 - src4if int32 - - dst unix.RawSockaddrAny + src unix.RawSockaddrInet6 + dst unix.RawSockaddrInet6 } -type Socket int +type IPv4Source struct { + src unix.RawSockaddrInet4 + Ifindex int32 +} -/* Returns a byte representation of the source field(s) - * for use in "under load" cookie computations. - */ -func (endpoint *Endpoint) Source() []byte { - return nil +type Bind struct { + sock4 int + sock6 int +} + +func CreateUDPBind(port uint16) (UDPBind, uint16, error) { + var err error + var bind Bind + + bind.sock6, port, err = create6(port) + if err != nil { + return nil, port, err + } + + bind.sock4, port, err = create4(port) + if err != nil { + unix.Close(bind.sock6) + } + return &bind, port, err +} + +func (bind *Bind) SetMark(value uint32) error { + err := unix.SetsockoptInt( + bind.sock6, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) + + if err != nil { + return err + } + + return unix.SetsockoptInt( + bind.sock4, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) +} + +func (bind *Bind) Close() error { + err1 := unix.Close(bind.sock6) + err2 := unix.Close(bind.sock4) + if err1 != nil { + return err1 + } + return err2 +} + +func (bind *Bind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { + return receive6( + bind.sock6, + buff, + end, + ) +} + +func (bind *Bind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { + return receive4( + bind.sock4, + buff, + end, + ) +} + +func (bind *Bind) Send(buff []byte, end *Endpoint) error { + switch end.src.Family { + case unix.AF_INET6: + return send6(bind.sock6, end, buff) + case unix.AF_INET: + return send4(bind.sock4, end, buff) + default: + return errors.New("Unknown address family of source") + } +} + +func sockaddrToString(addr unix.RawSockaddrInet6) string { + var udpAddr net.UDPAddr + + switch addr.Family { + case unix.AF_INET6: + udpAddr.Port = int(addr.Port) + udpAddr.IP = addr.Addr[:] + return udpAddr.String() + + case unix.AF_INET: + ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) + udpAddr.Port = int(ptr.Port) + udpAddr.IP = net.IPv4( + ptr.Addr[0], + ptr.Addr[1], + ptr.Addr[2], + ptr.Addr[3], + ) + return udpAddr.String() + + default: + return "" + } +} + +func (end *Endpoint) DestinationIP() net.IP { + switch end.dst.Family { + case unix.AF_INET6: + return end.dst.Addr[:] + case unix.AF_INET: + ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) + return net.IPv4( + ptr.Addr[0], + ptr.Addr[1], + ptr.Addr[2], + ptr.Addr[3], + ) + default: + return nil + } +} + +func (end *Endpoint) SourceToBytes() []byte { + ptr := unsafe.Pointer(&end.src) + arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) + return arr[:] +} + +func (end *Endpoint) SourceToString() string { + return sockaddrToString(end.src) +} + +func (end *Endpoint) DestinationToString() string { + return sockaddrToString(end.dst) +} + +func (end *Endpoint) ClearSrc() { + end.src = unix.RawSockaddrInet6{} } func zoneToUint32(zone string) (uint32, error) { @@ -56,7 +179,7 @@ func zoneToUint32(zone string) (uint32, error) { return uint32(n), err } -func CreateIPv4Socket(port uint16) (Socket, uint16, error) { +func create4(port uint16) (int, uint16, error) { // create socket @@ -100,18 +223,10 @@ func CreateIPv4Socket(port uint16) (Socket, uint16, error) { unix.Close(fd) } - return Socket(fd), uint16(addr.Port), err + return fd, uint16(addr.Port), err } -func CloseIPv4Socket(sock Socket) error { - return unix.Close(int(sock)) -} - -func CloseIPv6Socket(sock Socket) error { - return unix.Close(int(sock)) -} - -func CreateIPv6Socket(port uint16) (Socket, uint16, error) { +func create6(port uint16) (int, uint16, error) { // create socket @@ -166,13 +281,7 @@ func CreateIPv6Socket(port uint16) (Socket, uint16, error) { unix.Close(fd) } - return Socket(fd), uint16(addr.Port), err -} - -func (end *Endpoint) ClearSrc() { - end.src4if = 0 - end.src4 = unix.RawSockaddrInet4{} - end.src6 = unix.RawSockaddrInet6{} + return fd, uint16(addr.Port), err } func (end *Endpoint) Set(s string) error { @@ -187,23 +296,23 @@ func (end *Endpoint) Set(s string) error { if err != nil { return err } - ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst)) - ptr.Family = unix.AF_INET6 - ptr.Port = uint16(addr.Port) - ptr.Flowinfo = 0 - ptr.Scope_id = zone - copy(ptr.Addr[:], ipv6[:]) + dst := &end.dst + dst.Family = unix.AF_INET6 + dst.Port = uint16(addr.Port) + dst.Flowinfo = 0 + dst.Scope_id = zone + copy(dst.Addr[:], ipv6[:]) end.ClearSrc() return nil } ipv4 := addr.IP.To4() if ipv4 != nil { - ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) - ptr.Family = unix.AF_INET - ptr.Port = uint16(addr.Port) - ptr.Zero = [8]byte{} - copy(ptr.Addr[:], ipv4) + dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) + dst.Family = unix.AF_INET + dst.Port = uint16(addr.Port) + dst.Zero = [8]byte{} + copy(dst.Addr[:], ipv4) end.ClearSrc() return nil } @@ -211,7 +320,7 @@ func (end *Endpoint) Set(s string) error { return errors.New("Failed to recognize IP address format") } -func send6(sock uintptr, end *Endpoint, buff []byte) error { +func send6(sock int, end *Endpoint, buff []byte) error { // construct message header @@ -229,8 +338,8 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error { Len: unix.SizeofInet6Pktinfo, }, unix.Inet6Pktinfo{ - Addr: end.src6.Addr, - Ifindex: end.src6.Scope_id, + Addr: end.src.Addr, + Ifindex: end.src.Scope_id, }, } @@ -248,7 +357,7 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error { _, _, errno := unix.Syscall( unix.SYS_SENDMSG, - sock, + uintptr(sock), uintptr(unsafe.Pointer(&msghdr)), 0, ) @@ -258,7 +367,7 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error { return errno } -func send4(sock uintptr, end *Endpoint, buff []byte) error { +func send4(sock int, end *Endpoint, buff []byte) error { // construct message header @@ -266,6 +375,8 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) iovec.SetLen(len(buff)) + src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) + cmsg := struct { cmsghdr unix.Cmsghdr pktinfo unix.Inet4Pktinfo @@ -276,8 +387,8 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { Len: unix.SizeofInet4Pktinfo, }, unix.Inet4Pktinfo{ - Spec_dst: end.src4.Addr, - Ifindex: end.src4if, + Spec_dst: src4.src.Addr, + Ifindex: src4.Ifindex, }, } @@ -295,7 +406,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { _, _, errno := unix.Syscall( unix.SYS_SENDMSG, - sock, + uintptr(sock), uintptr(unsafe.Pointer(&msghdr)), 0, ) @@ -305,28 +416,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error { return errno } -func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error { - - // extract underlying file descriptor - - file, err := c.File() - if err != nil { - return err - } - sock := file.Fd() - - // send depending on address family of dst - - family := *((*uint16)(unsafe.Pointer(&end.dst))) - if family == unix.AF_INET { - return send4(sock, end, buff) - } else if family == unix.AF_INET6 { - return send6(sock, end, buff) - } - return errors.New("Unknown address family of source") -} - -func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) { +func receive4(sock int, buff []byte, end *Endpoint) (int, error) { // contruct message header @@ -360,22 +450,21 @@ func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) { return 0, errno } - fmt.Println(msghdr) - fmt.Println(cmsg) - // update source cache if cmsg.cmsghdr.Level == unix.IPPROTO_IP && cmsg.cmsghdr.Type == unix.IP_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - end.src4.Addr = cmsg.pktinfo.Spec_dst - end.src4if = cmsg.pktinfo.Ifindex + src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) + src4.src.Family = unix.AF_INET + src4.src.Addr = cmsg.pktinfo.Spec_dst + src4.Ifindex = cmsg.pktinfo.Ifindex } return int(size), nil } -func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) { +func receive6(sock int, buff []byte, end *Endpoint) (int, error) { // contruct message header @@ -414,18 +503,10 @@ func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) { if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - end.src6.Addr = cmsg.pktinfo.Addr - end.src6.Scope_id = cmsg.pktinfo.Ifindex + end.src.Family = unix.AF_INET6 + end.src.Addr = cmsg.pktinfo.Addr + end.src.Scope_id = cmsg.pktinfo.Ifindex } return int(size), nil } - -func SetMark(sock Socket, value uint32) error { - return unix.SetsockoptInt( - int(sock), - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) -} diff --git a/src/cookie.go b/src/cookie.go index a81819b..a13ad49 100644 --- a/src/cookie.go +++ b/src/cookie.go @@ -5,10 +5,8 @@ import ( "crypto/rand" "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" - "net" "sync" "time" - "unsafe" ) type CookieChecker struct { @@ -76,7 +74,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool { return hmac.Equal(mac1[:], msg[smac1:smac2]) } -func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool { +func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { st.mutex.RLock() defer st.mutex.RUnlock() @@ -89,8 +87,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool { var cookie [blake2s.Size128]byte func() { mac, _ := blake2s.New128(st.mac2.secret[:]) - mac.Write(src.IP) - mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:]) + mac.Write(src) mac.Sum(cookie[:0]) }() @@ -111,7 +108,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool { func (st *CookieChecker) CreateReply( msg []byte, recv uint32, - src *net.UDPAddr, + src []byte, ) (*MessageCookieReply, error) { st.mutex.RLock() @@ -136,8 +133,7 @@ func (st *CookieChecker) CreateReply( var cookie [blake2s.Size128]byte func() { mac, _ := blake2s.New128(st.mac2.secret[:]) - mac.Write(src.IP) - mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:]) + mac.Write(src) mac.Sum(cookie[:0]) }() diff --git a/src/device.go b/src/device.go index 509e6a7..d1e0685 100644 --- a/src/device.go +++ b/src/device.go @@ -1,18 +1,14 @@ package main import ( + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" "runtime" "sync" "sync/atomic" "time" ) -type Listener struct { - sock Socket - active bool - update chan struct{} -} - type Device struct { log *Logger // collection of loggers for levels idCounter uint // for assigning debug ids to peers @@ -27,8 +23,7 @@ type Device struct { } net struct { mutex sync.RWMutex - ipv4 Listener - ipv6 Listener + bind UDPBind port uint16 fwmark uint32 } @@ -43,9 +38,8 @@ type Device struct { handshake chan QueueHandshakeElement } signal struct { - stop chan struct{} // halts all go routines - updateIPv4Socket chan struct{} // a net.conn was set (consumed by the receiver routine) - updateIPv6Socket chan struct{} // a net.conn was set (consumed by the receiver routine) + stop chan struct{} + updateBind chan struct{} } underLoadUntil atomic.Value ratelimiter Ratelimiter @@ -146,8 +140,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { device.tun.device = tun device.indices.Init() - device.net.ipv4.Init() - device.net.ipv6.Init() device.ratelimiter.Init() device.routingTable.Reset() @@ -181,8 +173,8 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) - go device.RoutineReceiveIncomming(&device.net.ipv4) - go device.RoutineReceiveIncomming(&device.net.ipv6) + go device.RoutineReceiveIncomming(ipv4.Version) + go device.RoutineReceiveIncomming(ipv6.Version) return device } diff --git a/src/peer.go b/src/peer.go index 6fea829..791c091 100644 --- a/src/peer.go +++ b/src/peer.go @@ -4,7 +4,6 @@ import ( "encoding/base64" "errors" "fmt" - "net" "sync" "time" ) @@ -15,8 +14,8 @@ type Peer struct { persistentKeepaliveInterval uint64 keyPairs KeyPairs handshake Handshake + endpoint Endpoint device *Device - endpoint *net.UDPAddr stats struct { txBytes uint64 // bytes send to peer (endpoint) rxBytes uint64 // bytes received from peer @@ -134,7 +133,7 @@ func (peer *Peer) String() string { return fmt.Sprintf( "peer(%d %s %s)", peer.id, - peer.endpoint.String(), + peer.endpoint.DestinationToString(), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), ) } diff --git a/src/receive.go b/src/receive.go index 60c0f2c..664f1ba 100644 --- a/src/receive.go +++ b/src/receive.go @@ -97,17 +97,6 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) { logDebug := device.log.Debug logDebug.Println("Routine, receive incomming, started") - var listener *Listener - - switch IPVersion { - case ipv4.Version: - listener = &device.net.ipv4 - case ipv6.Version: - listener = &device.net.ipv6 - default: - return - } - for { // wait for new conn @@ -118,15 +107,14 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) { case <-device.signal.stop: return - case <-listener.update: + case <-device.signal.updateBind: // fetch new socket device.net.mutex.RLock() - sock := listener.sock - okay := listener.active + bind := device.net.bind device.net.mutex.RUnlock() - if !okay { + if bind == nil { continue } @@ -145,10 +133,13 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) { var endpoint Endpoint - if IPVersion == ipv6.Version { - size, err = endpoint.ReceiveIPv4(sock, buffer[:]) - } else { - size, err = endpoint.ReceiveIPv6(sock, buffer[:]) + switch IPVersion { + case ipv4.Version: + size, err = bind.ReceiveIPv4(buffer[:], &endpoint) + case ipv6.Version: + size, err = bind.ReceiveIPv6(buffer[:], &endpoint) + default: + return } if err != nil { @@ -340,15 +331,19 @@ func (device *Device) RoutineHandshake() { return } + srcBytes := elem.endpoint.SourceToBytes() if device.IsUnderLoad() { - if !device.mac.CheckMAC2(elem.packet, elem.source) { + + // verify MAC2 field + + if !device.mac.CheckMAC2(elem.packet, srcBytes) { // construct cookie reply - logDebug.Println("Sending cookie reply to:", elem.source.String()) + logDebug.Println("Sending cookie reply to:", elem.endpoint.SourceToString()) sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" - reply, err := device.mac.CreateReply(elem.packet, sender, elem.source) + reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes) if err != nil { logError.Println("Failed to create cookie reply:", err) return @@ -358,9 +353,9 @@ func (device *Device) RoutineHandshake() { writer := bytes.NewBuffer(temp[:0]) binary.Write(writer, binary.LittleEndian, reply) - _, err = device.net.conn.WriteToUDP( + device.net.bind.Send( writer.Bytes(), - elem.source, + &elem.endpoint, ) if err != nil { logDebug.Println("Failed to send cookie reply:", err) @@ -368,7 +363,11 @@ func (device *Device) RoutineHandshake() { continue } - if !device.ratelimiter.Allow(elem.source.IP) { + // check ratelimiter + + if !device.ratelimiter.Allow( + elem.endpoint.DestinationIP(), + ) { continue } } @@ -399,8 +398,7 @@ func (device *Device) RoutineHandshake() { if peer == nil { logInfo.Println( "Recieved invalid initiation message from", - elem.source.IP.String(), - elem.source.Port, + elem.endpoint.DestinationToString(), ) continue } @@ -414,7 +412,7 @@ func (device *Device) RoutineHandshake() { // TODO: Discover destination address also, only update on change peer.mutex.Lock() - peer.endpoint = elem.source + peer.endpoint = elem.endpoint peer.mutex.Unlock() // create response @@ -460,8 +458,7 @@ func (device *Device) RoutineHandshake() { if peer == nil { logInfo.Println( "Recieved invalid response message from", - elem.source.IP.String(), - elem.source.Port, + elem.endpoint.DestinationToString(), ) continue } From e86d03dca23e5adcbd1c7bd30157bc7d19a932d7 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Mon, 16 Oct 2017 21:33:47 +0200 Subject: [PATCH 04/15] Initial implementation of source caching Yet untested. --- src/conn.go | 21 ++++++++++++---- src/conn_linux.go | 12 ++++++--- src/device.go | 2 +- src/main.go | 2 -- src/peer.go | 24 +++++++++++++++--- src/receive.go | 18 ++++++-------- src/send.go | 19 ++++---------- src/timers.go | 2 +- src/tun.go | 4 +-- src/uapi.go | 63 +++++++++++++++++------------------------------ 10 files changed, 84 insertions(+), 83 deletions(-) diff --git a/src/conn.go b/src/conn.go index db4020d..012e24e 100644 --- a/src/conn.go +++ b/src/conn.go @@ -34,15 +34,20 @@ func parseEndpoint(s string) (*net.UDPAddr, error) { return addr, err } -func ListeningUpdate(device *Device) error { +func UpdateUDPListener(device *Device) error { + device.mutex.Lock() + defer device.mutex.Unlock() + netc := &device.net netc.mutex.Lock() defer netc.mutex.Unlock() // close existing sockets - if err := device.net.bind.Close(); err != nil { - return err + if netc.bind != nil { + if err := netc.bind.Close(); err != nil { + return err + } } // open new sockets @@ -64,13 +69,19 @@ func ListeningUpdate(device *Device) error { return err } - // TODO: clear endpoint (src) caches + // clear cached source addresses + + for _, peer := range device.peers { + peer.mutex.Lock() + peer.endpoint.value.ClearSrc() + peer.mutex.Unlock() + } } return nil } -func ListeningClose(device *Device) error { +func CloseUDPListener(device *Device) error { netc := &device.net netc.mutex.Lock() defer netc.mutex.Unlock() diff --git a/src/conn_linux.go b/src/conn_linux.go index 8942b03..4a5a3f0 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -133,7 +133,7 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string { } } -func (end *Endpoint) DestinationIP() net.IP { +func (end *Endpoint) DstIP() net.IP { switch end.dst.Family { case unix.AF_INET6: return end.dst.Addr[:] @@ -150,20 +150,24 @@ func (end *Endpoint) DestinationIP() net.IP { } } -func (end *Endpoint) SourceToBytes() []byte { +func (end *Endpoint) SrcToBytes() []byte { ptr := unsafe.Pointer(&end.src) arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) return arr[:] } -func (end *Endpoint) SourceToString() string { +func (end *Endpoint) SrcToString() string { return sockaddrToString(end.src) } -func (end *Endpoint) DestinationToString() string { +func (end *Endpoint) DstToString() string { return sockaddrToString(end.dst) } +func (end *Endpoint) ClearDst() { + end.dst = unix.RawSockaddrInet6{} +} + func (end *Endpoint) ClearSrc() { end.src = unix.RawSockaddrInet6{} } diff --git a/src/device.go b/src/device.go index d1e0685..1aae448 100644 --- a/src/device.go +++ b/src/device.go @@ -205,7 +205,7 @@ func (device *Device) RemoveAllPeers() { func (device *Device) Close() { device.RemoveAllPeers() close(device.signal.stop) - ListeningClose(device) + CloseUDPListener(device) } func (device *Device) WaitChannel() chan struct{} { diff --git a/src/main.go b/src/main.go index a05dbba..5aaed9b 100644 --- a/src/main.go +++ b/src/main.go @@ -14,8 +14,6 @@ func printUsage() { } func main() { - test() - // parse arguments var foreground bool diff --git a/src/peer.go b/src/peer.go index 791c091..f24dcd8 100644 --- a/src/peer.go +++ b/src/peer.go @@ -14,9 +14,12 @@ type Peer struct { persistentKeepaliveInterval uint64 keyPairs KeyPairs handshake Handshake - endpoint Endpoint device *Device - stats struct { + endpoint struct { + set bool // has a known endpoint been discovered + value Endpoint // source / destination cache + } + stats struct { txBytes uint64 // bytes send to peer (endpoint) rxBytes uint64 // bytes received from peer lastHandshakeNano int64 // nano seconds since epoch @@ -105,6 +108,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic) handshake.mutex.Unlock() + // reset endpoint + + peer.endpoint.set = false + peer.endpoint.value.ClearDst() + peer.endpoint.value.ClearSrc() + // prepare queuing peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) @@ -129,11 +138,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } +/* Returns a short string identification for logging + */ func (peer *Peer) String() string { + if !peer.endpoint.set { + return fmt.Sprintf( + "peer(%d unknown %s)", + peer.id, + base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), + ) + } return fmt.Sprintf( "peer(%d %s %s)", peer.id, - peer.endpoint.DestinationToString(), + peer.endpoint.value.DstToString(), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), ) } diff --git a/src/receive.go b/src/receive.go index 664f1ba..1f05b2f 100644 --- a/src/receive.go +++ b/src/receive.go @@ -331,7 +331,7 @@ func (device *Device) RoutineHandshake() { return } - srcBytes := elem.endpoint.SourceToBytes() + srcBytes := elem.endpoint.SrcToBytes() if device.IsUnderLoad() { // verify MAC2 field @@ -340,8 +340,7 @@ func (device *Device) RoutineHandshake() { // construct cookie reply - logDebug.Println("Sending cookie reply to:", elem.endpoint.SourceToString()) - + logDebug.Println("Sending cookie reply to:", elem.endpoint.SrcToString()) sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes) if err != nil { @@ -365,9 +364,7 @@ func (device *Device) RoutineHandshake() { // check ratelimiter - if !device.ratelimiter.Allow( - elem.endpoint.DestinationIP(), - ) { + if !device.ratelimiter.Allow(elem.endpoint.DstIP()) { continue } } @@ -398,7 +395,7 @@ func (device *Device) RoutineHandshake() { if peer == nil { logInfo.Println( "Recieved invalid initiation message from", - elem.endpoint.DestinationToString(), + elem.endpoint.DstToString(), ) continue } @@ -412,7 +409,8 @@ func (device *Device) RoutineHandshake() { // TODO: Discover destination address also, only update on change peer.mutex.Lock() - peer.endpoint = elem.endpoint + peer.endpoint.set = true + peer.endpoint.value = elem.endpoint peer.mutex.Unlock() // create response @@ -435,7 +433,7 @@ func (device *Device) RoutineHandshake() { // send response - _, err = peer.SendBuffer(packet) + err = peer.SendBuffer(packet) if err == nil { peer.TimerAnyAuthenticatedPacketTraversal() } @@ -458,7 +456,7 @@ func (device *Device) RoutineHandshake() { if peer == nil { logInfo.Println( "Recieved invalid response message from", - elem.endpoint.DestinationToString(), + elem.endpoint.DstToString(), ) continue } diff --git a/src/send.go b/src/send.go index 5c88ead..e37a736 100644 --- a/src/send.go +++ b/src/send.go @@ -105,24 +105,15 @@ func addToEncryptionQueue( } } -func (peer *Peer) SendBuffer(buffer []byte) (int, error) { +func (peer *Peer) SendBuffer(buffer []byte) error { peer.device.net.mutex.RLock() defer peer.device.net.mutex.RUnlock() - peer.mutex.RLock() defer peer.mutex.RUnlock() - - endpoint := peer.endpoint - if endpoint == nil { - return 0, errors.New("No known endpoint for peer") + if !peer.endpoint.set { + return errors.New("No known endpoint for peer") } - - conn := peer.device.net.conn - if conn == nil { - return 0, errors.New("No UDP socket for device") - } - - return conn.WriteToUDP(buffer, endpoint) + return peer.device.net.bind.Send(buffer, &peer.endpoint.value) } /* Reads packets from the TUN and inserts @@ -343,7 +334,7 @@ func (peer *Peer) RoutineSequentialSender() { // send message and return buffer to pool length := uint64(len(elem.packet)) - _, err := peer.SendBuffer(elem.packet) + err := peer.SendBuffer(elem.packet) device.PutMessageBuffer(elem.buffer) if err != nil { logDebug.Println("Failed to send authenticated packet to peer", peer.String()) diff --git a/src/timers.go b/src/timers.go index 99695ba..2a94005 100644 --- a/src/timers.go +++ b/src/timers.go @@ -288,7 +288,7 @@ func (peer *Peer) RoutineHandshakeInitiator() { packet := writer.Bytes() peer.mac.AddMacs(packet) - _, err = peer.SendBuffer(packet) + err = peer.SendBuffer(packet) if err != nil { logError.Println( "Failed to send handshake initiation message to", diff --git a/src/tun.go b/src/tun.go index 8e8c759..9eed987 100644 --- a/src/tun.go +++ b/src/tun.go @@ -47,7 +47,7 @@ func (device *Device) RoutineTUNEventReader() { if !device.tun.isUp.Get() { logInfo.Println("Interface set up") device.tun.isUp.Set(true) - updateUDPConn(device) + UpdateUDPListener(device) } } @@ -55,7 +55,7 @@ func (device *Device) RoutineTUNEventReader() { if device.tun.isUp.Get() { logInfo.Println("Interface set down") device.tun.isUp.Set(false) - closeUDPConn(device) + CloseUDPListener(device) } } } diff --git a/src/uapi.go b/src/uapi.go index 7d08e56..2de26ee 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -39,9 +39,10 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { send("private_key=" + device.privateKey.ToHex()) } - if device.net.addr != nil { - send(fmt.Sprintf("listen_port=%d", device.net.addr.Port)) + if device.net.port != 0 { + send(fmt.Sprintf("listen_port=%d", device.net.port)) } + if device.net.fwmark != 0 { send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) } @@ -52,8 +53,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { defer peer.mutex.RUnlock() send("public_key=" + peer.handshake.remoteStatic.ToHex()) send("preshared_key=" + peer.handshake.presharedKey.ToHex()) - if peer.endpoint != nil { - send("endpoint=" + peer.endpoint.String()) + if peer.endpoint.set { + send("endpoint=" + peer.endpoint.value.DstToString()) } nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) @@ -137,53 +138,24 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logError.Println("Failed to set listen_port:", err) return &IPCError{Code: ipcErrorInvalid} } - - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) - if err != nil { - logError.Println("Failed to set listen_port:", err) - return &IPCError{Code: ipcErrorInvalid} - } - - device.net.mutex.Lock() - device.net.addr = addr - device.net.mutex.Unlock() - - err = updateUDPConn(device) - if err != nil { + device.net.port = uint16(port) + if err := UpdateUDPListener(device); err != nil { logError.Println("Failed to set listen_port:", err) return &IPCError{Code: ipcErrorPortInUse} } - // TODO: Clear source address of all peers - case "fwmark": fwmark, err := strconv.ParseUint(value, 10, 32) if err != nil { logError.Println("Invalid fwmark", err) return &IPCError{Code: ipcErrorInvalid} } - device.net.mutex.Lock() - if fwmark > 0 || device.net.fwmark > 0 { - device.net.fwmark = uint32(fwmark) - err := SetMark( - device.net.conn, - device.net.fwmark, - ) - if err != nil { - logError.Println("Failed to set fwmark:", err) - device.net.mutex.Unlock() - return &IPCError{Code: ipcErrorIO} - } - - // TODO: Clear source address of all peers - } + device.net.fwmark = uint32(fwmark) device.net.mutex.Unlock() case "public_key": - // switch to peer configuration - deviceConfig = false case "replace_peers": @@ -218,7 +190,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { device.mutex.RLock() if device.publicKey.Equals(pubKey) { - // create dummy instance + // create dummy instance (not added to device) peer = &Peer{} dummy = true @@ -244,6 +216,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } case "remove": + + // remove currently selected peer from device + if value != "true" { logError.Println("Failed to set remove, invalid value:", value) return &IPCError{Code: ipcErrorInvalid} @@ -256,6 +231,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { dummy = true case "preshared_key": + + // update PSK + peer.mutex.Lock() err := peer.handshake.presharedKey.FromHex(value) peer.mutex.Unlock() @@ -265,14 +243,17 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } case "endpoint": - addr, err := parseEndpoint(value) + + // set endpoint destination and reset handshake timer + + peer.mutex.Lock() + err := peer.endpoint.value.Set(value) + peer.endpoint.set = (err == nil) + peer.mutex.Unlock() if err != nil { logError.Println("Failed to set endpoint:", value) return &IPCError{Code: ipcErrorInvalid} } - peer.mutex.Lock() - peer.endpoint = addr - peer.mutex.Unlock() signalSend(peer.signal.handshakeReset) case "persistent_keepalive_interval": From fd6f2e1f554cb545c7c554b56e2ac77308822680 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Tue, 17 Oct 2017 16:50:23 +0200 Subject: [PATCH 05/15] Fixed timer issue when failing to send handshake + Identified send4 issue --- src/conn_linux.go | 62 ++++++++++++++++++++++++++++++----------------- src/timers.go | 27 +++++++++------------ src/uapi.go | 3 ++- 3 files changed, 54 insertions(+), 38 deletions(-) diff --git a/src/conn_linux.go b/src/conn_linux.go index 4a5a3f0..51ca4f3 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -8,6 +8,7 @@ package main import ( "errors" + "fmt" "golang.org/x/sys/unix" "net" "strconv" @@ -31,14 +32,14 @@ type IPv4Source struct { Ifindex int32 } -type Bind struct { +type NativeBind struct { sock4 int sock6 int } func CreateUDPBind(port uint16) (UDPBind, uint16, error) { var err error - var bind Bind + var bind NativeBind bind.sock6, port, err = create6(port) if err != nil { @@ -52,7 +53,7 @@ func CreateUDPBind(port uint16) (UDPBind, uint16, error) { return &bind, port, err } -func (bind *Bind) SetMark(value uint32) error { +func (bind *NativeBind) SetMark(value uint32) error { err := unix.SetsockoptInt( bind.sock6, unix.SOL_SOCKET, @@ -72,7 +73,7 @@ func (bind *Bind) SetMark(value uint32) error { ) } -func (bind *Bind) Close() error { +func (bind *NativeBind) Close() error { err1 := unix.Close(bind.sock6) err2 := unix.Close(bind.sock4) if err1 != nil { @@ -81,7 +82,7 @@ func (bind *Bind) Close() error { return err2 } -func (bind *Bind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { +func (bind *NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { return receive6( bind.sock6, buff, @@ -89,7 +90,7 @@ func (bind *Bind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { ) } -func (bind *Bind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { +func (bind *NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { return receive4( bind.sock4, buff, @@ -97,14 +98,14 @@ func (bind *Bind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { ) } -func (bind *Bind) Send(buff []byte, end *Endpoint) error { - switch end.src.Family { +func (bind *NativeBind) Send(buff []byte, end *Endpoint) error { + switch end.dst.Family { case unix.AF_INET6: return send6(bind.sock6, end, buff) case unix.AF_INET: return send4(bind.sock4, end, buff) default: - return errors.New("Unknown address family of source") + return errors.New("Unknown address family of destination") } } @@ -288,12 +289,25 @@ func create6(port uint16) (int, uint16, error) { return fd, uint16(addr.Port), err } -func (end *Endpoint) Set(s string) error { +func (end *Endpoint) SetDst(s string) error { addr, err := parseEndpoint(s) if err != nil { return err } + fmt.Println(addr, err) + + ipv4 := addr.IP.To4() + if ipv4 != nil { + dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) + dst.Family = unix.AF_INET + dst.Port = uint16(addr.Port) + dst.Zero = [8]byte{} + copy(dst.Addr[:], ipv4) + end.ClearSrc() + return nil + } + ipv6 := addr.IP.To16() if ipv6 != nil { zone, err := zoneToUint32(addr.Zone) @@ -310,17 +324,6 @@ func (end *Endpoint) Set(s string) error { return nil } - ipv4 := addr.IP.To4() - if ipv4 != nil { - dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) - dst.Family = unix.AF_INET - dst.Port = uint16(addr.Port) - dst.Zero = [8]byte{} - copy(dst.Addr[:], ipv4) - end.ClearSrc() - return nil - } - return errors.New("Failed to recognize IP address format") } @@ -372,6 +375,8 @@ func send6(sock int, end *Endpoint, buff []byte) error { } func send4(sock int, end *Endpoint, buff []byte) error { + println("send 4") + println(end.DstToString()) // construct message header @@ -403,7 +408,6 @@ func send4(sock int, end *Endpoint, buff []byte) error { Namelen: unix.SizeofSockaddrInet4, Control: (*byte)(unsafe.Pointer(&cmsg)), } - msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) // sendmsg(sock, &msghdr, 0) @@ -414,9 +418,23 @@ func send4(sock int, end *Endpoint, buff []byte) error { uintptr(unsafe.Pointer(&msghdr)), 0, ) + + println(sock) + fmt.Println(errno) + + // clear source cache and try again + if errno == unix.EINVAL { end.ClearSrc() + cmsg.pktinfo = unix.Inet4Pktinfo{} + _, _, errno = unix.Syscall( + unix.SYS_SENDMSG, + uintptr(sock), + uintptr(unsafe.Pointer(&msghdr)), + 0, + ) } + return errno } diff --git a/src/timers.go b/src/timers.go index 2a94005..31165a3 100644 --- a/src/timers.go +++ b/src/timers.go @@ -279,34 +279,31 @@ func (peer *Peer) RoutineHandshakeInitiator() { break AttemptHandshakes } - jitter := time.Millisecond * time.Duration(rand.Uint32()%334) - - // marshal and send + // marshal handshake message writer := bytes.NewBuffer(temp[:0]) binary.Write(writer, binary.LittleEndian, msg) packet := writer.Bytes() peer.mac.AddMacs(packet) + // send to endpoint + err = peer.SendBuffer(packet) - if err != nil { + jitter := time.Millisecond * time.Duration(rand.Uint32()%334) + timeout := time.NewTimer(RekeyTimeout + jitter) + if err == nil { + peer.TimerAnyAuthenticatedPacketTraversal() + logDebug.Println( + "Handshake initiation attempt", + attempts, "sent to", peer.String(), + ) + } else { logError.Println( "Failed to send handshake initiation message to", peer.String(), ":", err, ) - continue } - peer.TimerAnyAuthenticatedPacketTraversal() - - // set handshake timeout - - timeout := time.NewTimer(RekeyTimeout + jitter) - logDebug.Println( - "Handshake initiation attempt", - attempts, "sent to", peer.String(), - ) - // wait for handshake or timeout select { diff --git a/src/uapi.go b/src/uapi.go index 2de26ee..accffd1 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -247,7 +247,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // set endpoint destination and reset handshake timer peer.mutex.Lock() - err := peer.endpoint.value.Set(value) + err := peer.endpoint.value.SetDst(value) + fmt.Println(peer.endpoint.value.DstToString(), err) peer.endpoint.set = (err == nil) peer.mutex.Unlock() if err != nil { From 0485c34c8e20e4f7ea19bd3c3f52d2f4717caead Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 27 Oct 2017 10:43:37 +0200 Subject: [PATCH 06/15] Fixed message header length in conn_linux --- src/conn.go | 7 +++++++ src/conn_linux.go | 43 +++++++++++++++++++++++++++++++------------ src/main.go | 5 ++++- src/peer.go | 11 +++++++++++ src/send.go | 12 ------------ src/uapi.go | 2 +- 6 files changed, 54 insertions(+), 26 deletions(-) diff --git a/src/conn.go b/src/conn.go index 012e24e..b2caffb 100644 --- a/src/conn.go +++ b/src/conn.go @@ -45,15 +45,20 @@ func UpdateUDPListener(device *Device) error { // close existing sockets if netc.bind != nil { + println("close bind") if err := netc.bind.Close(); err != nil { return err } + netc.bind = nil + println("closed") } // open new sockets if device.tun.isUp.Get() { + println("creat") + // bind to new port var err error @@ -69,6 +74,8 @@ func UpdateUDPListener(device *Device) error { return err } + println("okay") + // clear cached source addresses for _, peer := range device.peers { diff --git a/src/conn_linux.go b/src/conn_linux.go index 51ca4f3..8cda460 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -50,10 +50,12 @@ func CreateUDPBind(port uint16) (UDPBind, uint16, error) { if err != nil { unix.Close(bind.sock6) } - return &bind, port, err + println(bind.sock6) + println(bind.sock4) + return bind, port, err } -func (bind *NativeBind) SetMark(value uint32) error { +func (bind NativeBind) SetMark(value uint32) error { err := unix.SetsockoptInt( bind.sock6, unix.SOL_SOCKET, @@ -73,7 +75,7 @@ func (bind *NativeBind) SetMark(value uint32) error { ) } -func (bind *NativeBind) Close() error { +func (bind NativeBind) Close() error { err1 := unix.Close(bind.sock6) err2 := unix.Close(bind.sock4) if err1 != nil { @@ -82,7 +84,7 @@ func (bind *NativeBind) Close() error { return err2 } -func (bind *NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { +func (bind NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { return receive6( bind.sock6, buff, @@ -90,7 +92,7 @@ func (bind *NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { ) } -func (bind *NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { +func (bind NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { return receive4( bind.sock4, buff, @@ -98,7 +100,7 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { ) } -func (bind *NativeBind) Send(buff []byte, end *Endpoint) error { +func (bind NativeBind) Send(buff []byte, end *Endpoint) error { switch end.dst.Family { case unix.AF_INET6: return send6(bind.sock6, end, buff) @@ -236,7 +238,7 @@ func create6(port uint16) (int, uint16, error) { // create socket fd, err := unix.Socket( - unix.AF_INET, + unix.AF_INET6, unix.SOCK_DGRAM, 0, ) @@ -342,7 +344,7 @@ func send6(sock int, end *Endpoint, buff []byte) error { unix.Cmsghdr{ Level: unix.IPPROTO_IPV6, Type: unix.IPV6_PKTINFO, - Len: unix.SizeofInet6Pktinfo, + Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, }, unix.Inet6Pktinfo{ Addr: end.src.Addr, @@ -368,15 +370,31 @@ func send6(sock int, end *Endpoint, buff []byte) error { uintptr(unsafe.Pointer(&msghdr)), 0, ) + + if errno == 0 { + return nil + } + + // clear src and retry + if errno == unix.EINVAL { end.ClearSrc() + cmsg.pktinfo = unix.Inet6Pktinfo{} + _, _, errno = unix.Syscall( + unix.SYS_SENDMSG, + uintptr(sock), + uintptr(unsafe.Pointer(&msghdr)), + 0, + ) } + return errno } func send4(sock int, end *Endpoint, buff []byte) error { println("send 4") println(end.DstToString()) + println(sock) // construct message header @@ -393,7 +411,7 @@ func send4(sock int, end *Endpoint, buff []byte) error { unix.Cmsghdr{ Level: unix.IPPROTO_IP, Type: unix.IP_PKTINFO, - Len: unix.SizeofInet4Pktinfo, + Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, }, unix.Inet4Pktinfo{ Spec_dst: src4.src.Addr, @@ -419,10 +437,11 @@ func send4(sock int, end *Endpoint, buff []byte) error { 0, ) - println(sock) - fmt.Println(errno) + if errno == 0 { + return nil + } - // clear source cache and try again + // clear source and try again if errno == unix.EINVAL { end.ClearSrc() diff --git a/src/main.go b/src/main.go index 5aaed9b..05d56eb 100644 --- a/src/main.go +++ b/src/main.go @@ -84,7 +84,10 @@ func main() { logInfo := device.log.Info logError := device.log.Error - logInfo.Println("Starting device") + logDebug := device.log.Debug + + logInfo.Println("Device started") + logDebug.Println("Debug log enabled") // start configuration lister diff --git a/src/peer.go b/src/peer.go index f24dcd8..a98fc97 100644 --- a/src/peer.go +++ b/src/peer.go @@ -138,6 +138,17 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } +func (peer *Peer) SendBuffer(buffer []byte) error { + peer.device.net.mutex.RLock() + defer peer.device.net.mutex.RUnlock() + peer.mutex.RLock() + defer peer.mutex.RUnlock() + if !peer.endpoint.set { + return errors.New("No known endpoint for peer") + } + return peer.device.net.bind.Send(buffer, &peer.endpoint.value) +} + /* Returns a short string identification for logging */ func (peer *Peer) String() string { diff --git a/src/send.go b/src/send.go index e37a736..52872f6 100644 --- a/src/send.go +++ b/src/send.go @@ -2,7 +2,6 @@ package main import ( "encoding/binary" - "errors" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -105,17 +104,6 @@ func addToEncryptionQueue( } } -func (peer *Peer) SendBuffer(buffer []byte) error { - peer.device.net.mutex.RLock() - defer peer.device.net.mutex.RUnlock() - peer.mutex.RLock() - defer peer.mutex.RUnlock() - if !peer.endpoint.set { - return errors.New("No known endpoint for peer") - } - return peer.device.net.bind.Send(buffer, &peer.endpoint.value) -} - /* Reads packets from the TUN and inserts * into nonce queue for peer * diff --git a/src/uapi.go b/src/uapi.go index accffd1..5098e3d 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -135,7 +135,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "listen_port": port, err := strconv.ParseUint(value, 10, 16) if err != nil { - logError.Println("Failed to set listen_port:", err) + logError.Println("Failed to parse listen_port:", err) return &IPCError{Code: ipcErrorInvalid} } device.net.port = uint16(port) From 892276aa64ca9b14d2e96186b83145ab2f5ce25a Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 11 Nov 2017 15:43:55 +0100 Subject: [PATCH 07/15] Fixed port endianness --- src/conn.go | 45 ++++++---- src/conn_linux.go | 35 ++++---- src/device.go | 16 ++-- src/receive.go | 203 ++++++++++++++++++++++------------------------ 4 files changed, 159 insertions(+), 140 deletions(-) diff --git a/src/conn.go b/src/conn.go index b2caffb..aa0b72b 100644 --- a/src/conn.go +++ b/src/conn.go @@ -34,6 +34,21 @@ func parseEndpoint(s string) (*net.UDPAddr, error) { return addr, err } +/* Must hold device and net lock + */ +func unsafeCloseUDPListener(device *Device) error { + netc := &device.net + if netc.bind != nil { + if err := netc.bind.Close(); err != nil { + return err + } + netc.bind = nil + netc.update.Broadcast() + } + return nil +} + +// must inform all listeners func UpdateUDPListener(device *Device) error { device.mutex.Lock() defer device.mutex.Unlock() @@ -44,26 +59,22 @@ func UpdateUDPListener(device *Device) error { // close existing sockets - if netc.bind != nil { - println("close bind") - if err := netc.bind.Close(); err != nil { - return err - } - netc.bind = nil - println("closed") + if err := unsafeCloseUDPListener(device); err != nil { + return err } + // wait for reader + // open new sockets if device.tun.isUp.Get() { - println("creat") - // bind to new port var err error netc.bind, netc.port, err = CreateUDPBind(netc.port) if err != nil { + netc.bind = nil return err } @@ -74,8 +85,6 @@ func UpdateUDPListener(device *Device) error { return err } - println("okay") - // clear cached source addresses for _, peer := range device.peers { @@ -83,14 +92,20 @@ func UpdateUDPListener(device *Device) error { peer.endpoint.value.ClearSrc() peer.mutex.Unlock() } + + // inform readers of updated bind + + netc.update.Broadcast() } return nil } func CloseUDPListener(device *Device) error { - netc := &device.net - netc.mutex.Lock() - defer netc.mutex.Unlock() - return netc.bind.Close() + device.mutex.Lock() + device.net.mutex.Lock() + err := unsafeCloseUDPListener(device) + device.net.mutex.Unlock() + device.mutex.Unlock() + return err } diff --git a/src/conn_linux.go b/src/conn_linux.go index 8cda460..05f9347 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -7,8 +7,8 @@ package main import ( + "encoding/binary" "errors" - "fmt" "golang.org/x/sys/unix" "net" "strconv" @@ -37,6 +37,17 @@ type NativeBind struct { sock6 int } +func htons(val uint16) uint16 { + var out [unsafe.Sizeof(val)]byte + binary.BigEndian.PutUint16(out[:], val) + return *((*uint16)(unsafe.Pointer(&out[0]))) +} + +func ntohs(val uint16) uint16 { + tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val))) + return binary.BigEndian.Uint16((*tmp)[:]) +} + func CreateUDPBind(port uint16) (UDPBind, uint16, error) { var err error var bind NativeBind @@ -50,8 +61,6 @@ func CreateUDPBind(port uint16) (UDPBind, uint16, error) { if err != nil { unix.Close(bind.sock6) } - println(bind.sock6) - println(bind.sock4) return bind, port, err } @@ -297,13 +306,11 @@ func (end *Endpoint) SetDst(s string) error { return err } - fmt.Println(addr, err) - ipv4 := addr.IP.To4() if ipv4 != nil { dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) dst.Family = unix.AF_INET - dst.Port = uint16(addr.Port) + dst.Port = htons(uint16(addr.Port)) dst.Zero = [8]byte{} copy(dst.Addr[:], ipv4) end.ClearSrc() @@ -318,7 +325,7 @@ func (end *Endpoint) SetDst(s string) error { } dst := &end.dst dst.Family = unix.AF_INET6 - dst.Port = uint16(addr.Port) + dst.Port = htons(uint16(addr.Port)) dst.Flowinfo = 0 dst.Scope_id = zone copy(dst.Addr[:], ipv6[:]) @@ -392,9 +399,6 @@ func send6(sock int, end *Endpoint, buff []byte) error { } func send4(sock int, end *Endpoint, buff []byte) error { - println("send 4") - println(end.DstToString()) - println(sock) // construct message header @@ -425,6 +429,7 @@ func send4(sock int, end *Endpoint, buff []byte) error { Name: (*byte)(unsafe.Pointer(&end.dst)), Namelen: unix.SizeofSockaddrInet4, Control: (*byte)(unsafe.Pointer(&cmsg)), + Flags: 0, } msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) @@ -437,10 +442,6 @@ func send4(sock int, end *Endpoint, buff []byte) error { 0, ) - if errno == 0 { - return nil - } - // clear source and try again if errno == unix.EINVAL { @@ -454,6 +455,12 @@ func send4(sock int, end *Endpoint, buff []byte) error { ) } + // errno = 0 is still an error instance + + if errno == 0 { + return nil + } + return errno } diff --git a/src/device.go b/src/device.go index 1aae448..a348c68 100644 --- a/src/device.go +++ b/src/device.go @@ -23,9 +23,10 @@ type Device struct { } net struct { mutex sync.RWMutex - bind UDPBind - port uint16 - fwmark uint32 + bind UDPBind // bind interface + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) + update *sync.Cond // the bind was updated } mutex sync.RWMutex privateKey NoisePrivateKey @@ -38,8 +39,7 @@ type Device struct { handshake chan QueueHandshakeElement } signal struct { - stop chan struct{} - updateBind chan struct{} + stop chan struct{} } underLoadUntil atomic.Value ratelimiter Ratelimiter @@ -163,6 +163,12 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { device.signal.stop = make(chan struct{}) + // prepare net + + device.net.port = 0 + device.net.bind = nil + device.net.update = sync.NewCond(&device.net.mutex) + // start workers for i := 0; i < runtime.NumCPU(); i += 1 { diff --git a/src/receive.go b/src/receive.go index 1f05b2f..cb53f80 100644 --- a/src/receive.go +++ b/src/receive.go @@ -99,135 +99,126 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) { for { - // wait for new conn + // wait for bind - logDebug.Println("Waiting for udp socket") + logDebug.Println("Waiting for udp bind") + device.net.mutex.Lock() + device.net.update.Wait() + bind := device.net.bind + device.net.mutex.Unlock() + if bind == nil { + continue + } - select { - case <-device.signal.stop: - return + logDebug.Println("LISTEN\n\n\n") - case <-device.signal.updateBind: + // receive datagrams until conn is closed - // fetch new socket + buffer := device.GetMessageBuffer() - device.net.mutex.RLock() - bind := device.net.bind - device.net.mutex.RUnlock() - if bind == nil { + var size int + var err error + + for { + + // read next datagram + + var endpoint Endpoint + + switch IPVersion { + case ipv4.Version: + size, err = bind.ReceiveIPv4(buffer[:], &endpoint) + case ipv6.Version: + size, err = bind.ReceiveIPv6(buffer[:], &endpoint) + default: + return + } + + if err != nil { + break + } + + if size < MinMessageSize { continue } - logDebug.Println("Listening for inbound packets") + // check size of packet - // receive datagrams until conn is closed + packet := buffer[:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) - buffer := device.GetMessageBuffer() + var okay bool - var size int - var err error + switch msgType { - for { + // check if transport - // read next datagram + case MessageTransportType: - var endpoint Endpoint + // check size - switch IPVersion { - case ipv4.Version: - size, err = bind.ReceiveIPv4(buffer[:], &endpoint) - case ipv6.Version: - size, err = bind.ReceiveIPv6(buffer[:], &endpoint) - default: - return - } - - if err != nil { - break - } - - if size < MinMessageSize { + if len(packet) < MessageTransportType { continue } - // check size of packet + // lookup key pair - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) - - var okay bool - - switch msgType { - - // check if transport - - case MessageTransportType: - - // check size - - if len(packet) < MessageTransportType { - continue - } - - // lookup key pair - - receiver := binary.LittleEndian.Uint32( - packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], - ) - value := device.indices.Lookup(receiver) - keyPair := value.keyPair - if keyPair == nil { - continue - } - - // check key-pair expiry - - if keyPair.created.Add(RejectAfterTime).Before(time.Now()) { - continue - } - - // create work element - - peer := value.peer - elem := &QueueInboundElement{ - packet: packet, - buffer: buffer, - keyPair: keyPair, - dropped: AtomicFalse, - } - elem.mutex.Lock() - - // add to decryption queues - - device.addToDecryptionQueue(device.queue.decryption, elem) - device.addToInboundQueue(peer.queue.inbound, elem) - buffer = device.GetMessageBuffer() + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indices.Lookup(receiver) + keyPair := value.keyPair + if keyPair == nil { continue - - // otherwise it is a fixed size & handshake related packet - - case MessageInitiationType: - okay = len(packet) == MessageInitiationSize - - case MessageResponseType: - okay = len(packet) == MessageResponseSize - - case MessageCookieReplyType: - okay = len(packet) == MessageCookieReplySize } - if okay { - device.addToHandshakeQueue( - device.queue.handshake, - QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - endpoint: endpoint, - }, - ) - buffer = device.GetMessageBuffer() + // check key-pair expiry + + if keyPair.created.Add(RejectAfterTime).Before(time.Now()) { + continue } + + // create work element + + peer := value.peer + elem := &QueueInboundElement{ + packet: packet, + buffer: buffer, + keyPair: keyPair, + dropped: AtomicFalse, + } + elem.mutex.Lock() + + // add to decryption queues + + device.addToDecryptionQueue(device.queue.decryption, elem) + device.addToInboundQueue(peer.queue.inbound, elem) + buffer = device.GetMessageBuffer() + continue + + // otherwise it is a fixed size & handshake related packet + + case MessageInitiationType: + okay = len(packet) == MessageInitiationSize + + case MessageResponseType: + okay = len(packet) == MessageResponseSize + + case MessageCookieReplyType: + okay = len(packet) == MessageCookieReplySize + } + + if okay { + device.addToHandshakeQueue( + device.queue.handshake, + QueueHandshakeElement{ + msgType: msgType, + buffer: buffer, + packet: packet, + endpoint: endpoint, + }, + ) + buffer = device.GetMessageBuffer() } } } From 566269275ed97812ec909b10ec77c7c037d9e2ea Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 11 Nov 2017 23:26:44 +0100 Subject: [PATCH 08/15] Fixed blocking reader on closed socket --- src/conn.go | 16 ++++++++-------- src/conn_linux.go | 14 ++++++++++---- src/device.go | 12 +++++++----- src/receive.go | 13 +++++++------ 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/src/conn.go b/src/conn.go index aa0b72b..0347262 100644 --- a/src/conn.go +++ b/src/conn.go @@ -37,15 +37,14 @@ func parseEndpoint(s string) (*net.UDPAddr, error) { /* Must hold device and net lock */ func unsafeCloseUDPListener(device *Device) error { + var err error netc := &device.net if netc.bind != nil { - if err := netc.bind.Close(); err != nil { - return err - } + err = netc.bind.Close() netc.bind = nil - netc.update.Broadcast() + netc.update.Add(1) } - return nil + return err } // must inform all listeners @@ -63,7 +62,7 @@ func UpdateUDPListener(device *Device) error { return err } - // wait for reader + // assumption: netc.update WaitGroup should be exactly 1 // open new sockets @@ -93,9 +92,10 @@ func UpdateUDPListener(device *Device) error { peer.mutex.Unlock() } - // inform readers of updated bind + // decrease waitgroup to 0 - netc.update.Broadcast() + device.log.Debug.Println("UDP bind has been updated") + netc.update.Done() } return nil diff --git a/src/conn_linux.go b/src/conn_linux.go index 05f9347..383ff7e 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -84,9 +84,15 @@ func (bind NativeBind) SetMark(value uint32) error { ) } +func closeUnblock(fd int) error { + // shutdown to unblock readers + unix.Shutdown(fd, unix.SHUT_RD) + return unix.Close(fd) +} + func (bind NativeBind) Close() error { - err1 := unix.Close(bind.sock6) - err2 := unix.Close(bind.sock4) + err1 := closeUnblock(bind.sock6) + err2 := closeUnblock(bind.sock4) if err1 != nil { return err1 } @@ -125,13 +131,13 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string { switch addr.Family { case unix.AF_INET6: - udpAddr.Port = int(addr.Port) + udpAddr.Port = int(ntohs(addr.Port)) udpAddr.IP = addr.Addr[:] return udpAddr.String() case unix.AF_INET: ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) - udpAddr.Port = int(ptr.Port) + udpAddr.Port = int(ntohs(ptr.Port)) udpAddr.IP = net.IPv4( ptr.Addr[0], ptr.Addr[1], diff --git a/src/device.go b/src/device.go index a348c68..033a387 100644 --- a/src/device.go +++ b/src/device.go @@ -23,10 +23,10 @@ type Device struct { } net struct { mutex sync.RWMutex - bind UDPBind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) - update *sync.Cond // the bind was updated + bind UDPBind // bind interface + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) + update sync.WaitGroup // the bind was updated (acting as a barrier) } mutex sync.RWMutex privateKey NoisePrivateKey @@ -167,7 +167,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { device.net.port = 0 device.net.bind = nil - device.net.update = sync.NewCond(&device.net.mutex) + device.net.update.Add(1) // start workers @@ -209,9 +209,11 @@ func (device *Device) RemoveAllPeers() { } func (device *Device) Close() { + device.log.Info.Println("Closing device") device.RemoveAllPeers() close(device.signal.stop) CloseUDPListener(device) + device.tun.device.Close() } func (device *Device) WaitChannel() chan struct{} { diff --git a/src/receive.go b/src/receive.go index cb53f80..3e88be3 100644 --- a/src/receive.go +++ b/src/receive.go @@ -95,23 +95,22 @@ func (device *Device) addToHandshakeQueue( func (device *Device) RoutineReceiveIncomming(IPVersion int) { logDebug := device.log.Debug - logDebug.Println("Routine, receive incomming, started") + logDebug.Println("Routine, receive incomming, IP version:", IPVersion) for { // wait for bind - logDebug.Println("Waiting for udp bind") - device.net.mutex.Lock() + logDebug.Println("Waiting for UDP socket, IP version:", IPVersion) + device.net.update.Wait() + device.net.mutex.RLock() bind := device.net.bind - device.net.mutex.Unlock() + device.net.mutex.RUnlock() if bind == nil { continue } - logDebug.Println("LISTEN\n\n\n") - // receive datagrams until conn is closed buffer := device.GetMessageBuffer() @@ -427,6 +426,8 @@ func (device *Device) RoutineHandshake() { err = peer.SendBuffer(packet) if err == nil { peer.TimerAnyAuthenticatedPacketTraversal() + } else { + logError.Println("Failed to send response to:", peer.String(), err) } case MessageResponseType: From 69fe86edf0ba371b9b0a54e522ec20d33e0ae129 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Tue, 14 Nov 2017 16:27:53 +0100 Subject: [PATCH 09/15] Initial working source caching --- src/conn.go | 9 +++++-- src/device.go | 12 +++------ src/main.go | 1 + src/receive.go | 61 +++++++++++++++++++++++++--------------------- src/tests/netns.sh | 33 ++++++++++++++----------- src/uapi.go | 1 - 6 files changed, 63 insertions(+), 54 deletions(-) diff --git a/src/conn.go b/src/conn.go index 0347262..a047bb6 100644 --- a/src/conn.go +++ b/src/conn.go @@ -2,6 +2,8 @@ package main import ( "errors" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" "net" ) @@ -42,7 +44,6 @@ func unsafeCloseUDPListener(device *Device) error { if netc.bind != nil { err = netc.bind.Close() netc.bind = nil - netc.update.Add(1) } return err } @@ -68,6 +69,8 @@ func UpdateUDPListener(device *Device) error { if device.tun.isUp.Get() { + device.log.Debug.Println("UDP bind updating") + // bind to new port var err error @@ -94,8 +97,10 @@ func UpdateUDPListener(device *Device) error { // decrease waitgroup to 0 + go device.RoutineReceiveIncomming(ipv4.Version, netc.bind) + go device.RoutineReceiveIncomming(ipv6.Version, netc.bind) + device.log.Debug.Println("UDP bind has been updated") - netc.update.Done() } return nil diff --git a/src/device.go b/src/device.go index 033a387..9422d49 100644 --- a/src/device.go +++ b/src/device.go @@ -1,8 +1,6 @@ package main import ( - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" "runtime" "sync" "sync/atomic" @@ -23,10 +21,9 @@ type Device struct { } net struct { mutex sync.RWMutex - bind UDPBind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) - update sync.WaitGroup // the bind was updated (acting as a barrier) + bind UDPBind // bind interface + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) } mutex sync.RWMutex privateKey NoisePrivateKey @@ -167,7 +164,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { device.net.port = 0 device.net.bind = nil - device.net.update.Add(1) // start workers @@ -179,8 +175,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) - go device.RoutineReceiveIncomming(ipv4.Version) - go device.RoutineReceiveIncomming(ipv6.Version) return device } diff --git a/src/main.go b/src/main.go index 05d56eb..eb3c67f 100644 --- a/src/main.go +++ b/src/main.go @@ -14,6 +14,7 @@ func printUsage() { } func main() { + // parse arguments var foreground bool diff --git a/src/receive.go b/src/receive.go index 3e88be3..ff3b7bd 100644 --- a/src/receive.go +++ b/src/receive.go @@ -20,12 +20,13 @@ type QueueHandshakeElement struct { } type QueueInboundElement struct { - dropped int32 - mutex sync.Mutex - buffer *[MaxMessageSize]byte - packet []byte - counter uint64 - keyPair *KeyPair + dropped int32 + mutex sync.Mutex + buffer *[MaxMessageSize]byte + packet []byte + counter uint64 + keyPair *KeyPair + endpoint Endpoint } func (elem *QueueInboundElement) Drop() { @@ -92,25 +93,13 @@ func (device *Device) addToHandshakeQueue( } } -func (device *Device) RoutineReceiveIncomming(IPVersion int) { +func (device *Device) RoutineReceiveIncomming(IP int, bind UDPBind) { logDebug := device.log.Debug - logDebug.Println("Routine, receive incomming, IP version:", IPVersion) + logDebug.Println("Routine, receive incomming, IP version:", IP) for { - // wait for bind - - logDebug.Println("Waiting for UDP socket, IP version:", IPVersion) - - device.net.update.Wait() - device.net.mutex.RLock() - bind := device.net.bind - device.net.mutex.RUnlock() - if bind == nil { - continue - } - // receive datagrams until conn is closed buffer := device.GetMessageBuffer() @@ -124,7 +113,7 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) { var endpoint Endpoint - switch IPVersion { + switch IP { case ipv4.Version: size, err = bind.ReceiveIPv4(buffer[:], &endpoint) case ipv6.Version: @@ -181,10 +170,11 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) { peer := value.peer elem := &QueueInboundElement{ - packet: packet, - buffer: buffer, - keyPair: keyPair, - dropped: AtomicFalse, + packet: packet, + buffer: buffer, + keyPair: keyPair, + dropped: AtomicFalse, + endpoint: endpoint, } elem.mutex.Lock() @@ -396,7 +386,6 @@ func (device *Device) RoutineHandshake() { peer.TimerAnyAuthenticatedPacketReceived() // update endpoint - // TODO: Discover destination address also, only update on change peer.mutex.Lock() peer.endpoint.set = true @@ -453,6 +442,13 @@ func (device *Device) RoutineHandshake() { continue } + // update endpoint + + peer.mutex.Lock() + peer.endpoint.set = true + peer.endpoint.value = elem.endpoint + peer.mutex.Unlock() + logDebug.Println("Received handshake initation from", peer) peer.TimerEphemeralKeyCreated() @@ -521,6 +517,13 @@ func (peer *Peer) RoutineSequentialReceiver() { } kp.mutex.Unlock() + // update endpoint + + peer.mutex.Lock() + peer.endpoint.set = true + peer.endpoint.value = elem.endpoint + peer.mutex.Unlock() + // check for keep-alive if len(elem.packet) == 0 { @@ -552,7 +555,8 @@ func (peer *Peer) RoutineSequentialReceiver() { src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] if device.routingTable.LookupIPv4(src) != peer { - logInfo.Println("Packet with unallowed source IP from", peer.String()) + logInfo.Println(src) + logInfo.Println("Packet with unallowed source IPv4 from", peer.String()) continue } @@ -577,7 +581,8 @@ func (peer *Peer) RoutineSequentialReceiver() { src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] if device.routingTable.LookupIPv6(src) != peer { - logInfo.Println("Packet with unallowed source IP from", peer.String()) + logInfo.Println(src) + logInfo.Println("Packet with unallowed source IPv6 from", peer.String()) continue } diff --git a/src/tests/netns.sh b/src/tests/netns.sh index 043da3e..9124b80 100755 --- a/src/tests/netns.sh +++ b/src/tests/netns.sh @@ -28,7 +28,7 @@ netns0="wg-test-$$-0" netns1="wg-test-$$-1" netns2="wg-test-$$-2" program="../wireguard-go" -export LOG_LEVEL="error" +export LOG_LEVEL="debug" pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; } pp() { pretty "" "$*"; "$@"; } @@ -147,6 +147,8 @@ tests() { n1 iperf3 -Z -n 1G -b 0 -u -c fd00::2 } +echo "4" + [[ $(ip1 link show dev wg1) =~ mtu\ ([0-9]+) ]] && orig_mtu="${BASH_REMATCH[1]}" big_mtu=$(( 34816 - 1500 + $orig_mtu )) @@ -185,14 +187,14 @@ ip0 -4 addr del 127.0.0.1/8 dev lo ip0 -4 addr add 127.212.121.99/8 dev lo n0 wg set wg1 listen-port 9999 n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000 -n1 ping6 -W 1 -c 1 fd00::20000 -[[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]] +n1 ping6 -W 1 -c 1 fd00::2 +[[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]] # Test using IPv6 that roaming works n1 wg set wg1 listen-port 9998 n1 wg set wg1 peer "$pub2" endpoint [::1]:20000 n1 ping -W 1 -c 1 192.168.241.2 -[[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]] +[[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]] # Test that crypto-RP filter works n1 wg set wg1 peer "$pub2" allowed-ips 192.168.241.0/24 @@ -212,7 +214,7 @@ n2 ncat -u 192.168.241.1 1111 <<<"X" ! read -r -N 1 -t 1 out <&4 kill $nmap_pid n0 wg set wg1 peer "$more_specific_key" remove -[[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]] +[[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]] ip1 link del wg1 ip2 link del wg2 @@ -232,8 +234,9 @@ ip2 link del wg2 # ip1 link add dev wg1 type wireguard # ip2 link add dev wg1 type wireguard -n1 $program wg1 -n2 $program wg2 +n1 $program -f wg1 & +n2 $program -f wg2 & +sleep 5 configure_peers @@ -263,7 +266,7 @@ n0 iptables -t nat -A POSTROUTING -s 192.168.1.0/24 -d 10.0.0.0/24 -j SNAT --to n0 wg set wg1 peer "$pub2" endpoint 10.0.0.100:20000 persistent-keepalive 1 n1 ping -W 1 -c 1 192.168.241.2 n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]] +[[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]] # Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`). pp sleep 3 n2 ping -W 1 -c 1 192.168.241.1 @@ -288,8 +291,9 @@ ip2 link del wg2 # ip1 link add dev wg1 type wireguard # ip2 link add dev wg1 type wireguard -n1 $program wg1 -n2 $program wg1 +n1 $program -f wg1 & +n2 $program -f wg2 & +sleep 5 configure_peers @@ -336,17 +340,18 @@ waitiface $netns1 veth1 waitiface $netns2 veth2 n0 wg set wg2 peer "$pub1" endpoint 10.0.0.1:10000 n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]] +[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]] n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::1]:10000 n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]] +[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]] n0 wg set wg2 peer "$pub1" endpoint 10.0.0.2:10000 n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]] +[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]] n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::2]:10000 n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]] +[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]] ip1 link del veth1 ip1 link del wg1 ip2 link del wg2 +echo "done" diff --git a/src/uapi.go b/src/uapi.go index 5098e3d..5e40939 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -248,7 +248,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { peer.mutex.Lock() err := peer.endpoint.value.SetDst(value) - fmt.Println(peer.endpoint.value.DstToString(), err) peer.endpoint.set = (err == nil) peer.mutex.Unlock() if err != nil { From 88801529fd4097993f7c448b1c3eee0abc8cb51c Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Tue, 14 Nov 2017 18:26:28 +0100 Subject: [PATCH 10/15] Moved TUN device creation to pre-fork --- src/daemon_linux.go | 11 +---- src/device.go | 4 +- src/main.go | 104 +++++++++++++++++++++++++++++--------------- src/tests/netns.sh | 21 ++++----- src/tun.go | 2 + src/tun_linux.go | 28 ++++++++++++ 6 files changed, 111 insertions(+), 59 deletions(-) diff --git a/src/daemon_linux.go b/src/daemon_linux.go index 730f89e..8210f8b 100644 --- a/src/daemon_linux.go +++ b/src/daemon_linux.go @@ -11,18 +11,9 @@ import ( * TODO: Use env variable to spawn in background */ -func Daemonize() error { +func Daemonize(attr *os.ProcAttr) error { argv := []string{os.Args[0], "--foreground"} argv = append(argv, os.Args[1:]...) - attr := &os.ProcAttr{ - Dir: ".", - Env: os.Environ(), - Files: []*os.File{ - os.Stdin, - nil, - nil, - }, - } process, err := os.StartProcess( argv[0], argv, diff --git a/src/device.go b/src/device.go index 9422d49..429ee46 100644 --- a/src/device.go +++ b/src/device.go @@ -126,13 +126,13 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { device.pool.messageBuffers.Put(msg) } -func NewDevice(tun TUNDevice, logLevel int) *Device { +func NewDevice(tun TUNDevice, logger *Logger) *Device { device := new(Device) device.mutex.Lock() defer device.mutex.Unlock() - device.log = NewLogger(logLevel, "("+tun.Name()+") ") + device.log = logger device.peers = make(map[NoisePublicKey]*Peer) device.tun.device = tun diff --git a/src/main.go b/src/main.go index eb3c67f..3808c9c 100644 --- a/src/main.go +++ b/src/main.go @@ -2,10 +2,14 @@ package main import ( "fmt" - "log" "os" "os/signal" "runtime" + "strconv" +) + +const ( + EnvWGTunFD = "WG_TUN_FD" ) func printUsage() { @@ -43,28 +47,6 @@ func main() { interfaceName = os.Args[1] } - // daemonize the process - - if !foreground { - err := Daemonize() - if err != nil { - log.Println("Failed to daemonize:", err) - } - return - } - - // increase number of go workers (for Go <1.5) - - runtime.GOMAXPROCS(runtime.NumCPU()) - - // open TUN device - - tun, err := CreateTUN(interfaceName) - if err != nil { - log.Println("Failed to create tun device:", err) - return - } - // get log level (default: info) logLevel := func() int { @@ -79,22 +61,76 @@ func main() { return LogLevelInfo }() + logger := NewLogger( + logLevel, + fmt.Sprintf("(%s) ", interfaceName), + ) + logger.Debug.Println("Debug log enabled") + + // open TUN device + + tun, err := func() (TUNDevice, error) { + tunFdStr := os.Getenv(EnvWGTunFD) + if tunFdStr == "" { + return CreateTUN(interfaceName) + } + + // construct tun device from supplied FD + + fd, err := strconv.ParseUint(tunFdStr, 10, 32) + if err != nil { + return nil, err + } + + file := os.NewFile(uintptr(fd), "/dev/net/tun") + return CreateTUNFromFile(interfaceName, file) + }() + + if err != nil { + logger.Error.Println("Failed to create TUN device:", err) + } + + // daemonize the process + + if !foreground { + env := os.Environ() + _, ok := os.LookupEnv(EnvWGTunFD) + if !ok { + kvp := fmt.Sprintf("%s=3", EnvWGTunFD) + env = append(env, kvp) + } + attr := &os.ProcAttr{ + Files: []*os.File{ + nil, // stdin + nil, // stdout + nil, // stderr + tun.File(), + }, + Dir: ".", + Env: env, + } + err = Daemonize(attr) + if err != nil { + logger.Error.Println("Failed to daemonize:", err) + } + return + } + + // increase number of go workers (for Go <1.5) + + runtime.GOMAXPROCS(runtime.NumCPU()) + // create wireguard device - device := NewDevice(tun, logLevel) - - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - - logInfo.Println("Device started") - logDebug.Println("Debug log enabled") + device := NewDevice(tun, logger) + logger.Info.Println("Device started") // start configuration lister uapi, err := NewUAPIListener(interfaceName) if err != nil { - logError.Fatal("UAPI listen error:", err) + logger.Error.Println("UAPI listen error:", err) + return } errs := make(chan error) @@ -112,7 +148,7 @@ func main() { } }() - logInfo.Println("UAPI listener started") + logger.Info.Println("UAPI listener started") // wait for program to terminate @@ -129,5 +165,5 @@ func main() { uapi.Close() - logInfo.Println("Closing") + logger.Info.Println("Shutting down") } diff --git a/src/tests/netns.sh b/src/tests/netns.sh index 9124b80..b5c2f9c 100755 --- a/src/tests/netns.sh +++ b/src/tests/netns.sh @@ -28,7 +28,7 @@ netns0="wg-test-$$-0" netns1="wg-test-$$-1" netns2="wg-test-$$-2" program="../wireguard-go" -export LOG_LEVEL="debug" +export LOG_LEVEL="info" pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; } pp() { pretty "" "$*"; "$@"; } @@ -72,13 +72,11 @@ pp ip netns add $netns2 ip0 link set up dev lo # ip0 link add dev wg1 type wireguard -n0 $program -f wg1 & -sleep 1 +n0 $program wg1 ip0 link set wg1 netns $netns1 # ip0 link add dev wg1 type wireguard -n0 $program -f wg2 & -sleep 1 +n0 $program wg2 ip0 link set wg2 netns $netns2 key1="$(pp wg genkey)" @@ -147,8 +145,6 @@ tests() { n1 iperf3 -Z -n 1G -b 0 -u -c fd00::2 } -echo "4" - [[ $(ip1 link show dev wg1) =~ mtu\ ([0-9]+) ]] && orig_mtu="${BASH_REMATCH[1]}" big_mtu=$(( 34816 - 1500 + $orig_mtu )) @@ -234,9 +230,8 @@ ip2 link del wg2 # ip1 link add dev wg1 type wireguard # ip2 link add dev wg1 type wireguard -n1 $program -f wg1 & -n2 $program -f wg2 & -sleep 5 +n1 $program wg1 +n2 $program wg2 configure_peers @@ -291,9 +286,8 @@ ip2 link del wg2 # ip1 link add dev wg1 type wireguard # ip2 link add dev wg1 type wireguard -n1 $program -f wg1 & -n2 $program -f wg2 & -sleep 5 +n1 $program wg1 +n2 $program wg2 configure_peers @@ -354,4 +348,5 @@ n2 ping -W 1 -c 1 192.168.241.1 ip1 link del veth1 ip1 link del wg1 ip2 link del wg2 + echo "done" diff --git a/src/tun.go b/src/tun.go index 9eed987..5bdac0e 100644 --- a/src/tun.go +++ b/src/tun.go @@ -1,6 +1,7 @@ package main import ( + "os" "sync/atomic" ) @@ -15,6 +16,7 @@ const ( ) type TUNDevice interface { + File() *os.File // returns the file descriptor of the device Read([]byte) (int, error) // read a packet from the device (without any additional headers) Write([]byte) (int, error) // writes a packet to the device (without any additional headers) MTU() (int, error) // returns the MTU of the device diff --git a/src/tun_linux.go b/src/tun_linux.go index accc6c6..ce6304c 100644 --- a/src/tun_linux.go +++ b/src/tun_linux.go @@ -56,6 +56,11 @@ type NativeTun struct { events chan TUNEvent // device related events } +func (tun *NativeTun) File() *os.File { + println(tun.fd.Name()) + return tun.fd +} + func (tun *NativeTun) RoutineNetlinkListener() { sock := int(C.bind_rtmgrp()) if sock < 0 { @@ -248,6 +253,29 @@ func (tun *NativeTun) Close() error { return nil } +func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) { + device := &NativeTun{ + fd: fd, + name: name, + events: make(chan TUNEvent, 5), + errors: make(chan error, 5), + } + + // start event listener + + var err error + device.index, err = getIFIndex(device.name) + if err != nil { + return nil, err + } + + go device.RoutineNetlinkListener() + + // set default MTU + + return device, device.setMTU(DefaultMTU) +} + func CreateTUN(name string) (TUNDevice, error) { // open clone device From e1227d3af480eae72639cde842b4d538c58936dc Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 17 Nov 2017 14:36:08 +0100 Subject: [PATCH 11/15] Allows passing UAPI fd to service --- src/main.go | 59 ++++++++++++++++------- src/tun_linux.go | 2 +- src/uapi_linux.go | 117 +++++++++++++++++++++++++++------------------- 3 files changed, 111 insertions(+), 67 deletions(-) diff --git a/src/main.go b/src/main.go index 3808c9c..7d86716 100644 --- a/src/main.go +++ b/src/main.go @@ -9,7 +9,8 @@ import ( ) const ( - EnvWGTunFD = "WG_TUN_FD" + ENV_WG_TUN_FD = "WG_TUN_FD" + ENV_WG_UAPI_FD = "WG_UAPI_FD" ) func printUsage() { @@ -65,46 +66,69 @@ func main() { logLevel, fmt.Sprintf("(%s) ", interfaceName), ) + logger.Debug.Println("Debug log enabled") - // open TUN device + // open TUN device (or use supplied fd) tun, err := func() (TUNDevice, error) { - tunFdStr := os.Getenv(EnvWGTunFD) + tunFdStr := os.Getenv(ENV_WG_TUN_FD) if tunFdStr == "" { return CreateTUN(interfaceName) } - // construct tun device from supplied FD + // construct tun device from supplied fd fd, err := strconv.ParseUint(tunFdStr, 10, 32) if err != nil { return nil, err } - file := os.NewFile(uintptr(fd), "/dev/net/tun") + file := os.NewFile(uintptr(fd), "") return CreateTUNFromFile(interfaceName, file) }() if err != nil { logger.Error.Println("Failed to create TUN device:", err) + os.Exit(ExitSetupFailed) } + // open UAPI file (or use supplied fd) + + fileUAPI, err := func() (*os.File, error) { + uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) + if uapiFdStr == "" { + return UAPIOpen(interfaceName) + } + + // use supplied fd + + fd, err := strconv.ParseUint(uapiFdStr, 10, 32) + if err != nil { + return nil, err + } + + return os.NewFile(uintptr(fd), ""), nil + }() + + if err != nil { + logger.Error.Println("UAPI listen error:", err) + os.Exit(ExitSetupFailed) + return + } // daemonize the process if !foreground { env := os.Environ() - _, ok := os.LookupEnv(EnvWGTunFD) - if !ok { - kvp := fmt.Sprintf("%s=3", EnvWGTunFD) - env = append(env, kvp) - } + env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD)) + env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) attr := &os.ProcAttr{ Files: []*os.File{ nil, // stdin nil, // stdout nil, // stderr tun.File(), + fileUAPI, }, Dir: ".", Env: env, @@ -112,6 +136,7 @@ func main() { err = Daemonize(attr) if err != nil { logger.Error.Println("Failed to daemonize:", err) + os.Exit(ExitSetupFailed) } return } @@ -123,20 +148,17 @@ func main() { // create wireguard device device := NewDevice(tun, logger) + logger.Info.Println("Device started") - // start configuration lister - - uapi, err := NewUAPIListener(interfaceName) - if err != nil { - logger.Error.Println("UAPI listen error:", err) - return - } + // start uapi listener errs := make(chan error) term := make(chan os.Signal) wait := device.WaitChannel() + uapi, err := UAPIListen(interfaceName, fileUAPI) + go func() { for { conn, err := uapi.Accept() @@ -161,9 +183,10 @@ func main() { case <-errs: } - // clean up UAPI bind + // clean up uapi.Close() + device.Close() logger.Info.Println("Shutting down") } diff --git a/src/tun_linux.go b/src/tun_linux.go index ce6304c..2a5b276 100644 --- a/src/tun_linux.go +++ b/src/tun_linux.go @@ -227,7 +227,7 @@ func (tun *NativeTun) MTU() (int, error) { val := binary.LittleEndian.Uint32(ifr[16:20]) if val >= (1 << 31) { - return int(val-(1<<31)) - (1 << 31), nil + return int(toInt32(val)), nil } return int(val), nil } diff --git a/src/uapi_linux.go b/src/uapi_linux.go index cb9d858..f97a18a 100644 --- a/src/uapi_linux.go +++ b/src/uapi_linux.go @@ -10,12 +10,12 @@ import ( ) const ( - ipcErrorIO = -int64(unix.EIO) - ipcErrorProtocol = -int64(unix.EPROTO) - ipcErrorInvalid = -int64(unix.EINVAL) - ipcErrorPortInUse = -int64(unix.EADDRINUSE) - socketDirectory = "/var/run/wireguard" - socketName = "%s.sock" + ipcErrorIO = -int64(unix.EIO) + ipcErrorProtocol = -int64(unix.EPROTO) + ipcErrorInvalid = -int64(unix.EINVAL) + ipcErrorPortInUse = -int64(unix.EADDRINUSE) + socketDirectory = "/var/run/wireguard" + socketName = "%s.sock" ) type UAPIListener struct { @@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr { return nil } -func connectUnixSocket(path string) (net.Listener, error) { +func UAPIListen(name string, file *os.File) (net.Listener, error) { - // attempt inital connection + // wrap file in listener - listener, err := net.Listen("unix", path) - if err == nil { - return listener, nil - } - - // check if active - - _, err = net.Dial("unix", path) - if err == nil { - return nil, errors.New("Unix socket in use") - } - - // attempt cleanup - - err = os.Remove(path) - if err != nil { - return nil, err - } - - return net.Listen("unix", path) -} - -func NewUAPIListener(name string) (net.Listener, error) { - - // check if path exist - - err := os.MkdirAll(socketDirectory, 077) - if err != nil && !os.IsExist(err) { - return nil, err - } - - // open UNIX socket - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - listener, err := connectUnixSocket(socketPath) + listener, err := net.FileListener(file) if err != nil { return nil, err } @@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) { // watch for deletion of socket + socketPath := path.Join( + socketDirectory, + fmt.Sprintf(socketName, name), + ) + uapi.inotifyFd, err = unix.InotifyInit() if err != nil { return nil, err @@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) { go func(l *UAPIListener) { var buff [4096]byte for { - unix.Read(uapi.inotifyFd, buff[:]) + // start with lstat to avoid race condition if _, err := os.Lstat(socketPath); os.IsNotExist(err) { l.connErr <- err return } + unix.Read(uapi.inotifyFd, buff[:]) } }(uapi) @@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) { return uapi, nil } + +func UAPIOpen(name string) (*os.File, error) { + + // check if path exist + + err := os.MkdirAll(socketDirectory, 0600) + if err != nil && !os.IsExist(err) { + return nil, err + } + + // open UNIX socket + + socketPath := path.Join( + socketDirectory, + fmt.Sprintf(socketName, name), + ) + + addr, err := net.ResolveUnixAddr("unix", socketPath) + if err != nil { + return nil, err + } + + listener, err := func() (*net.UnixListener, error) { + + // initial connection attempt + + listener, err := net.ListenUnix("unix", addr) + if err == nil { + return listener, nil + } + + // check if socket already active + + _, err = net.Dial("unix", socketPath) + if err == nil { + return nil, errors.New("unix socket in use") + } + + // cleanup & attempt again + + err = os.Remove(socketPath) + if err != nil { + return nil, err + } + return net.ListenUnix("unix", addr) + }() + + if err != nil { + return nil, err + } + + return listener.File() +} From fa399a91d5da9874cbf248e00db8dbd87b587e91 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 17 Nov 2017 17:25:45 +0100 Subject: [PATCH 12/15] Ported remaining netns.sh - Ported remaining netns.sh tests - Begin work on generic implementation of bind interface --- src/conn.go | 16 ++++++++++ src/conn_default.go | 35 +++++++++++++++++++++ src/conn_linux.go | 2 +- src/cookie_test.go | 7 ++--- src/daemon_linux.go | 13 +++++--- src/device.go | 8 +++-- src/helper_test.go | 8 ++++- src/misc.go | 8 +++++ src/noise_test.go | 8 ++--- src/receive.go | 27 +++++++++++----- src/tests/netns.sh | 76 +++++++++++++++++++++++++++++++++++++++++++-- src/tun_linux.go | 1 - src/uapi.go | 13 +++++++- 13 files changed, 194 insertions(+), 28 deletions(-) diff --git a/src/conn.go b/src/conn.go index a047bb6..3cf00ab 100644 --- a/src/conn.go +++ b/src/conn.go @@ -15,6 +15,22 @@ type UDPBind interface { Close() error } +/* An Endpoint maintains the source/destination caching for a peer + * + * dst : the remote address of a peer + * src : the local address from which datagrams originate going to the peer + * + */ +type UDPEndpoint interface { + ClearSrc() // clears the source address + ClearDst() // clears the destination address + SrcToString() string // returns the local source address (ip:port) + DstToString() string // returns the destination address (ip:port) + DstToBytes() []byte // used for mac2 cookie calculations + DstIP() net.IP + SrcIP() net.IP +} + func parseEndpoint(s string) (*net.UDPAddr, error) { // ensure that the host is an IP address diff --git a/src/conn_default.go b/src/conn_default.go index 279643e..31cab5c 100644 --- a/src/conn_default.go +++ b/src/conn_default.go @@ -6,6 +6,41 @@ import ( "net" ) +/* This code is meant to be a temporary solution + * on platforms for which the sticky socket / source caching behavior + * has not yet been implemented. + * + * See conn_linux.go for an implementation on the linux platform. + */ + +type Endpoint *net.UDPAddr + +type NativeBind *net.UDPConn + +func CreateUDPBind(port uint16) (UDPBind, uint16, error) { + + // listen + + addr := UDPAddr{ + Port: int(port), + } + conn, err := net.ListenUDP("udp", &addr) + if err != nil { + return nil, 0, err + } + + // retrieve port + + laddr := conn.LocalAddr() + uaddr, _ = net.ResolveUDPAddr( + laddr.Network(), + laddr.String(), + ) + return uaddr.Port +} + +func (_ Endpoint) ClearSrc() {} + func SetMark(conn *net.UDPConn, value uint32) error { return nil } diff --git a/src/conn_linux.go b/src/conn_linux.go index 383ff7e..fb576b1 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -168,7 +168,7 @@ func (end *Endpoint) DstIP() net.IP { } } -func (end *Endpoint) SrcToBytes() []byte { +func (end *Endpoint) DstToBytes() []byte { ptr := unsafe.Pointer(&end.src) arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) return arr[:] diff --git a/src/cookie_test.go b/src/cookie_test.go index 193a76e..d745fe7 100644 --- a/src/cookie_test.go +++ b/src/cookie_test.go @@ -1,7 +1,6 @@ package main import ( - "net" "testing" ) @@ -25,7 +24,7 @@ func TestCookieMAC1(t *testing.T) { // check mac1 - src, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4000") + src := []byte{192, 168, 13, 37, 10, 10, 10} checkMAC1 := func(msg []byte) { generator.AddMacs(msg) @@ -128,12 +127,12 @@ func TestCookieMAC1(t *testing.T) { msg[5] ^= 0x20 - srcBad1, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4001") + srcBad1 := []byte{192, 168, 13, 37, 40, 01} if checker.CheckMAC2(msg, srcBad1) { t.Fatal("MAC2 generation/verification failed") } - srcBad2, _ := net.ResolveUDPAddr("udp", "192.168.13.38:4000") + srcBad2 := []byte{192, 168, 13, 38, 40, 01} if checker.CheckMAC2(msg, srcBad2) { t.Fatal("MAC2 generation/verification failed") } diff --git a/src/daemon_linux.go b/src/daemon_linux.go index 8210f8b..e1aaede 100644 --- a/src/daemon_linux.go +++ b/src/daemon_linux.go @@ -2,20 +2,25 @@ package main import ( "os" + "os/exec" ) /* Daemonizes the process on linux * * This is done by spawning and releasing a copy with the --foreground flag - * - * TODO: Use env variable to spawn in background */ - func Daemonize(attr *os.ProcAttr) error { + // I would like to use os.Executable, + // however this means dropping support for Go <1.8 + path, err := exec.LookPath(os.Args[0]) + if err != nil { + return err + } + argv := []string{os.Args[0], "--foreground"} argv = append(argv, os.Args[1:]...) process, err := os.StartProcess( - argv[0], + path, argv, attr, ) diff --git a/src/device.go b/src/device.go index 429ee46..0085cee 100644 --- a/src/device.go +++ b/src/device.go @@ -8,8 +8,9 @@ import ( ) type Device struct { - log *Logger // collection of loggers for levels - idCounter uint // for assigning debug ids to peers + closed AtomicBool // device is closed? (acting as guard) + log *Logger // collection of loggers for levels + idCounter uint // for assigning debug ids to peers fwMark uint32 tun struct { device TUNDevice @@ -203,6 +204,9 @@ func (device *Device) RemoveAllPeers() { } func (device *Device) Close() { + if device.closed.Swap(true) { + return + } device.log.Info.Println("Closing device") device.RemoveAllPeers() close(device.signal.stop) diff --git a/src/helper_test.go b/src/helper_test.go index fc171e8..8548121 100644 --- a/src/helper_test.go +++ b/src/helper_test.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "os" "testing" ) @@ -15,6 +16,10 @@ type DummyTUN struct { events chan TUNEvent } +func (tun *DummyTUN) File() *os.File { + return nil +} + func (tun *DummyTUN) Name() string { return tun.name } @@ -67,7 +72,8 @@ func randDevice(t *testing.T) *Device { t.Fatal(err) } tun, _ := CreateDummyTUN("dummy") - device := NewDevice(tun, LogLevelError) + logger := NewLogger(LogLevelError, "") + device := NewDevice(tun, logger) device.SetPrivateKey(sk) return device } diff --git a/src/misc.go b/src/misc.go index bbe0d68..b43e97e 100644 --- a/src/misc.go +++ b/src/misc.go @@ -21,6 +21,14 @@ func (a *AtomicBool) Get() bool { return atomic.LoadInt32(&a.flag) == AtomicTrue } +func (a *AtomicBool) Swap(val bool) bool { + flag := AtomicFalse + if val { + flag = AtomicTrue + } + return atomic.SwapInt32(&a.flag, flag) == AtomicTrue +} + func (a *AtomicBool) Set(val bool) { flag := AtomicFalse if val { diff --git a/src/noise_test.go b/src/noise_test.go index 48408f9..0d7f0e9 100644 --- a/src/noise_test.go +++ b/src/noise_test.go @@ -117,8 +117,8 @@ func TestNoiseHandshake(t *testing.T) { var err error var out []byte var nonce [12]byte - out = key1.send.aead.Seal(out, nonce[:], testMsg, nil) - out, err = key2.receive.aead.Open(out[:0], nonce[:], out, nil) + out = key1.send.Seal(out, nonce[:], testMsg, nil) + out, err = key2.receive.Open(out[:0], nonce[:], out, nil) assertNil(t, err) assertEqual(t, out, testMsg) }() @@ -128,8 +128,8 @@ func TestNoiseHandshake(t *testing.T) { var err error var out []byte var nonce [12]byte - out = key2.send.aead.Seal(out, nonce[:], testMsg, nil) - out, err = key1.receive.aead.Open(out[:0], nonce[:], out, nil) + out = key2.send.Seal(out, nonce[:], testMsg, nil) + out, err = key1.receive.Open(out[:0], nonce[:], out, nil) assertNil(t, err) assertEqual(t, out, testMsg) }() diff --git a/src/receive.go b/src/receive.go index ff3b7bd..b8b06f7 100644 --- a/src/receive.go +++ b/src/receive.go @@ -311,7 +311,10 @@ func (device *Device) RoutineHandshake() { return } - srcBytes := elem.endpoint.SrcToBytes() + // endpoints destination address is the source of the datagram + + srcBytes := elem.endpoint.DstToBytes() + if device.IsUnderLoad() { // verify MAC2 field @@ -320,8 +323,12 @@ func (device *Device) RoutineHandshake() { // construct cookie reply - logDebug.Println("Sending cookie reply to:", elem.endpoint.SrcToString()) - sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" + logDebug.Println( + "Sending cookie reply to:", + elem.endpoint.DstToString(), + ) + + sender := binary.LittleEndian.Uint32(elem.packet[4:8]) reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes) if err != nil { logError.Println("Failed to create cookie reply:", err) @@ -555,8 +562,10 @@ func (peer *Peer) RoutineSequentialReceiver() { src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] if device.routingTable.LookupIPv4(src) != peer { - logInfo.Println(src) - logInfo.Println("Packet with unallowed source IPv4 from", peer.String()) + logInfo.Println( + "IPv4 packet with unallowed source address from", + peer.String(), + ) continue } @@ -581,8 +590,10 @@ func (peer *Peer) RoutineSequentialReceiver() { src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] if device.routingTable.LookupIPv6(src) != peer { - logInfo.Println(src) - logInfo.Println("Packet with unallowed source IPv6 from", peer.String()) + logInfo.Println( + "IPv6 packet with unallowed source address from", + peer.String(), + ) continue } @@ -591,7 +602,7 @@ func (peer *Peer) RoutineSequentialReceiver() { continue } - // write to tun + // write to tun device atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) _, err := device.tun.device.Write(elem.packet) diff --git a/src/tests/netns.sh b/src/tests/netns.sh index b5c2f9c..22abea8 100755 --- a/src/tests/netns.sh +++ b/src/tests/netns.sh @@ -20,6 +20,14 @@ # wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1 # interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further # details on how this is accomplished. + +# This code is ported to the WireGuard-Go directly from the kernel project. +# +# Please ensure that you have installed the newest version of the WireGuard +# tools from the WireGuard project and before running these tests as: +# +# ./netns.sh + set -e exec 3>&1 @@ -27,7 +35,7 @@ export WG_HIDE_KEYS=never netns0="wg-test-$$-0" netns1="wg-test-$$-1" netns2="wg-test-$$-2" -program="../wireguard-go" +program=$1 export LOG_LEVEL="info" pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; } @@ -349,4 +357,68 @@ ip1 link del veth1 ip1 link del wg1 ip2 link del wg2 -echo "done" +# Test that Netlink/IPC is working properly by doing things that usually cause split responses + +n0 $program wg0 +sleep 5 +config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" ) +for a in {1..255}; do + for b in {0..255}; do + config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" ) + done +done +n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") +i=0 +for ip in $(n0 wg show wg0 allowed-ips); do + ((++i)) +done +((i == 255*256*2+1)) +ip0 link del wg0 + +n0 $program wg0 +config=( "[Interface]" "PrivateKey=$(wg genkey)" ) +for a in {1..40}; do + config+=( "[Peer]" "PublicKey=$(wg genkey)" ) + for b in {1..52}; do + config+=( "AllowedIPs=$a.$b.0.0/16" ) + done +done +n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") +i=0 +while read -r line; do + j=0 + for ip in $line; do + ((++j)) + done + ((j == 53)) + ((++i)) +done < <(n0 wg show wg0 allowed-ips) +((i == 40)) +ip0 link del wg0 + +n0 $program wg0 +config=( ) +for i in {1..29}; do + config+=( "[Peer]" "PublicKey=$(wg genkey)" ) +done +config+=( "[Peer]" "PublicKey=$(wg genkey)" "AllowedIPs=255.2.3.4/32,abcd::255/128" ) +n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") +n0 wg showconf wg0 > /dev/null +ip0 link del wg0 + +! n0 wg show doesnotexist || false + +declare -A objects +while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do + [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue + objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}" +done < /dev/kmsg +alldeleted=1 +for object in "${!objects[@]}"; do + if [[ ${objects["$object"]} != *createddestroyed ]]; then + echo "Error: $object: merely ${objects["$object"]}" >&3 + alldeleted=0 + fi +done +[[ $alldeleted -eq 1 ]] +pretty "" "Objects that were created were also destroyed." diff --git a/src/tun_linux.go b/src/tun_linux.go index 2a5b276..a728a48 100644 --- a/src/tun_linux.go +++ b/src/tun_linux.go @@ -57,7 +57,6 @@ type NativeTun struct { } func (tun *NativeTun) File() *os.File { - println(tun.fd.Name()) return tun.fd } diff --git a/src/uapi.go b/src/uapi.go index 5e40939..e1d0929 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -145,11 +145,22 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } case "fwmark": - fwmark, err := strconv.ParseUint(value, 10, 32) + + // parse fwmark field + + fwmark, err := func() (uint32, error) { + if value == "" { + return 0, nil + } + mark, err := strconv.ParseUint(value, 10, 32) + return uint32(mark), err + }() + if err != nil { logError.Println("Invalid fwmark", err) return &IPCError{Code: ipcErrorInvalid} } + device.net.mutex.Lock() device.net.fwmark = uint32(fwmark) device.net.mutex.Unlock() From d10126f883ad39567248540347b5469956ab8b2e Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 18 Nov 2017 23:34:02 +0100 Subject: [PATCH 13/15] Moved endpoint into interface and simplified peer --- src/conn.go | 20 +++++++----- src/conn_linux.go | 83 +++++++++++++++++++++++++++++------------------ src/device.go | 6 ++-- src/peer.go | 19 ++++------- src/receive.go | 29 +++++++---------- src/uapi.go | 24 +++++++++----- 6 files changed, 101 insertions(+), 80 deletions(-) diff --git a/src/conn.go b/src/conn.go index 3cf00ab..74bb075 100644 --- a/src/conn.go +++ b/src/conn.go @@ -7,26 +7,28 @@ import ( "net" ) -type UDPBind interface { +/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic + */ +type Bind interface { SetMark(value uint32) error - ReceiveIPv6(buff []byte, end *Endpoint) (int, error) - ReceiveIPv4(buff []byte, end *Endpoint) (int, error) - Send(buff []byte, end *Endpoint) error + ReceiveIPv6(buff []byte) (int, Endpoint, error) + ReceiveIPv4(buff []byte) (int, Endpoint, error) + Send(buff []byte, end Endpoint) error Close() error } /* An Endpoint maintains the source/destination caching for a peer * - * dst : the remote address of a peer + * dst : the remote address of a peer ("endpoint" in uapi terminology) * src : the local address from which datagrams originate going to the peer - * */ -type UDPEndpoint interface { +type Endpoint interface { ClearSrc() // clears the source address ClearDst() // clears the destination address SrcToString() string // returns the local source address (ip:port) DstToString() string // returns the destination address (ip:port) DstToBytes() []byte // used for mac2 cookie calculations + SetDst(string) error // used for manually setting the endpoint (uapi) DstIP() net.IP SrcIP() net.IP } @@ -107,7 +109,9 @@ func UpdateUDPListener(device *Device) error { for _, peer := range device.peers { peer.mutex.Lock() - peer.endpoint.value.ClearSrc() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } peer.mutex.Unlock() } diff --git a/src/conn_linux.go b/src/conn_linux.go index fb576b1..46f873f 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -21,22 +21,24 @@ import ( * See e.g. https://github.com/golang/go/issues/17930 * So this code is remains platform dependent. */ - -type Endpoint struct { +type NativeEndpoint struct { src unix.RawSockaddrInet6 dst unix.RawSockaddrInet6 } -type IPv4Source struct { - src unix.RawSockaddrInet4 - Ifindex int32 -} - type NativeBind struct { sock4 int sock6 int } +var _ Endpoint = (*NativeEndpoint)(nil) +var _ Bind = NativeBind{} + +type IPv4Source struct { + src unix.RawSockaddrInet4 + Ifindex int32 +} + func htons(val uint16) uint16 { var out [unsafe.Sizeof(val)]byte binary.BigEndian.PutUint16(out[:], val) @@ -48,7 +50,11 @@ func ntohs(val uint16) uint16 { return binary.BigEndian.Uint16((*tmp)[:]) } -func CreateUDPBind(port uint16) (UDPBind, uint16, error) { +func NewEndpoint() Endpoint { + return &NativeEndpoint{} +} + +func CreateUDPBind(port uint16) (Bind, uint16, error) { var err error var bind NativeBind @@ -99,28 +105,33 @@ func (bind NativeBind) Close() error { return err2 } -func (bind NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { - return receive6( +func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + var end NativeEndpoint + n, err := receive6( bind.sock6, buff, - end, + &end, ) + return n, &end, err } -func (bind NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { - return receive4( +func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + var end NativeEndpoint + n, err := receive4( bind.sock4, buff, - end, + &end, ) + return n, &end, err } -func (bind NativeBind) Send(buff []byte, end *Endpoint) error { - switch end.dst.Family { +func (bind NativeBind) Send(buff []byte, end Endpoint) error { + nend := end.(*NativeEndpoint) + switch nend.dst.Family { case unix.AF_INET6: - return send6(bind.sock6, end, buff) + return send6(bind.sock6, nend, buff) case unix.AF_INET: - return send4(bind.sock4, end, buff) + return send4(bind.sock4, nend, buff) default: return errors.New("Unknown address family of destination") } @@ -151,12 +162,12 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string { } } -func (end *Endpoint) DstIP() net.IP { - switch end.dst.Family { +func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP { + switch addr.Family { case unix.AF_INET6: - return end.dst.Addr[:] + return addr.Addr[:] case unix.AF_INET: - ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) + ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) return net.IPv4( ptr.Addr[0], ptr.Addr[1], @@ -168,25 +179,33 @@ func (end *Endpoint) DstIP() net.IP { } } -func (end *Endpoint) DstToBytes() []byte { +func (end *NativeEndpoint) SrcIP() net.IP { + return rawAddrToIP(end.src) +} + +func (end *NativeEndpoint) DstIP() net.IP { + return rawAddrToIP(end.dst) +} + +func (end *NativeEndpoint) DstToBytes() []byte { ptr := unsafe.Pointer(&end.src) arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) return arr[:] } -func (end *Endpoint) SrcToString() string { +func (end *NativeEndpoint) SrcToString() string { return sockaddrToString(end.src) } -func (end *Endpoint) DstToString() string { +func (end *NativeEndpoint) DstToString() string { return sockaddrToString(end.dst) } -func (end *Endpoint) ClearDst() { +func (end *NativeEndpoint) ClearDst() { end.dst = unix.RawSockaddrInet6{} } -func (end *Endpoint) ClearSrc() { +func (end *NativeEndpoint) ClearSrc() { end.src = unix.RawSockaddrInet6{} } @@ -306,7 +325,7 @@ func create6(port uint16) (int, uint16, error) { return fd, uint16(addr.Port), err } -func (end *Endpoint) SetDst(s string) error { +func (end *NativeEndpoint) SetDst(s string) error { addr, err := parseEndpoint(s) if err != nil { return err @@ -342,7 +361,7 @@ func (end *Endpoint) SetDst(s string) error { return errors.New("Failed to recognize IP address format") } -func send6(sock int, end *Endpoint, buff []byte) error { +func send6(sock int, end *NativeEndpoint, buff []byte) error { // construct message header @@ -404,7 +423,7 @@ func send6(sock int, end *Endpoint, buff []byte) error { return errno } -func send4(sock int, end *Endpoint, buff []byte) error { +func send4(sock int, end *NativeEndpoint, buff []byte) error { // construct message header @@ -470,7 +489,7 @@ func send4(sock int, end *Endpoint, buff []byte) error { return errno } -func receive4(sock int, buff []byte, end *Endpoint) (int, error) { +func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { // contruct message header @@ -518,7 +537,7 @@ func receive4(sock int, buff []byte, end *Endpoint) (int, error) { return int(size), nil } -func receive6(sock int, buff []byte, end *Endpoint) (int, error) { +func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { // contruct message header diff --git a/src/device.go b/src/device.go index 0085cee..76235bd 100644 --- a/src/device.go +++ b/src/device.go @@ -22,9 +22,9 @@ type Device struct { } net struct { mutex sync.RWMutex - bind UDPBind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) + bind Bind // bind interface + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) } mutex sync.RWMutex privateKey NoisePrivateKey diff --git a/src/peer.go b/src/peer.go index a98fc97..f3eb6c2 100644 --- a/src/peer.go +++ b/src/peer.go @@ -15,11 +15,8 @@ type Peer struct { keyPairs KeyPairs handshake Handshake device *Device - endpoint struct { - set bool // has a known endpoint been discovered - value Endpoint // source / destination cache - } - stats struct { + endpoint Endpoint + stats struct { txBytes uint64 // bytes send to peer (endpoint) rxBytes uint64 // bytes received from peer lastHandshakeNano int64 // nano seconds since epoch @@ -110,9 +107,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // reset endpoint - peer.endpoint.set = false - peer.endpoint.value.ClearDst() - peer.endpoint.value.ClearSrc() + peer.endpoint = nil // prepare queuing @@ -143,16 +138,16 @@ func (peer *Peer) SendBuffer(buffer []byte) error { defer peer.device.net.mutex.RUnlock() peer.mutex.RLock() defer peer.mutex.RUnlock() - if !peer.endpoint.set { + if peer.endpoint == nil { return errors.New("No known endpoint for peer") } - return peer.device.net.bind.Send(buffer, &peer.endpoint.value) + return peer.device.net.bind.Send(buffer, peer.endpoint) } /* Returns a short string identification for logging */ func (peer *Peer) String() string { - if !peer.endpoint.set { + if peer.endpoint == nil { return fmt.Sprintf( "peer(%d unknown %s)", peer.id, @@ -162,7 +157,7 @@ func (peer *Peer) String() string { return fmt.Sprintf( "peer(%d %s %s)", peer.id, - peer.endpoint.value.DstToString(), + peer.endpoint.DstToString(), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), ) } diff --git a/src/receive.go b/src/receive.go index b8b06f7..27fdb8a 100644 --- a/src/receive.go +++ b/src/receive.go @@ -93,7 +93,7 @@ func (device *Device) addToHandshakeQueue( } } -func (device *Device) RoutineReceiveIncomming(IP int, bind UDPBind) { +func (device *Device) RoutineReceiveIncomming(IP int, bind Bind) { logDebug := device.log.Debug logDebug.Println("Routine, receive incomming, IP version:", IP) @@ -104,20 +104,21 @@ func (device *Device) RoutineReceiveIncomming(IP int, bind UDPBind) { buffer := device.GetMessageBuffer() - var size int - var err error + var ( + err error + size int + endpoint Endpoint + ) for { // read next datagram - var endpoint Endpoint - switch IP { case ipv4.Version: - size, err = bind.ReceiveIPv4(buffer[:], &endpoint) + size, endpoint, err = bind.ReceiveIPv4(buffer[:]) case ipv6.Version: - size, err = bind.ReceiveIPv6(buffer[:], &endpoint) + size, endpoint, err = bind.ReceiveIPv6(buffer[:]) default: return } @@ -339,10 +340,7 @@ func (device *Device) RoutineHandshake() { writer := bytes.NewBuffer(temp[:0]) binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send( - writer.Bytes(), - &elem.endpoint, - ) + device.net.bind.Send(writer.Bytes(), elem.endpoint) if err != nil { logDebug.Println("Failed to send cookie reply:", err) } @@ -395,8 +393,7 @@ func (device *Device) RoutineHandshake() { // update endpoint peer.mutex.Lock() - peer.endpoint.set = true - peer.endpoint.value = elem.endpoint + peer.endpoint = elem.endpoint peer.mutex.Unlock() // create response @@ -452,8 +449,7 @@ func (device *Device) RoutineHandshake() { // update endpoint peer.mutex.Lock() - peer.endpoint.set = true - peer.endpoint.value = elem.endpoint + peer.endpoint = elem.endpoint peer.mutex.Unlock() logDebug.Println("Received handshake initation from", peer) @@ -527,8 +523,7 @@ func (peer *Peer) RoutineSequentialReceiver() { // update endpoint peer.mutex.Lock() - peer.endpoint.set = true - peer.endpoint.value = elem.endpoint + peer.endpoint = elem.endpoint peer.mutex.Unlock() // check for keep-alive diff --git a/src/uapi.go b/src/uapi.go index e1d0929..670ecc4 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -53,8 +53,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { defer peer.mutex.RUnlock() send("public_key=" + peer.handshake.remoteStatic.ToHex()) send("preshared_key=" + peer.handshake.presharedKey.ToHex()) - if peer.endpoint.set { - send("endpoint=" + peer.endpoint.value.DstToString()) + if peer.endpoint != nil { + send("endpoint=" + peer.endpoint.DstToString()) } nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) @@ -255,17 +255,25 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "endpoint": - // set endpoint destination and reset handshake timer + // set endpoint destination + + err := func() error { + peer.mutex.Lock() + defer peer.mutex.Unlock() + + endpoint := NewEndpoint() + if err := endpoint.SetDst(value); err != nil { + return err + } + peer.endpoint = endpoint + signalSend(peer.signal.handshakeReset) + return nil + }() - peer.mutex.Lock() - err := peer.endpoint.value.SetDst(value) - peer.endpoint.set = (err == nil) - peer.mutex.Unlock() if err != nil { logError.Println("Failed to set endpoint:", value) return &IPCError{Code: ipcErrorInvalid} } - signalSend(peer.signal.handshakeReset) case "persistent_keepalive_interval": From a79fdc13a2d7be07b20ea499da9210ebe69f1958 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sun, 19 Nov 2017 00:21:58 +0100 Subject: [PATCH 14/15] Begin generic Bind implementation --- src/conn.go | 4 +-- src/conn_default.go | 69 +++++++++++++++++++++++++++++++++++------ src/conn_linux.go | 75 ++++++++++++++++++++++----------------------- src/uapi.go | 5 ++- 4 files changed, 99 insertions(+), 54 deletions(-) diff --git a/src/conn.go b/src/conn.go index 74bb075..5b40a23 100644 --- a/src/conn.go +++ b/src/conn.go @@ -24,11 +24,9 @@ type Bind interface { */ type Endpoint interface { ClearSrc() // clears the source address - ClearDst() // clears the destination address SrcToString() string // returns the local source address (ip:port) DstToString() string // returns the destination address (ip:port) DstToBytes() []byte // used for mac2 cookie calculations - SetDst(string) error // used for manually setting the endpoint (uapi) DstIP() net.IP SrcIP() net.IP } @@ -92,7 +90,7 @@ func UpdateUDPListener(device *Device) error { // bind to new port var err error - netc.bind, netc.port, err = CreateUDPBind(netc.port) + netc.bind, netc.port, err = CreateBind(netc.port) if err != nil { netc.bind = nil return err diff --git a/src/conn_default.go b/src/conn_default.go index 31cab5c..34168c6 100644 --- a/src/conn_default.go +++ b/src/conn_default.go @@ -13,11 +13,68 @@ import ( * See conn_linux.go for an implementation on the linux platform. */ -type Endpoint *net.UDPAddr +type NativeBind struct { + ipv4 *net.UDPConn + ipv6 *net.UDPConn +} -type NativeBind *net.UDPConn +type NativeEndpoint net.UDPAddr -func CreateUDPBind(port uint16) (UDPBind, uint16, error) { +var _ Bind = (*NativeBind)(nil) +var _ Endpoint = (*NativeEndpoint)(nil) + +func CreateEndpoint(s string) (Endpoint, error) { + addr, err := parseEndpoint(s) + return (addr).(*NativeEndpoint), err +} + +func (_ *NativeEndpoint) ClearSrc() {} + +func (e *NativeEndpoint) DstIP() net.IP { + return (*net.UDPAddr)(e).IP +} + +func (e *NativeEndpoint) SrcIP() net.IP { + return nil // not supported +} + +func (e *NativeEndpoint) DstToBytes() []byte { + addr := (*net.UDPAddr)(e) + out := addr.IP.([]byte) + out = append(out, byte(addr.Port&0xff)) + out = append(out, byte((addr.Port>>8)&0xff)) + return out +} + +func (e *NativeEndpoint) DstToString() string { + return (*net.UDPAddr)(e).String() +} + +func (e *NativeEndpoint) SrcToString() string { + return "" +} + +func listenNet(net string, port int) (*net.UDPConn, int, error) { + + // listen + + conn, err := net.ListenUDP("udp", &UDPAddr{Port: port}) + if err != nil { + return nil, 0, err + } + + // retrieve port + + laddr := conn.LocalAddr() + uaddr, _ = net.ResolveUDPAddr( + laddr.Network(), + laddr.String(), + ) + + return conn, uaddr.Port, nil +} + +func CreateBind(port uint16) (Bind, uint16, error) { // listen @@ -38,9 +95,3 @@ func CreateUDPBind(port uint16) (UDPBind, uint16, error) { ) return uaddr.Port } - -func (_ Endpoint) ClearSrc() {} - -func SetMark(conn *net.UDPConn, value uint32) error { - return nil -} diff --git a/src/conn_linux.go b/src/conn_linux.go index 46f873f..cdba74f 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -50,11 +50,44 @@ func ntohs(val uint16) uint16 { return binary.BigEndian.Uint16((*tmp)[:]) } -func NewEndpoint() Endpoint { - return &NativeEndpoint{} +func CreateEndpoint(s string) (Endpoint, error) { + var end NativeEndpoint + addr, err := parseEndpoint(s) + if err != nil { + return nil, err + } + + ipv4 := addr.IP.To4() + if ipv4 != nil { + dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) + dst.Family = unix.AF_INET + dst.Port = htons(uint16(addr.Port)) + dst.Zero = [8]byte{} + copy(dst.Addr[:], ipv4) + end.ClearSrc() + return &end, nil + } + + ipv6 := addr.IP.To16() + if ipv6 != nil { + zone, err := zoneToUint32(addr.Zone) + if err != nil { + return nil, err + } + dst := &end.dst + dst.Family = unix.AF_INET6 + dst.Port = htons(uint16(addr.Port)) + dst.Flowinfo = 0 + dst.Scope_id = zone + copy(dst.Addr[:], ipv6[:]) + end.ClearSrc() + return &end, nil + } + + return nil, errors.New("Failed to recognize IP address format") } -func CreateUDPBind(port uint16) (Bind, uint16, error) { +func CreateBind(port uint16) (Bind, uint16, error) { var err error var bind NativeBind @@ -325,42 +358,6 @@ func create6(port uint16) (int, uint16, error) { return fd, uint16(addr.Port), err } -func (end *NativeEndpoint) SetDst(s string) error { - addr, err := parseEndpoint(s) - if err != nil { - return err - } - - ipv4 := addr.IP.To4() - if ipv4 != nil { - dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) - dst.Family = unix.AF_INET - dst.Port = htons(uint16(addr.Port)) - dst.Zero = [8]byte{} - copy(dst.Addr[:], ipv4) - end.ClearSrc() - return nil - } - - ipv6 := addr.IP.To16() - if ipv6 != nil { - zone, err := zoneToUint32(addr.Zone) - if err != nil { - return err - } - dst := &end.dst - dst.Family = unix.AF_INET6 - dst.Port = htons(uint16(addr.Port)) - dst.Flowinfo = 0 - dst.Scope_id = zone - copy(dst.Addr[:], ipv6[:]) - end.ClearSrc() - return nil - } - - return errors.New("Failed to recognize IP address format") -} - func send6(sock int, end *NativeEndpoint, buff []byte) error { // construct message header diff --git a/src/uapi.go b/src/uapi.go index 670ecc4..dc8be66 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -260,9 +260,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { err := func() error { peer.mutex.Lock() defer peer.mutex.Unlock() - - endpoint := NewEndpoint() - if err := endpoint.SetDst(value); err != nil { + endpoint, err := CreateEndpoint(value) + if err != nil { return err } peer.endpoint = endpoint From 9ebab57c417d4fd19db6cf69f920a3adb1a1e092 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sun, 19 Nov 2017 13:14:15 +0100 Subject: [PATCH 15/15] Implemented missing methods for Bind and Endpoint --- src/conn_default.go | 72 +++++++++++++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 19 deletions(-) diff --git a/src/conn_default.go b/src/conn_default.go index 34168c6..5b73c90 100644 --- a/src/conn_default.go +++ b/src/conn_default.go @@ -25,7 +25,7 @@ var _ Endpoint = (*NativeEndpoint)(nil) func CreateEndpoint(s string) (Endpoint, error) { addr, err := parseEndpoint(s) - return (addr).(*NativeEndpoint), err + return (*NativeEndpoint)(addr), err } func (_ *NativeEndpoint) ClearSrc() {} @@ -40,7 +40,7 @@ func (e *NativeEndpoint) SrcIP() net.IP { func (e *NativeEndpoint) DstToBytes() []byte { addr := (*net.UDPAddr)(e) - out := addr.IP.([]byte) + out := addr.IP out = append(out, byte(addr.Port&0xff)) out = append(out, byte((addr.Port>>8)&0xff)) return out @@ -54,11 +54,11 @@ func (e *NativeEndpoint) SrcToString() string { return "" } -func listenNet(net string, port int) (*net.UDPConn, int, error) { +func listenNet(network string, port int) (*net.UDPConn, int, error) { // listen - conn, err := net.ListenUDP("udp", &UDPAddr{Port: port}) + conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) if err != nil { return nil, 0, err } @@ -66,32 +66,66 @@ func listenNet(net string, port int) (*net.UDPConn, int, error) { // retrieve port laddr := conn.LocalAddr() - uaddr, _ = net.ResolveUDPAddr( + uaddr, err := net.ResolveUDPAddr( laddr.Network(), laddr.String(), ) - + if err != nil { + return nil, 0, err + } return conn, uaddr.Port, nil } -func CreateBind(port uint16) (Bind, uint16, error) { +func CreateBind(uport uint16) (Bind, uint16, error) { + var err error + var bind NativeBind - // listen + port := int(uport) - addr := UDPAddr{ - Port: int(port), - } - conn, err := net.ListenUDP("udp", &addr) + bind.ipv4, port, err = listenNet("udp4", port) if err != nil { return nil, 0, err } - // retrieve port + bind.ipv6, port, err = listenNet("udp6", port) + if err != nil { + bind.ipv4.Close() + return nil, 0, err + } - laddr := conn.LocalAddr() - uaddr, _ = net.ResolveUDPAddr( - laddr.Network(), - laddr.String(), - ) - return uaddr.Port + return &bind, uint16(port), nil +} + +func (bind *NativeBind) Close() error { + err1 := bind.ipv4.Close() + err2 := bind.ipv6.Close() + if err1 != nil { + return err1 + } + return err2 +} + +func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + n, endpoint, err := bind.ipv4.ReadFromUDP(buff) + return n, (*NativeEndpoint)(endpoint), err +} + +func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + n, endpoint, err := bind.ipv6.ReadFromUDP(buff) + return n, (*NativeEndpoint)(endpoint), err +} + +func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error { + var err error + nend := endpoint.(*NativeEndpoint) + if nend.IP.To16() != nil { + _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) + } else { + _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) + } + return err +} + +func (bind *NativeBind) SetMark(_ uint32) error { + return nil }