diff --git a/device/channels.go b/device/channels.go index 4471477..8cd6aee 100644 --- a/device/channels.go +++ b/device/channels.go @@ -5,7 +5,10 @@ package device -import "sync" +import ( + "runtime" + "sync" +) // An outboundQueue is a channel of QueueOutboundElements awaiting encryption. // An outboundQueue is ref-counted using its wg field. @@ -67,3 +70,60 @@ func newHandshakeQueue() *handshakeQueue { }() return q } + +// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. +// It is useful in cases in which is it hard to manage the lifetime of the channel. +// The returned channel must not be closed. Senders should signal shutdown using +// some other means, such as sending a sentinel nil values. +func newAutodrainingInboundQueue(device *Device) chan *QueueInboundElement { + type autodrainingInboundQueue struct { + c chan *QueueInboundElement + } + q := &autodrainingInboundQueue{ + c: make(chan *QueueInboundElement, QueueInboundSize), + } + runtime.SetFinalizer(q, func(q *autodrainingInboundQueue) { + for { + select { + case elem := <-q.c: + if elem == nil { + continue + } + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + default: + return + } + } + }) + return q.c +} + +// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. +// It is useful in cases in which is it hard to manage the lifetime of the channel. +// The returned channel must not be closed. Senders should signal shutdown using +// some other means, such as sending a sentinel nil values. +// All sends to the channel must be best-effort, because there may be no receivers. +func newAutodrainingOutboundQueue(device *Device) chan *QueueOutboundElement { + type autodrainingOutboundQueue struct { + c chan *QueueOutboundElement + } + q := &autodrainingOutboundQueue{ + c: make(chan *QueueOutboundElement, QueueOutboundSize), + } + runtime.SetFinalizer(q, func(q *autodrainingOutboundQueue) { + for { + select { + case elem := <-q.c: + if elem == nil { + continue + } + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + default: + return + } + } + }) + return q.c +} diff --git a/device/peer.go b/device/peer.go index 3e4f4ec..49b9acb 100644 --- a/device/peer.go +++ b/device/peer.go @@ -51,8 +51,11 @@ type Peer struct { sentLastMinuteHandshake AtomicBool } + state struct { + mu sync.Mutex // protects against concurrent Start/Stop + } + queue struct { - sync.RWMutex staged chan *QueueOutboundElement // staged packets before a handshake is available outbound chan *QueueOutboundElement // sequential ordering of udp transmission inbound chan *QueueInboundElement // sequential ordering of tun writing @@ -158,8 +161,8 @@ func (peer *Peer) Start() { } // prevent simultaneous start/stop operations - peer.queue.Lock() - defer peer.queue.Unlock() + peer.state.mu.Lock() + defer peer.state.mu.Unlock() if peer.isRunning.Get() { return @@ -177,8 +180,8 @@ func (peer *Peer) Start() { peer.handshake.mutex.Unlock() // prepare queues - peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) + peer.queue.outbound = newAutodrainingOutboundQueue(device) + peer.queue.inbound = newAutodrainingInboundQueue(device) if peer.queue.staged == nil { peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize) } @@ -239,8 +242,8 @@ func (peer *Peer) ExpireCurrentKeypairs() { } func (peer *Peer) Stop() { - peer.queue.Lock() - defer peer.queue.Unlock() + peer.state.mu.Lock() + defer peer.state.mu.Unlock() if !peer.isRunning.Swap(false) { return @@ -249,9 +252,9 @@ func (peer *Peer) Stop() { peer.device.log.Verbosef("%v - Stopping...", peer) peer.timersStop() - - close(peer.queue.inbound) - close(peer.queue.outbound) + // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit. + peer.queue.inbound <- nil + peer.queue.outbound <- nil peer.stopping.Wait() peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us diff --git a/device/receive.go b/device/receive.go index c6a28f7..7acb7d9 100644 --- a/device/receive.go +++ b/device/receive.go @@ -166,7 +166,6 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { elem.Lock() // add to decryption queues - peer.queue.RLock() if peer.isRunning.Get() { peer.queue.inbound <- elem device.queue.decryption.c <- elem @@ -174,8 +173,6 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { } else { device.PutInboundElement(elem) } - peer.queue.RUnlock() - continue // otherwise it is a fixed size & handshake related packet @@ -406,6 +403,9 @@ func (peer *Peer) RoutineSequentialReceiver() { device.log.Verbosef("%v - Routine: sequential receiver - started", peer) for elem := range peer.queue.inbound { + if elem == nil { + return + } var err error elem.Lock() if elem.packet == nil { diff --git a/device/send.go b/device/send.go index 982fec0..911ee5c 100644 --- a/device/send.go +++ b/device/send.go @@ -316,7 +316,6 @@ top: elem.Lock() // add to parallel and sequential queue - peer.queue.RLock() if peer.isRunning.Get() { peer.queue.outbound <- elem peer.device.queue.encryption.c <- elem @@ -324,7 +323,6 @@ top: peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.queue.RUnlock() default: return } @@ -413,6 +411,9 @@ func (peer *Peer) RoutineSequentialSender() { device.log.Verbosef("%v - Routine: sequential sender - started", peer) for elem := range peer.queue.outbound { + if elem == nil { + return + } elem.Lock() if !peer.isRunning.Get() { // peer has been stopped; return re-usable elems to the shared pool.