From 1dd590b91b893a413666b6daaed848d89bab7f05 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 13 Jan 2018 09:00:37 +0100 Subject: [PATCH] Work on timer teardown + bug fixes Added waitgroups to peer struct for routine start / stop synchronisation --- src/conn.go | 11 ++++----- src/device.go | 23 ++++++++++++++----- src/peer.go | 62 ++++++++++++++++++++++++++++++++++---------------- src/receive.go | 2 +- src/send.go | 11 ++++++--- src/timers.go | 18 +++++++++------ src/tun.go | 6 ++--- src/uapi.go | 16 ++++++++++++- 8 files changed, 102 insertions(+), 47 deletions(-) diff --git a/src/conn.go b/src/conn.go index ddb7ed1..1d033ff 100644 --- a/src/conn.go +++ b/src/conn.go @@ -64,13 +64,9 @@ func unsafeCloseBind(device *Device) error { return err } -func updateBind(device *Device) error { - device.mutex.Lock() - defer device.mutex.Unlock() - - netc := &device.net - netc.mutex.Lock() - defer netc.mutex.Unlock() +/* Must hold device and net lock + */ +func unsafeUpdateBind(device *Device) error { // close existing sockets @@ -89,6 +85,7 @@ func updateBind(device *Device) error { // bind to new port var err error + netc := &device.net netc.bind, netc.port, err = CreateBind(netc.port) if err != nil { netc.bind = nil diff --git a/src/device.go b/src/device.go index f4a087c..5f8e91b 100644 --- a/src/device.go +++ b/src/device.go @@ -1,6 +1,7 @@ package main import ( + "github.com/sasha-s/go-deadlock" "runtime" "sync" "sync/atomic" @@ -21,12 +22,12 @@ type Device struct { messageBuffers sync.Pool } net struct { - mutex sync.RWMutex + mutex deadlock.RWMutex bind Bind // bind interface port uint16 // listening port fwmark uint32 // mark value (0 = disabled) } - mutex sync.RWMutex + mutex deadlock.RWMutex privateKey NoisePrivateKey publicKey NoisePublicKey routingTable RoutingTable @@ -49,8 +50,15 @@ func (device *Device) Up() { device.mutex.Lock() defer device.mutex.Unlock() - device.isUp.Set(true) - updateBind(device) + device.net.mutex.Lock() + defer device.net.mutex.Unlock() + + if device.isUp.Swap(true) { + return + } + + unsafeUpdateBind(device) + for _, peer := range device.peers { peer.Start() } @@ -60,8 +68,12 @@ func (device *Device) Down() { device.mutex.Lock() defer device.mutex.Unlock() - device.isUp.Set(false) + if !device.isUp.Swap(false) { + return + } + closeBind(device) + for _, peer := range device.peers { peer.Stop() } @@ -75,7 +87,6 @@ func removePeerUnsafe(device *Device, key NoisePublicKey) { if !ok { return } - peer.mutex.Lock() peer.Stop() device.routingTable.RemovePeer(peer) delete(device.peers, key) diff --git a/src/peer.go b/src/peer.go index 7c6ad47..3d82989 100644 --- a/src/peer.go +++ b/src/peer.go @@ -8,6 +8,10 @@ import ( "time" ) +const ( + PeerRoutineNumber = 4 +) + type Peer struct { id uint mutex sync.RWMutex @@ -34,7 +38,6 @@ type Peer struct { flushNonceQueue Signal // size 1, empty queued packets messageSend Signal // size 1, message was send to peer messageReceived Signal // size 1, authenticated message recv - stop Signal // size 0, stop all goroutines in peer } timer struct { // state related to WireGuard timers @@ -54,6 +57,12 @@ type Peer struct { outbound chan *QueueOutboundElement // sequential ordering of work inbound chan *QueueInboundElement // sequential ordering of work } + routines struct { + mutex sync.Mutex // held when stopping / starting routines + starting sync.WaitGroup // routines pending start + stopping sync.WaitGroup // routines pending stop + stop Signal // size 0, stop all goroutines in peer + } mac CookieGenerator } @@ -121,6 +130,10 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.signal.handshakeCompleted = NewSignal() peer.signal.flushNonceQueue = NewSignal() + peer.routines.mutex.Lock() + peer.routines.stop = NewSignal() + peer.routines.mutex.Unlock() + return peer, nil } @@ -156,32 +169,43 @@ func (peer *Peer) String() string { ) } -/* Starts all routines for a given peer - * - * Requires that the caller holds the exclusive peer lock! - */ -func unsafePeerStart(peer *Peer) { - peer.signal.stop.Broadcast() - peer.signal.stop = NewSignal() +func (peer *Peer) Start() { - var wait sync.WaitGroup + peer.routines.mutex.Lock() + defer peer.routines.mutex.Lock() - wait.Add(1) + // stop & wait for ungoing routines (if any) + + peer.routines.stop.Broadcast() + peer.routines.starting.Wait() + peer.routines.stopping.Wait() + + // reset signal and start (new) routines + + peer.routines.stop = NewSignal() + peer.routines.starting.Add(PeerRoutineNumber) + peer.routines.stopping.Add(PeerRoutineNumber) go peer.RoutineNonce() - go peer.RoutineTimerHandler(&wait) + go peer.RoutineTimerHandler() go peer.RoutineSequentialSender() go peer.RoutineSequentialReceiver() - wait.Wait() -} - -func (peer *Peer) Start() { - peer.mutex.Lock() - unsafePeerStart(peer) - peer.mutex.Unlock() + peer.routines.starting.Wait() } func (peer *Peer) Stop() { - peer.signal.stop.Broadcast() + + peer.routines.mutex.Lock() + defer peer.routines.mutex.Lock() + + // stop & wait for ungoing routines (if any) + + peer.routines.stop.Broadcast() + peer.routines.starting.Wait() + peer.routines.stopping.Wait() + + // reset signal (to handle repeated stopping) + + peer.routines.stop = NewSignal() } diff --git a/src/receive.go b/src/receive.go index dbd2813..e6e8481 100644 --- a/src/receive.go +++ b/src/receive.go @@ -497,7 +497,7 @@ func (peer *Peer) RoutineSequentialReceiver() { select { - case <-peer.signal.stop.Wait(): + case <-peer.routines.stop.Wait(): logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id) return diff --git a/src/send.go b/src/send.go index 9537f5e..fa13c91 100644 --- a/src/send.go +++ b/src/send.go @@ -192,7 +192,7 @@ func (peer *Peer) RoutineNonce() { for { NextPacket: select { - case <-peer.signal.stop.Wait(): + case <-peer.routines.stop.Wait(): return case elem := <-peer.queue.nonce: @@ -217,7 +217,7 @@ func (peer *Peer) RoutineNonce() { logDebug.Println("Clearing queue for", peer.String()) peer.FlushNonceQueue() goto NextPacket - case <-peer.signal.stop.Wait(): + case <-peer.routines.stop.Wait(): return } } @@ -309,15 +309,20 @@ func (device *Device) RoutineEncryption() { * The routine terminates then the outbound queue is closed. */ func (peer *Peer) RoutineSequentialSender() { + + defer peer.routines.stopping.Done() + device := peer.device logDebug := device.log.Debug logDebug.Println("Routine, sequential sender, started for", peer.String()) + peer.routines.starting.Done() + for { select { - case <-peer.signal.stop.Wait(): + case <-peer.routines.stop.Wait(): logDebug.Println( "Routine, sequential sender, stopped for", peer.String()) return diff --git a/src/timers.go b/src/timers.go index f2fed30..f1ed9c5 100644 --- a/src/timers.go +++ b/src/timers.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/binary" "math/rand" - "sync" "sync/atomic" "time" ) @@ -182,7 +181,10 @@ func (peer *Peer) sendNewHandshake() error { return err } -func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) { +func (peer *Peer) RoutineTimerHandler() { + + defer peer.routines.stopping.Done() + device := peer.device logInfo := device.log.Info @@ -203,15 +205,20 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) { peer.timer.keepalivePersistent.Reset(duration) } - // signal that timers are reset + // signal synchronised setup complete - ready.Done() + peer.routines.starting.Done() // handle timer events for { select { + /* stopping */ + + case <-peer.routines.stop.Wait(): + return + /* timers */ // keep-alive @@ -312,9 +319,6 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) { /* signals */ - case <-peer.signal.stop.Wait(): - return - case <-peer.signal.handshakeBegin.Wait(): peer.signal.handshakeBegin.Disable() diff --git a/src/tun.go b/src/tun.go index 024f0f0..6259f33 100644 --- a/src/tun.go +++ b/src/tun.go @@ -45,14 +45,14 @@ func (device *Device) RoutineTUNEventReader() { } } - if event&TUNEventUp != 0 { + if event&TUNEventUp != 0 && !device.isUp.Get() { logInfo.Println("Interface set up") device.Up() } - if event&TUNEventDown != 0 { + if event&TUNEventDown != 0 && device.isUp.Get() { logInfo.Println("Interface set down") - device.Up() + device.Down() } } } diff --git a/src/uapi.go b/src/uapi.go index a67bff1..f66528c 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -133,13 +133,27 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { device.SetPrivateKey(sk) case "listen_port": + + // parse port number + port, err := strconv.ParseUint(value, 10, 16) if err != nil { logError.Println("Failed to parse listen_port:", err) return &IPCError{Code: ipcErrorInvalid} } + + // update port and rebind + + device.mutex.Lock() + device.net.mutex.Lock() + device.net.port = uint16(port) - if err := updateBind(device); err != nil { + err = unsafeUpdateBind(device) + + device.net.mutex.Unlock() + device.mutex.Unlock() + + if err != nil { logError.Println("Failed to set listen_port:", err) return &IPCError{Code: ipcErrorPortInUse} }