diff --git a/conn/bind_linux.go b/conn/bind_linux.go index 70ea609..9eec384 100644 --- a/conn/bind_linux.go +++ b/conn/bind_linux.go @@ -55,10 +55,11 @@ func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 { // LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux. type LinuxSocketBind struct { - sock4 int - sock6 int - lastMark uint32 - closing sync.RWMutex + // mu guards sock4 and sock6 and the associated fds. + // As long as someone holds mu (read or write), the associated fds are valid. + mu sync.RWMutex + sock4 int + sock6 int } func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} } @@ -102,54 +103,67 @@ func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) { return nil, errors.New("invalid IP address") } -func (bind *LinuxSocketBind) Open(port uint16) (uint16, error) { +func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) { + bind.mu.Lock() + defer bind.mu.Unlock() + var err error var newPort uint16 var tries int if bind.sock4 != -1 || bind.sock6 != -1 { - return 0, ErrBindAlreadyOpen + return nil, 0, ErrBindAlreadyOpen } originalPort := port again: port = originalPort + var sock4, sock6 int // Attempt ipv6 bind, update port if successful. - bind.sock6, newPort, err = create6(port) + sock6, newPort, err = create6(port) if err != nil { - if err != syscall.EAFNOSUPPORT { - return 0, err + if !errors.Is(err, syscall.EAFNOSUPPORT) { + return nil, 0, err } } else { port = newPort } // Attempt ipv4 bind, update port if successful. - bind.sock4, newPort, err = create4(port) + sock4, newPort, err = create4(port) if err != nil { - if originalPort == 0 && err == syscall.EADDRINUSE && tries < 100 { - unix.Close(bind.sock6) + if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { + unix.Close(sock6) tries++ goto again } - if err != syscall.EAFNOSUPPORT { - unix.Close(bind.sock6) - return 0, err + if !errors.Is(err, syscall.EAFNOSUPPORT) { + unix.Close(sock6) + return nil, 0, err } } else { port = newPort } - if bind.sock4 == -1 && bind.sock6 == -1 { - return 0, syscall.EAFNOSUPPORT + var fns []ReceiveFunc + if sock4 != -1 { + fns = append(fns, makeReceiveIPv4(sock4)) + bind.sock4 = sock4 } - return port, nil + if sock6 != -1 { + fns = append(fns, makeReceiveIPv6(sock6)) + bind.sock6 = sock6 + } + if len(fns) == 0 { + return nil, 0, syscall.EAFNOSUPPORT + } + return fns, port, nil } func (bind *LinuxSocketBind) SetMark(value uint32) error { - bind.closing.RLock() - defer bind.closing.RUnlock() + bind.mu.RLock() + defer bind.mu.RUnlock() if bind.sock6 != -1 { err := unix.SetsockoptInt( @@ -177,21 +191,24 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error { } } - bind.lastMark = value return nil } func (bind *LinuxSocketBind) Close() error { - var err1, err2 error - bind.closing.RLock() + // Take a readlock to shut down the sockets... + bind.mu.RLock() if bind.sock6 != -1 { unix.Shutdown(bind.sock6, unix.SHUT_RDWR) } if bind.sock4 != -1 { unix.Shutdown(bind.sock4, unix.SHUT_RDWR) } - bind.closing.RUnlock() - bind.closing.Lock() + bind.mu.RUnlock() + // ...and a write lock to close the fd. + // This ensures that no one else is using the fd. + bind.mu.Lock() + defer bind.mu.Unlock() + var err1, err2 error if bind.sock6 != -1 { err1 = unix.Close(bind.sock6) bind.sock6 = -1 @@ -200,7 +217,6 @@ func (bind *LinuxSocketBind) Close() error { err2 = unix.Close(bind.sock4) bind.sock4 = -1 } - bind.closing.Unlock() if err1 != nil { return err1 @@ -208,46 +224,29 @@ func (bind *LinuxSocketBind) Close() error { return err2 } -func (bind *LinuxSocketBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - bind.closing.RLock() - defer bind.closing.RUnlock() - - var end LinuxSocketEndpoint - if bind.sock6 == -1 { - return 0, nil, net.ErrClosed +func makeReceiveIPv6(sock int) ReceiveFunc { + return func(buff []byte) (int, Endpoint, error) { + var end LinuxSocketEndpoint + n, err := receive6(sock, buff, &end) + return n, &end, err } - n, err := receive6( - bind.sock6, - buff, - &end, - ) - return n, &end, err } -func (bind *LinuxSocketBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - bind.closing.RLock() - defer bind.closing.RUnlock() - - var end LinuxSocketEndpoint - if bind.sock4 == -1 { - return 0, nil, net.ErrClosed +func makeReceiveIPv4(sock int) ReceiveFunc { + return func(buff []byte) (int, Endpoint, error) { + var end LinuxSocketEndpoint + n, err := receive4(sock, buff, &end) + return n, &end, err } - n, err := receive4( - bind.sock4, - buff, - &end, - ) - return n, &end, err } func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { - bind.closing.RLock() - defer bind.closing.RUnlock() - nend, ok := end.(*LinuxSocketEndpoint) if !ok { return ErrWrongEndpointType } + bind.mu.RLock() + defer bind.mu.RUnlock() if !nend.isV6 { if bind.sock4 == -1 { return net.ErrClosed diff --git a/conn/bind_std.go b/conn/bind_std.go index f8b8a1b..5261779 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -8,6 +8,7 @@ package conn import ( "errors" "net" + "sync" "syscall" ) @@ -16,6 +17,7 @@ import ( // It uses the Go's net package to implement networking. // See LinuxSocketBind for a proper implementation on the Linux platform. type StdNetBind struct { + mu sync.Mutex // protects following fields ipv4 *net.UDPConn ipv6 *net.UDPConn blackhole4 bool @@ -81,12 +83,15 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { return conn, uaddr.Port, nil } -func (bind *StdNetBind) Open(uport uint16) (uint16, error) { +func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { + bind.mu.Lock() + defer bind.mu.Unlock() + var err error var tries int if bind.ipv4 != nil || bind.ipv6 != nil { - return 0, ErrBindAlreadyOpen + return nil, 0, ErrBindAlreadyOpen } // Attempt to open ipv4 and ipv6 listeners on the same port. @@ -97,7 +102,7 @@ again: ipv4, port, err = listenNet("udp4", port) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - return 0, err + return nil, 0, err } // Listen on the same port as we're using for ipv4. @@ -109,17 +114,27 @@ again: } if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { ipv4.Close() - return 0, err + return nil, 0, err } - if ipv4 == nil && ipv6 == nil { - return 0, syscall.EAFNOSUPPORT + var fns []ReceiveFunc + if ipv4 != nil { + fns = append(fns, makeReceiveFunc(ipv4, true)) + bind.ipv4 = ipv4 } - bind.ipv4 = ipv4 - bind.ipv6 = ipv6 - return uint16(port), nil + if ipv6 != nil { + fns = append(fns, makeReceiveFunc(ipv6, false)) + bind.ipv6 = ipv6 + } + if len(fns) == 0 { + return nil, 0, syscall.EAFNOSUPPORT + } + return fns, uint16(port), nil } func (bind *StdNetBind) Close() error { + bind.mu.Lock() + defer bind.mu.Unlock() + var err1, err2 error if bind.ipv4 != nil { err1 = bind.ipv4.Close() @@ -137,23 +152,14 @@ func (bind *StdNetBind) Close() error { return err2 } -func (bind *StdNetBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - if bind.ipv4 == nil { - return 0, nil, syscall.EAFNOSUPPORT +func makeReceiveFunc(conn *net.UDPConn, isIPv4 bool) ReceiveFunc { + return func(buff []byte) (int, Endpoint, error) { + n, endpoint, err := conn.ReadFromUDP(buff) + if isIPv4 && endpoint != nil { + endpoint.IP = endpoint.IP.To4() + } + return n, (*StdNetEndpoint)(endpoint), err } - n, endpoint, err := bind.ipv4.ReadFromUDP(buff) - if endpoint != nil { - endpoint.IP = endpoint.IP.To4() - } - return n, (*StdNetEndpoint)(endpoint), err -} - -func (bind *StdNetBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - if bind.ipv6 == nil { - return 0, nil, syscall.EAFNOSUPPORT - } - n, endpoint, err := bind.ipv6.ReadFromUDP(buff) - return n, (*StdNetEndpoint)(endpoint), err } func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { @@ -162,15 +168,16 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { if !ok { return ErrWrongEndpointType } - var conn *net.UDPConn - var blackhole bool - if nend.IP.To4() != nil { - blackhole = bind.blackhole4 - conn = bind.ipv4 - } else { + + bind.mu.Lock() + blackhole := bind.blackhole4 + conn := bind.ipv4 + if nend.IP.To4() == nil { blackhole = bind.blackhole6 conn = bind.ipv6 } + bind.mu.Unlock() + if blackhole { return nil } diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 1e2712e..6cabee1 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -266,7 +266,7 @@ func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sock return sa, nil } -func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) { +func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) { bind.mu.Lock() defer bind.mu.Unlock() defer func() { @@ -275,30 +275,30 @@ func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) { } }() if atomic.LoadUint32(&bind.isOpen) != 0 { - return 0, ErrBindAlreadyOpen + return nil, 0, ErrBindAlreadyOpen } var sa windows.Sockaddr sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) if err != nil { - return 0, err + return nil, 0, err } sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) if err != nil { - return 0, err + return nil, 0, err } selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) for i := 0; i < packetsPerRing; i++ { err = bind.v4.InsertReceiveRequest() if err != nil { - return 0, err + return nil, 0, err } err = bind.v6.InsertReceiveRequest() if err != nil { - return 0, err + return nil, 0, err } } atomic.StoreUint32(&bind.isOpen, 1) - return + return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err } func (bind *WinRingBind) Close() error { @@ -395,13 +395,13 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e return n, &ep, nil } -func (bind *WinRingBind) ReceiveIPv4(buf []byte) (int, Endpoint, error) { +func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) { bind.mu.RLock() defer bind.mu.RUnlock() return bind.v4.Receive(buf, &bind.isOpen) } -func (bind *WinRingBind) ReceiveIPv6(buf []byte) (int, Endpoint, error) { +func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) { bind.mu.RLock() defer bind.mu.RUnlock() return bind.v6.Receive(buf, &bind.isOpen) @@ -482,6 +482,8 @@ func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error { } func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + bind.mu.Lock() + defer bind.mu.Unlock() sysconn, err := bind.ipv4.SyscallConn() if err != nil { return err @@ -500,6 +502,8 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole } func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + bind.mu.Lock() + defer bind.mu.Unlock() sysconn, err := bind.ipv6.SyscallConn() if err != nil { return err diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index ad8fa05..7d43fb3 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -65,12 +65,14 @@ func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) } func (c ChannelEndpoint) SrcIP() net.IP { return nil } -func (c *ChannelBind) Open(port uint16) (actualPort uint16, err error) { +func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { c.closeSignal = make(chan bool) + fns = append(fns, c.makeReceiveFunc(*c.rx4)) + fns = append(fns, c.makeReceiveFunc(*c.rx6)) if rand.Uint32()&1 == 0 { - return uint16(c.source4), nil + return fns, uint16(c.source4), nil } else { - return uint16(c.source6), nil + return fns, uint16(c.source6), nil } } @@ -87,21 +89,14 @@ func (c *ChannelBind) Close() error { func (c *ChannelBind) SetMark(mark uint32) error { return nil } -func (c *ChannelBind) ReceiveIPv6(b []byte) (n int, ep conn.Endpoint, err error) { - select { - case <-c.closeSignal: - return 0, nil, net.ErrClosed - case rx := <-*c.rx6: - return copy(b, rx), c.target6, nil - } -} - -func (c *ChannelBind) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) { - select { - case <-c.closeSignal: - return 0, nil, net.ErrClosed - case rx := <-*c.rx4: - return copy(b, rx), c.target4, nil +func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { + return func(b []byte) (n int, ep conn.Endpoint, err error) { + select { + case <-c.closeSignal: + return 0, nil, net.ErrClosed + case rx := <-ch: + return copy(b, rx), c.target6, nil + } } } diff --git a/conn/conn.go b/conn/conn.go index 6fd232f..3c7fcd0 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -12,6 +12,11 @@ import ( "strings" ) +// A ReceiveFunc receives a single inbound packet from the network. +// It writes the data into b. n is the length of the packet. +// ep is the remote endpoint. +type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error) + // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, @@ -19,23 +24,17 @@ import ( type Bind interface { // Open puts the Bind into a listening state on a given port and reports the actual // port that it bound to. Passing zero results in a random selection. - Open(port uint16) (actualPort uint16, err error) + // fns is the set of functions that will be called to receive packets. + Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error) // Close closes the Bind listener. + // All fns returned by Open must return net.ErrClosed after a call to Close. Close() error // SetMark sets the mark for each packet sent through this Bind. // This mark is passed to the kernel as the socket option SO_MARK. SetMark(mark uint32) error - // ReceiveIPv6 reads an IPv6 UDP packet into b. It reports the number of bytes read, - // n, the packet source address ep, and any error. - ReceiveIPv6(b []byte) (n int, ep Endpoint, err error) - - // ReceiveIPv4 reads an IPv4 UDP packet into b. It reports the number of bytes read, - // n, the packet source address ep, and any error. - ReceiveIPv4(b []byte) (n int, ep Endpoint, err error) - // Send writes a packet b to address ep. Send(b []byte, ep Endpoint) error diff --git a/device/device.go b/device/device.go index 1e32db6..a635e68 100644 --- a/device/device.go +++ b/device/device.go @@ -11,9 +11,6 @@ import ( "sync/atomic" "time" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/rwcancel" @@ -468,8 +465,9 @@ func (device *Device) BindUpdate() error { // bind to new port var err error + var recvFns []conn.ReceiveFunc netc := &device.net - netc.port, err = netc.bind.Open(netc.port) + recvFns, netc.port, err = netc.bind.Open(netc.port) if err != nil { netc.port = 0 return err @@ -501,11 +499,12 @@ func (device *Device) BindUpdate() error { device.peers.RUnlock() // start receiving routines - device.net.stopping.Add(2) - device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption - device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake - go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) - go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) + device.net.stopping.Add(len(recvFns)) + device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption + device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake + for _, fn := range recvFns { + go device.RoutineReceiveIncoming(fn) + } device.log.Verbosef("UDP bind has been updated") return nil diff --git a/device/receive.go b/device/receive.go index 5ddb66c..fa5c0a6 100644 --- a/device/receive.go +++ b/device/receive.go @@ -68,15 +68,15 @@ func (peer *Peer) keepKeyFreshReceiving() { * Every time the bind is updated a new routine is started for * IPv4 and IPv6 (separately) */ -func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { +func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { defer func() { - device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP) + device.log.Verbosef("Routine: receive incoming %p - stopped", recv) device.queue.decryption.wg.Done() device.queue.handshake.wg.Done() device.net.stopping.Done() }() - device.log.Verbosef("Routine: receive incoming IPv%d - started", IP) + device.log.Verbosef("Routine: receive incoming %p - started", recv) // receive datagrams until conn is closed @@ -90,14 +90,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { ) for { - switch IP { - case ipv4.Version: - size, endpoint, err = bind.ReceiveIPv4(buffer[:]) - case ipv6.Version: - size, endpoint, err = bind.ReceiveIPv6(buffer[:]) - default: - panic("invalid IP version") - } + size, endpoint, err = recv(buffer[:]) if err != nil { device.PutMessageBuffer(buffer)