device: use channel close to shut down and drain decryption channel

This is similar to commit e1fa1cc556,
but for the decryption channel.

It is an alternative fix to f9f655567930a4cd78d40fa4ba0d58503335ae6a.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
Josh Bleecher Snyder 2021-01-11 17:34:02 -08:00 committed by Jason A. Donenfeld
parent 675955de5d
commit 48c3b87eb8
2 changed files with 54 additions and 64 deletions

View file

@ -76,7 +76,7 @@ type Device struct {
queue struct { queue struct {
encryption *encryptionQueue encryption *encryptionQueue
decryption chan *QueueInboundElement decryption *decryptionQueue
handshake chan QueueHandshakeElement handshake chan QueueHandshakeElement
} }
@ -115,6 +115,24 @@ func newEncryptionQueue() *encryptionQueue {
return q return q
} }
// A decryptionQueue is similar to an encryptionQueue; see those docs.
type decryptionQueue struct {
c chan *QueueInboundElement
wg sync.WaitGroup
}
func newDecryptionQueue() *decryptionQueue {
q := &decryptionQueue{
c: make(chan *QueueInboundElement, QueueInboundSize),
}
q.wg.Add(1)
go func() {
q.wg.Wait()
close(q.c)
}()
return q
}
/* Converts the peer into a "zombie", which remains in the peer map, /* Converts the peer into a "zombie", which remains in the peer map,
* but processes no packets and does not exists in the routing table. * but processes no packets and does not exists in the routing table.
* *
@ -308,7 +326,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
device.queue.encryption = newEncryptionQueue() device.queue.encryption = newEncryptionQueue()
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) device.queue.decryption = newDecryptionQueue()
// prepare signals // prepare signals
@ -369,13 +387,6 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) FlushPacketQueues() { func (device *Device) FlushPacketQueues() {
for { for {
select { select {
case elem, ok := <-device.queue.decryption:
if ok {
if !elem.IsDropped() {
elem.Drop()
device.PutMessageBuffer(elem.buffer)
}
}
case <-device.queue.handshake: case <-device.queue.handshake:
default: default:
return return
@ -399,10 +410,11 @@ func (device *Device) Close() {
device.isUp.Set(false) device.isUp.Set(false)
// We kept a reference to the encryption queue, // We kept a reference to the encryption and decryption queues,
// in case we started any new peers that might write to it. // in case we started any new peers that might write to them.
// No new peers are coming; we are done with the encryption queue. // No new peers are coming; we are done with these queues.
device.queue.encryption.wg.Done() device.queue.encryption.wg.Done()
device.queue.decryption.wg.Done()
close(device.signals.stop) close(device.signals.stop)
device.state.stopping.Wait() device.state.stopping.Wait()
@ -549,6 +561,7 @@ func (device *Device) BindUpdate() error {
// start receiving routines // start receiving routines
device.net.stopping.Add(2) device.net.stopping.Add(2)
device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)

View file

@ -109,6 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
logDebug := device.log.Debug logDebug := device.log.Debug
defer func() { defer func() {
logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
device.queue.decryption.wg.Done()
device.net.stopping.Done() device.net.stopping.Done()
}() }()
@ -206,7 +207,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
peer.queue.RLock() peer.queue.RLock()
if peer.isRunning.Get() { if peer.isRunning.Get() {
if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) { if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption.c, elem) {
buffer = device.GetMessageBuffer() buffer = device.GetMessageBuffer()
} }
} else { } else {
@ -258,59 +259,35 @@ func (device *Device) RoutineDecryption() {
}() }()
logDebug.Println("Routine: decryption worker - started") logDebug.Println("Routine: decryption worker - started")
for { for elem := range device.queue.decryption.c {
select { // check if dropped
case <-device.signals.stop:
for {
select {
case elem, ok := <-device.queue.decryption:
if ok {
if !elem.IsDropped() {
elem.Drop()
device.PutMessageBuffer(elem.buffer)
}
elem.Unlock()
}
default:
return
}
}
case elem, ok := <-device.queue.decryption: if elem.IsDropped() {
continue
if !ok {
return
}
// check if dropped
if elem.IsDropped() {
continue
}
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
// decrypt and release to consumer
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
content,
nil,
)
if err != nil {
elem.Drop()
device.PutMessageBuffer(elem.buffer)
}
elem.Unlock()
} }
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
// decrypt and release to consumer
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
content,
nil,
)
if err != nil {
elem.Drop()
device.PutMessageBuffer(elem.buffer)
}
elem.Unlock()
} }
} }