diff --git a/src/conn.go b/src/conn.go index 3cf00ab..74bb075 100644 --- a/src/conn.go +++ b/src/conn.go @@ -7,26 +7,28 @@ import ( "net" ) -type UDPBind interface { +/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic + */ +type Bind interface { SetMark(value uint32) error - ReceiveIPv6(buff []byte, end *Endpoint) (int, error) - ReceiveIPv4(buff []byte, end *Endpoint) (int, error) - Send(buff []byte, end *Endpoint) error + ReceiveIPv6(buff []byte) (int, Endpoint, error) + ReceiveIPv4(buff []byte) (int, Endpoint, error) + Send(buff []byte, end Endpoint) error Close() error } /* An Endpoint maintains the source/destination caching for a peer * - * dst : the remote address of a peer + * dst : the remote address of a peer ("endpoint" in uapi terminology) * src : the local address from which datagrams originate going to the peer - * */ -type UDPEndpoint interface { +type Endpoint interface { ClearSrc() // clears the source address ClearDst() // clears the destination address SrcToString() string // returns the local source address (ip:port) DstToString() string // returns the destination address (ip:port) DstToBytes() []byte // used for mac2 cookie calculations + SetDst(string) error // used for manually setting the endpoint (uapi) DstIP() net.IP SrcIP() net.IP } @@ -107,7 +109,9 @@ func UpdateUDPListener(device *Device) error { for _, peer := range device.peers { peer.mutex.Lock() - peer.endpoint.value.ClearSrc() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } peer.mutex.Unlock() } diff --git a/src/conn_linux.go b/src/conn_linux.go index fb576b1..46f873f 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -21,22 +21,24 @@ import ( * See e.g. https://github.com/golang/go/issues/17930 * So this code is remains platform dependent. */ - -type Endpoint struct { +type NativeEndpoint struct { src unix.RawSockaddrInet6 dst unix.RawSockaddrInet6 } -type IPv4Source struct { - src unix.RawSockaddrInet4 - Ifindex int32 -} - type NativeBind struct { sock4 int sock6 int } +var _ Endpoint = (*NativeEndpoint)(nil) +var _ Bind = NativeBind{} + +type IPv4Source struct { + src unix.RawSockaddrInet4 + Ifindex int32 +} + func htons(val uint16) uint16 { var out [unsafe.Sizeof(val)]byte binary.BigEndian.PutUint16(out[:], val) @@ -48,7 +50,11 @@ func ntohs(val uint16) uint16 { return binary.BigEndian.Uint16((*tmp)[:]) } -func CreateUDPBind(port uint16) (UDPBind, uint16, error) { +func NewEndpoint() Endpoint { + return &NativeEndpoint{} +} + +func CreateUDPBind(port uint16) (Bind, uint16, error) { var err error var bind NativeBind @@ -99,28 +105,33 @@ func (bind NativeBind) Close() error { return err2 } -func (bind NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { - return receive6( +func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + var end NativeEndpoint + n, err := receive6( bind.sock6, buff, - end, + &end, ) + return n, &end, err } -func (bind NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { - return receive4( +func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + var end NativeEndpoint + n, err := receive4( bind.sock4, buff, - end, + &end, ) + return n, &end, err } -func (bind NativeBind) Send(buff []byte, end *Endpoint) error { - switch end.dst.Family { +func (bind NativeBind) Send(buff []byte, end Endpoint) error { + nend := end.(*NativeEndpoint) + switch nend.dst.Family { case unix.AF_INET6: - return send6(bind.sock6, end, buff) + return send6(bind.sock6, nend, buff) case unix.AF_INET: - return send4(bind.sock4, end, buff) + return send4(bind.sock4, nend, buff) default: return errors.New("Unknown address family of destination") } @@ -151,12 +162,12 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string { } } -func (end *Endpoint) DstIP() net.IP { - switch end.dst.Family { +func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP { + switch addr.Family { case unix.AF_INET6: - return end.dst.Addr[:] + return addr.Addr[:] case unix.AF_INET: - ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) + ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) return net.IPv4( ptr.Addr[0], ptr.Addr[1], @@ -168,25 +179,33 @@ func (end *Endpoint) DstIP() net.IP { } } -func (end *Endpoint) DstToBytes() []byte { +func (end *NativeEndpoint) SrcIP() net.IP { + return rawAddrToIP(end.src) +} + +func (end *NativeEndpoint) DstIP() net.IP { + return rawAddrToIP(end.dst) +} + +func (end *NativeEndpoint) DstToBytes() []byte { ptr := unsafe.Pointer(&end.src) arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) return arr[:] } -func (end *Endpoint) SrcToString() string { +func (end *NativeEndpoint) SrcToString() string { return sockaddrToString(end.src) } -func (end *Endpoint) DstToString() string { +func (end *NativeEndpoint) DstToString() string { return sockaddrToString(end.dst) } -func (end *Endpoint) ClearDst() { +func (end *NativeEndpoint) ClearDst() { end.dst = unix.RawSockaddrInet6{} } -func (end *Endpoint) ClearSrc() { +func (end *NativeEndpoint) ClearSrc() { end.src = unix.RawSockaddrInet6{} } @@ -306,7 +325,7 @@ func create6(port uint16) (int, uint16, error) { return fd, uint16(addr.Port), err } -func (end *Endpoint) SetDst(s string) error { +func (end *NativeEndpoint) SetDst(s string) error { addr, err := parseEndpoint(s) if err != nil { return err @@ -342,7 +361,7 @@ func (end *Endpoint) SetDst(s string) error { return errors.New("Failed to recognize IP address format") } -func send6(sock int, end *Endpoint, buff []byte) error { +func send6(sock int, end *NativeEndpoint, buff []byte) error { // construct message header @@ -404,7 +423,7 @@ func send6(sock int, end *Endpoint, buff []byte) error { return errno } -func send4(sock int, end *Endpoint, buff []byte) error { +func send4(sock int, end *NativeEndpoint, buff []byte) error { // construct message header @@ -470,7 +489,7 @@ func send4(sock int, end *Endpoint, buff []byte) error { return errno } -func receive4(sock int, buff []byte, end *Endpoint) (int, error) { +func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { // contruct message header @@ -518,7 +537,7 @@ func receive4(sock int, buff []byte, end *Endpoint) (int, error) { return int(size), nil } -func receive6(sock int, buff []byte, end *Endpoint) (int, error) { +func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { // contruct message header diff --git a/src/device.go b/src/device.go index 0085cee..76235bd 100644 --- a/src/device.go +++ b/src/device.go @@ -22,9 +22,9 @@ type Device struct { } net struct { mutex sync.RWMutex - bind UDPBind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) + bind Bind // bind interface + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) } mutex sync.RWMutex privateKey NoisePrivateKey diff --git a/src/peer.go b/src/peer.go index a98fc97..f3eb6c2 100644 --- a/src/peer.go +++ b/src/peer.go @@ -15,11 +15,8 @@ type Peer struct { keyPairs KeyPairs handshake Handshake device *Device - endpoint struct { - set bool // has a known endpoint been discovered - value Endpoint // source / destination cache - } - stats struct { + endpoint Endpoint + stats struct { txBytes uint64 // bytes send to peer (endpoint) rxBytes uint64 // bytes received from peer lastHandshakeNano int64 // nano seconds since epoch @@ -110,9 +107,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // reset endpoint - peer.endpoint.set = false - peer.endpoint.value.ClearDst() - peer.endpoint.value.ClearSrc() + peer.endpoint = nil // prepare queuing @@ -143,16 +138,16 @@ func (peer *Peer) SendBuffer(buffer []byte) error { defer peer.device.net.mutex.RUnlock() peer.mutex.RLock() defer peer.mutex.RUnlock() - if !peer.endpoint.set { + if peer.endpoint == nil { return errors.New("No known endpoint for peer") } - return peer.device.net.bind.Send(buffer, &peer.endpoint.value) + return peer.device.net.bind.Send(buffer, peer.endpoint) } /* Returns a short string identification for logging */ func (peer *Peer) String() string { - if !peer.endpoint.set { + if peer.endpoint == nil { return fmt.Sprintf( "peer(%d unknown %s)", peer.id, @@ -162,7 +157,7 @@ func (peer *Peer) String() string { return fmt.Sprintf( "peer(%d %s %s)", peer.id, - peer.endpoint.value.DstToString(), + peer.endpoint.DstToString(), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), ) } diff --git a/src/receive.go b/src/receive.go index b8b06f7..27fdb8a 100644 --- a/src/receive.go +++ b/src/receive.go @@ -93,7 +93,7 @@ func (device *Device) addToHandshakeQueue( } } -func (device *Device) RoutineReceiveIncomming(IP int, bind UDPBind) { +func (device *Device) RoutineReceiveIncomming(IP int, bind Bind) { logDebug := device.log.Debug logDebug.Println("Routine, receive incomming, IP version:", IP) @@ -104,20 +104,21 @@ func (device *Device) RoutineReceiveIncomming(IP int, bind UDPBind) { buffer := device.GetMessageBuffer() - var size int - var err error + var ( + err error + size int + endpoint Endpoint + ) for { // read next datagram - var endpoint Endpoint - switch IP { case ipv4.Version: - size, err = bind.ReceiveIPv4(buffer[:], &endpoint) + size, endpoint, err = bind.ReceiveIPv4(buffer[:]) case ipv6.Version: - size, err = bind.ReceiveIPv6(buffer[:], &endpoint) + size, endpoint, err = bind.ReceiveIPv6(buffer[:]) default: return } @@ -339,10 +340,7 @@ func (device *Device) RoutineHandshake() { writer := bytes.NewBuffer(temp[:0]) binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send( - writer.Bytes(), - &elem.endpoint, - ) + device.net.bind.Send(writer.Bytes(), elem.endpoint) if err != nil { logDebug.Println("Failed to send cookie reply:", err) } @@ -395,8 +393,7 @@ func (device *Device) RoutineHandshake() { // update endpoint peer.mutex.Lock() - peer.endpoint.set = true - peer.endpoint.value = elem.endpoint + peer.endpoint = elem.endpoint peer.mutex.Unlock() // create response @@ -452,8 +449,7 @@ func (device *Device) RoutineHandshake() { // update endpoint peer.mutex.Lock() - peer.endpoint.set = true - peer.endpoint.value = elem.endpoint + peer.endpoint = elem.endpoint peer.mutex.Unlock() logDebug.Println("Received handshake initation from", peer) @@ -527,8 +523,7 @@ func (peer *Peer) RoutineSequentialReceiver() { // update endpoint peer.mutex.Lock() - peer.endpoint.set = true - peer.endpoint.value = elem.endpoint + peer.endpoint = elem.endpoint peer.mutex.Unlock() // check for keep-alive diff --git a/src/uapi.go b/src/uapi.go index e1d0929..670ecc4 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -53,8 +53,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { defer peer.mutex.RUnlock() send("public_key=" + peer.handshake.remoteStatic.ToHex()) send("preshared_key=" + peer.handshake.presharedKey.ToHex()) - if peer.endpoint.set { - send("endpoint=" + peer.endpoint.value.DstToString()) + if peer.endpoint != nil { + send("endpoint=" + peer.endpoint.DstToString()) } nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) @@ -255,17 +255,25 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "endpoint": - // set endpoint destination and reset handshake timer + // set endpoint destination + + err := func() error { + peer.mutex.Lock() + defer peer.mutex.Unlock() + + endpoint := NewEndpoint() + if err := endpoint.SetDst(value); err != nil { + return err + } + peer.endpoint = endpoint + signalSend(peer.signal.handshakeReset) + return nil + }() - peer.mutex.Lock() - err := peer.endpoint.value.SetDst(value) - peer.endpoint.set = (err == nil) - peer.mutex.Unlock() if err != nil { logError.Println("Failed to set endpoint:", value) return &IPCError{Code: ipcErrorInvalid} } - signalSend(peer.signal.handshakeReset) case "persistent_keepalive_interval":