diff --git a/src/conn.go b/src/conn.go index db4020d..012e24e 100644 --- a/src/conn.go +++ b/src/conn.go @@ -34,15 +34,20 @@ func parseEndpoint(s string) (*net.UDPAddr, error) { return addr, err } -func ListeningUpdate(device *Device) error { +func UpdateUDPListener(device *Device) error { + device.mutex.Lock() + defer device.mutex.Unlock() + netc := &device.net netc.mutex.Lock() defer netc.mutex.Unlock() // close existing sockets - if err := device.net.bind.Close(); err != nil { - return err + if netc.bind != nil { + if err := netc.bind.Close(); err != nil { + return err + } } // open new sockets @@ -64,13 +69,19 @@ func ListeningUpdate(device *Device) error { return err } - // TODO: clear endpoint (src) caches + // clear cached source addresses + + for _, peer := range device.peers { + peer.mutex.Lock() + peer.endpoint.value.ClearSrc() + peer.mutex.Unlock() + } } return nil } -func ListeningClose(device *Device) error { +func CloseUDPListener(device *Device) error { netc := &device.net netc.mutex.Lock() defer netc.mutex.Unlock() diff --git a/src/conn_linux.go b/src/conn_linux.go index 8942b03..4a5a3f0 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -133,7 +133,7 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string { } } -func (end *Endpoint) DestinationIP() net.IP { +func (end *Endpoint) DstIP() net.IP { switch end.dst.Family { case unix.AF_INET6: return end.dst.Addr[:] @@ -150,20 +150,24 @@ func (end *Endpoint) DestinationIP() net.IP { } } -func (end *Endpoint) SourceToBytes() []byte { +func (end *Endpoint) SrcToBytes() []byte { ptr := unsafe.Pointer(&end.src) arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) return arr[:] } -func (end *Endpoint) SourceToString() string { +func (end *Endpoint) SrcToString() string { return sockaddrToString(end.src) } -func (end *Endpoint) DestinationToString() string { +func (end *Endpoint) DstToString() string { return sockaddrToString(end.dst) } +func (end *Endpoint) ClearDst() { + end.dst = unix.RawSockaddrInet6{} +} + func (end *Endpoint) ClearSrc() { end.src = unix.RawSockaddrInet6{} } diff --git a/src/device.go b/src/device.go index d1e0685..1aae448 100644 --- a/src/device.go +++ b/src/device.go @@ -205,7 +205,7 @@ func (device *Device) RemoveAllPeers() { func (device *Device) Close() { device.RemoveAllPeers() close(device.signal.stop) - ListeningClose(device) + CloseUDPListener(device) } func (device *Device) WaitChannel() chan struct{} { diff --git a/src/main.go b/src/main.go index a05dbba..5aaed9b 100644 --- a/src/main.go +++ b/src/main.go @@ -14,8 +14,6 @@ func printUsage() { } func main() { - test() - // parse arguments var foreground bool diff --git a/src/peer.go b/src/peer.go index 791c091..f24dcd8 100644 --- a/src/peer.go +++ b/src/peer.go @@ -14,9 +14,12 @@ type Peer struct { persistentKeepaliveInterval uint64 keyPairs KeyPairs handshake Handshake - endpoint Endpoint device *Device - stats struct { + endpoint struct { + set bool // has a known endpoint been discovered + value Endpoint // source / destination cache + } + stats struct { txBytes uint64 // bytes send to peer (endpoint) rxBytes uint64 // bytes received from peer lastHandshakeNano int64 // nano seconds since epoch @@ -105,6 +108,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic) handshake.mutex.Unlock() + // reset endpoint + + peer.endpoint.set = false + peer.endpoint.value.ClearDst() + peer.endpoint.value.ClearSrc() + // prepare queuing peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) @@ -129,11 +138,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } +/* Returns a short string identification for logging + */ func (peer *Peer) String() string { + if !peer.endpoint.set { + return fmt.Sprintf( + "peer(%d unknown %s)", + peer.id, + base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), + ) + } return fmt.Sprintf( "peer(%d %s %s)", peer.id, - peer.endpoint.DestinationToString(), + peer.endpoint.value.DstToString(), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), ) } diff --git a/src/receive.go b/src/receive.go index 664f1ba..1f05b2f 100644 --- a/src/receive.go +++ b/src/receive.go @@ -331,7 +331,7 @@ func (device *Device) RoutineHandshake() { return } - srcBytes := elem.endpoint.SourceToBytes() + srcBytes := elem.endpoint.SrcToBytes() if device.IsUnderLoad() { // verify MAC2 field @@ -340,8 +340,7 @@ func (device *Device) RoutineHandshake() { // construct cookie reply - logDebug.Println("Sending cookie reply to:", elem.endpoint.SourceToString()) - + logDebug.Println("Sending cookie reply to:", elem.endpoint.SrcToString()) sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes) if err != nil { @@ -365,9 +364,7 @@ func (device *Device) RoutineHandshake() { // check ratelimiter - if !device.ratelimiter.Allow( - elem.endpoint.DestinationIP(), - ) { + if !device.ratelimiter.Allow(elem.endpoint.DstIP()) { continue } } @@ -398,7 +395,7 @@ func (device *Device) RoutineHandshake() { if peer == nil { logInfo.Println( "Recieved invalid initiation message from", - elem.endpoint.DestinationToString(), + elem.endpoint.DstToString(), ) continue } @@ -412,7 +409,8 @@ func (device *Device) RoutineHandshake() { // TODO: Discover destination address also, only update on change peer.mutex.Lock() - peer.endpoint = elem.endpoint + peer.endpoint.set = true + peer.endpoint.value = elem.endpoint peer.mutex.Unlock() // create response @@ -435,7 +433,7 @@ func (device *Device) RoutineHandshake() { // send response - _, err = peer.SendBuffer(packet) + err = peer.SendBuffer(packet) if err == nil { peer.TimerAnyAuthenticatedPacketTraversal() } @@ -458,7 +456,7 @@ func (device *Device) RoutineHandshake() { if peer == nil { logInfo.Println( "Recieved invalid response message from", - elem.endpoint.DestinationToString(), + elem.endpoint.DstToString(), ) continue } diff --git a/src/send.go b/src/send.go index 5c88ead..e37a736 100644 --- a/src/send.go +++ b/src/send.go @@ -105,24 +105,15 @@ func addToEncryptionQueue( } } -func (peer *Peer) SendBuffer(buffer []byte) (int, error) { +func (peer *Peer) SendBuffer(buffer []byte) error { peer.device.net.mutex.RLock() defer peer.device.net.mutex.RUnlock() - peer.mutex.RLock() defer peer.mutex.RUnlock() - - endpoint := peer.endpoint - if endpoint == nil { - return 0, errors.New("No known endpoint for peer") + if !peer.endpoint.set { + return errors.New("No known endpoint for peer") } - - conn := peer.device.net.conn - if conn == nil { - return 0, errors.New("No UDP socket for device") - } - - return conn.WriteToUDP(buffer, endpoint) + return peer.device.net.bind.Send(buffer, &peer.endpoint.value) } /* Reads packets from the TUN and inserts @@ -343,7 +334,7 @@ func (peer *Peer) RoutineSequentialSender() { // send message and return buffer to pool length := uint64(len(elem.packet)) - _, err := peer.SendBuffer(elem.packet) + err := peer.SendBuffer(elem.packet) device.PutMessageBuffer(elem.buffer) if err != nil { logDebug.Println("Failed to send authenticated packet to peer", peer.String()) diff --git a/src/timers.go b/src/timers.go index 99695ba..2a94005 100644 --- a/src/timers.go +++ b/src/timers.go @@ -288,7 +288,7 @@ func (peer *Peer) RoutineHandshakeInitiator() { packet := writer.Bytes() peer.mac.AddMacs(packet) - _, err = peer.SendBuffer(packet) + err = peer.SendBuffer(packet) if err != nil { logError.Println( "Failed to send handshake initiation message to", diff --git a/src/tun.go b/src/tun.go index 8e8c759..9eed987 100644 --- a/src/tun.go +++ b/src/tun.go @@ -47,7 +47,7 @@ func (device *Device) RoutineTUNEventReader() { if !device.tun.isUp.Get() { logInfo.Println("Interface set up") device.tun.isUp.Set(true) - updateUDPConn(device) + UpdateUDPListener(device) } } @@ -55,7 +55,7 @@ func (device *Device) RoutineTUNEventReader() { if device.tun.isUp.Get() { logInfo.Println("Interface set down") device.tun.isUp.Set(false) - closeUDPConn(device) + CloseUDPListener(device) } } } diff --git a/src/uapi.go b/src/uapi.go index 7d08e56..2de26ee 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -39,9 +39,10 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { send("private_key=" + device.privateKey.ToHex()) } - if device.net.addr != nil { - send(fmt.Sprintf("listen_port=%d", device.net.addr.Port)) + if device.net.port != 0 { + send(fmt.Sprintf("listen_port=%d", device.net.port)) } + if device.net.fwmark != 0 { send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) } @@ -52,8 +53,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { defer peer.mutex.RUnlock() send("public_key=" + peer.handshake.remoteStatic.ToHex()) send("preshared_key=" + peer.handshake.presharedKey.ToHex()) - if peer.endpoint != nil { - send("endpoint=" + peer.endpoint.String()) + if peer.endpoint.set { + send("endpoint=" + peer.endpoint.value.DstToString()) } nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) @@ -137,53 +138,24 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logError.Println("Failed to set listen_port:", err) return &IPCError{Code: ipcErrorInvalid} } - - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) - if err != nil { - logError.Println("Failed to set listen_port:", err) - return &IPCError{Code: ipcErrorInvalid} - } - - device.net.mutex.Lock() - device.net.addr = addr - device.net.mutex.Unlock() - - err = updateUDPConn(device) - if err != nil { + device.net.port = uint16(port) + if err := UpdateUDPListener(device); err != nil { logError.Println("Failed to set listen_port:", err) return &IPCError{Code: ipcErrorPortInUse} } - // TODO: Clear source address of all peers - case "fwmark": fwmark, err := strconv.ParseUint(value, 10, 32) if err != nil { logError.Println("Invalid fwmark", err) return &IPCError{Code: ipcErrorInvalid} } - device.net.mutex.Lock() - if fwmark > 0 || device.net.fwmark > 0 { - device.net.fwmark = uint32(fwmark) - err := SetMark( - device.net.conn, - device.net.fwmark, - ) - if err != nil { - logError.Println("Failed to set fwmark:", err) - device.net.mutex.Unlock() - return &IPCError{Code: ipcErrorIO} - } - - // TODO: Clear source address of all peers - } + device.net.fwmark = uint32(fwmark) device.net.mutex.Unlock() case "public_key": - // switch to peer configuration - deviceConfig = false case "replace_peers": @@ -218,7 +190,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { device.mutex.RLock() if device.publicKey.Equals(pubKey) { - // create dummy instance + // create dummy instance (not added to device) peer = &Peer{} dummy = true @@ -244,6 +216,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } case "remove": + + // remove currently selected peer from device + if value != "true" { logError.Println("Failed to set remove, invalid value:", value) return &IPCError{Code: ipcErrorInvalid} @@ -256,6 +231,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { dummy = true case "preshared_key": + + // update PSK + peer.mutex.Lock() err := peer.handshake.presharedKey.FromHex(value) peer.mutex.Unlock() @@ -265,14 +243,17 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } case "endpoint": - addr, err := parseEndpoint(value) + + // set endpoint destination and reset handshake timer + + peer.mutex.Lock() + err := peer.endpoint.value.Set(value) + peer.endpoint.set = (err == nil) + peer.mutex.Unlock() if err != nil { logError.Println("Failed to set endpoint:", value) return &IPCError{Code: ipcErrorInvalid} } - peer.mutex.Lock() - peer.endpoint = addr - peer.mutex.Unlock() signalSend(peer.signal.handshakeReset) case "persistent_keepalive_interval":