diff --git a/src/conn.go b/src/conn.go index b2caffb..aa0b72b 100644 --- a/src/conn.go +++ b/src/conn.go @@ -34,6 +34,21 @@ func parseEndpoint(s string) (*net.UDPAddr, error) { return addr, err } +/* Must hold device and net lock + */ +func unsafeCloseUDPListener(device *Device) error { + netc := &device.net + if netc.bind != nil { + if err := netc.bind.Close(); err != nil { + return err + } + netc.bind = nil + netc.update.Broadcast() + } + return nil +} + +// must inform all listeners func UpdateUDPListener(device *Device) error { device.mutex.Lock() defer device.mutex.Unlock() @@ -44,26 +59,22 @@ func UpdateUDPListener(device *Device) error { // close existing sockets - if netc.bind != nil { - println("close bind") - if err := netc.bind.Close(); err != nil { - return err - } - netc.bind = nil - println("closed") + if err := unsafeCloseUDPListener(device); err != nil { + return err } + // wait for reader + // open new sockets if device.tun.isUp.Get() { - println("creat") - // bind to new port var err error netc.bind, netc.port, err = CreateUDPBind(netc.port) if err != nil { + netc.bind = nil return err } @@ -74,8 +85,6 @@ func UpdateUDPListener(device *Device) error { return err } - println("okay") - // clear cached source addresses for _, peer := range device.peers { @@ -83,14 +92,20 @@ func UpdateUDPListener(device *Device) error { peer.endpoint.value.ClearSrc() peer.mutex.Unlock() } + + // inform readers of updated bind + + netc.update.Broadcast() } return nil } func CloseUDPListener(device *Device) error { - netc := &device.net - netc.mutex.Lock() - defer netc.mutex.Unlock() - return netc.bind.Close() + device.mutex.Lock() + device.net.mutex.Lock() + err := unsafeCloseUDPListener(device) + device.net.mutex.Unlock() + device.mutex.Unlock() + return err } diff --git a/src/conn_linux.go b/src/conn_linux.go index 8cda460..05f9347 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -7,8 +7,8 @@ package main import ( + "encoding/binary" "errors" - "fmt" "golang.org/x/sys/unix" "net" "strconv" @@ -37,6 +37,17 @@ type NativeBind struct { sock6 int } +func htons(val uint16) uint16 { + var out [unsafe.Sizeof(val)]byte + binary.BigEndian.PutUint16(out[:], val) + return *((*uint16)(unsafe.Pointer(&out[0]))) +} + +func ntohs(val uint16) uint16 { + tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val))) + return binary.BigEndian.Uint16((*tmp)[:]) +} + func CreateUDPBind(port uint16) (UDPBind, uint16, error) { var err error var bind NativeBind @@ -50,8 +61,6 @@ func CreateUDPBind(port uint16) (UDPBind, uint16, error) { if err != nil { unix.Close(bind.sock6) } - println(bind.sock6) - println(bind.sock4) return bind, port, err } @@ -297,13 +306,11 @@ func (end *Endpoint) SetDst(s string) error { return err } - fmt.Println(addr, err) - ipv4 := addr.IP.To4() if ipv4 != nil { dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) dst.Family = unix.AF_INET - dst.Port = uint16(addr.Port) + dst.Port = htons(uint16(addr.Port)) dst.Zero = [8]byte{} copy(dst.Addr[:], ipv4) end.ClearSrc() @@ -318,7 +325,7 @@ func (end *Endpoint) SetDst(s string) error { } dst := &end.dst dst.Family = unix.AF_INET6 - dst.Port = uint16(addr.Port) + dst.Port = htons(uint16(addr.Port)) dst.Flowinfo = 0 dst.Scope_id = zone copy(dst.Addr[:], ipv6[:]) @@ -392,9 +399,6 @@ func send6(sock int, end *Endpoint, buff []byte) error { } func send4(sock int, end *Endpoint, buff []byte) error { - println("send 4") - println(end.DstToString()) - println(sock) // construct message header @@ -425,6 +429,7 @@ func send4(sock int, end *Endpoint, buff []byte) error { Name: (*byte)(unsafe.Pointer(&end.dst)), Namelen: unix.SizeofSockaddrInet4, Control: (*byte)(unsafe.Pointer(&cmsg)), + Flags: 0, } msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) @@ -437,10 +442,6 @@ func send4(sock int, end *Endpoint, buff []byte) error { 0, ) - if errno == 0 { - return nil - } - // clear source and try again if errno == unix.EINVAL { @@ -454,6 +455,12 @@ func send4(sock int, end *Endpoint, buff []byte) error { ) } + // errno = 0 is still an error instance + + if errno == 0 { + return nil + } + return errno } diff --git a/src/device.go b/src/device.go index 1aae448..a348c68 100644 --- a/src/device.go +++ b/src/device.go @@ -23,9 +23,10 @@ type Device struct { } net struct { mutex sync.RWMutex - bind UDPBind - port uint16 - fwmark uint32 + bind UDPBind // bind interface + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) + update *sync.Cond // the bind was updated } mutex sync.RWMutex privateKey NoisePrivateKey @@ -38,8 +39,7 @@ type Device struct { handshake chan QueueHandshakeElement } signal struct { - stop chan struct{} - updateBind chan struct{} + stop chan struct{} } underLoadUntil atomic.Value ratelimiter Ratelimiter @@ -163,6 +163,12 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { device.signal.stop = make(chan struct{}) + // prepare net + + device.net.port = 0 + device.net.bind = nil + device.net.update = sync.NewCond(&device.net.mutex) + // start workers for i := 0; i < runtime.NumCPU(); i += 1 { diff --git a/src/receive.go b/src/receive.go index 1f05b2f..cb53f80 100644 --- a/src/receive.go +++ b/src/receive.go @@ -99,135 +99,126 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) { for { - // wait for new conn + // wait for bind - logDebug.Println("Waiting for udp socket") + logDebug.Println("Waiting for udp bind") + device.net.mutex.Lock() + device.net.update.Wait() + bind := device.net.bind + device.net.mutex.Unlock() + if bind == nil { + continue + } - select { - case <-device.signal.stop: - return + logDebug.Println("LISTEN\n\n\n") - case <-device.signal.updateBind: + // receive datagrams until conn is closed - // fetch new socket + buffer := device.GetMessageBuffer() - device.net.mutex.RLock() - bind := device.net.bind - device.net.mutex.RUnlock() - if bind == nil { + var size int + var err error + + for { + + // read next datagram + + var endpoint Endpoint + + switch IPVersion { + case ipv4.Version: + size, err = bind.ReceiveIPv4(buffer[:], &endpoint) + case ipv6.Version: + size, err = bind.ReceiveIPv6(buffer[:], &endpoint) + default: + return + } + + if err != nil { + break + } + + if size < MinMessageSize { continue } - logDebug.Println("Listening for inbound packets") + // check size of packet - // receive datagrams until conn is closed + packet := buffer[:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) - buffer := device.GetMessageBuffer() + var okay bool - var size int - var err error + switch msgType { - for { + // check if transport - // read next datagram + case MessageTransportType: - var endpoint Endpoint + // check size - switch IPVersion { - case ipv4.Version: - size, err = bind.ReceiveIPv4(buffer[:], &endpoint) - case ipv6.Version: - size, err = bind.ReceiveIPv6(buffer[:], &endpoint) - default: - return - } - - if err != nil { - break - } - - if size < MinMessageSize { + if len(packet) < MessageTransportType { continue } - // check size of packet + // lookup key pair - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) - - var okay bool - - switch msgType { - - // check if transport - - case MessageTransportType: - - // check size - - if len(packet) < MessageTransportType { - continue - } - - // lookup key pair - - receiver := binary.LittleEndian.Uint32( - packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], - ) - value := device.indices.Lookup(receiver) - keyPair := value.keyPair - if keyPair == nil { - continue - } - - // check key-pair expiry - - if keyPair.created.Add(RejectAfterTime).Before(time.Now()) { - continue - } - - // create work element - - peer := value.peer - elem := &QueueInboundElement{ - packet: packet, - buffer: buffer, - keyPair: keyPair, - dropped: AtomicFalse, - } - elem.mutex.Lock() - - // add to decryption queues - - device.addToDecryptionQueue(device.queue.decryption, elem) - device.addToInboundQueue(peer.queue.inbound, elem) - buffer = device.GetMessageBuffer() + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indices.Lookup(receiver) + keyPair := value.keyPair + if keyPair == nil { continue - - // otherwise it is a fixed size & handshake related packet - - case MessageInitiationType: - okay = len(packet) == MessageInitiationSize - - case MessageResponseType: - okay = len(packet) == MessageResponseSize - - case MessageCookieReplyType: - okay = len(packet) == MessageCookieReplySize } - if okay { - device.addToHandshakeQueue( - device.queue.handshake, - QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - endpoint: endpoint, - }, - ) - buffer = device.GetMessageBuffer() + // check key-pair expiry + + if keyPair.created.Add(RejectAfterTime).Before(time.Now()) { + continue } + + // create work element + + peer := value.peer + elem := &QueueInboundElement{ + packet: packet, + buffer: buffer, + keyPair: keyPair, + dropped: AtomicFalse, + } + elem.mutex.Lock() + + // add to decryption queues + + device.addToDecryptionQueue(device.queue.decryption, elem) + device.addToInboundQueue(peer.queue.inbound, elem) + buffer = device.GetMessageBuffer() + continue + + // otherwise it is a fixed size & handshake related packet + + case MessageInitiationType: + okay = len(packet) == MessageInitiationSize + + case MessageResponseType: + okay = len(packet) == MessageResponseSize + + case MessageCookieReplyType: + okay = len(packet) == MessageCookieReplySize + } + + if okay { + device.addToHandshakeQueue( + device.queue.handshake, + QueueHandshakeElement{ + msgType: msgType, + buffer: buffer, + packet: packet, + endpoint: endpoint, + }, + ) + buffer = device.GetMessageBuffer() } } }