diff --git a/device/device.go b/device/device.go index 9e2d001..d9367e5 100644 --- a/device/device.go +++ b/device/device.go @@ -74,7 +74,7 @@ type Device struct { } queue struct { - encryption chan *QueueOutboundElement + encryption *encryptionQueue decryption chan *QueueInboundElement handshake chan QueueHandshakeElement } @@ -89,6 +89,31 @@ type Device struct { } } +// An encryptionQueue is a channel of QueueOutboundElements awaiting encryption. +// An encryptionQueue is ref-counted using its wg field. +// An encryptionQueue created with newEncryptionQueue has one reference. +// Every additional writer must call wg.Add(1). +// Every completed writer must call wg.Done(). +// When no further writers will be added, +// call wg.Done to remove the initial reference. +// When the refcount hits 0, the queue's channel is closed. +type encryptionQueue struct { + c chan *QueueOutboundElement + wg sync.WaitGroup +} + +func newEncryptionQueue() *encryptionQueue { + q := &encryptionQueue{ + c: make(chan *QueueOutboundElement, QueueOutboundSize), + } + q.wg.Add(1) + go func() { + q.wg.Wait() + close(q.c) + }() + return q +} + /* Converts the peer into a "zombie", which remains in the peer map, * but processes no packets and does not exists in the routing table. * @@ -280,7 +305,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { // create queues device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) - device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) + device.queue.encryption = newEncryptionQueue() device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) // prepare signals @@ -297,7 +322,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { cpus := runtime.NumCPU() device.state.stopping.Wait() for i := 0; i < cpus; i += 1 { - device.state.stopping.Add(3) + device.state.stopping.Add(2) // decryption and handshake go device.RoutineEncryption() go device.RoutineDecryption() go device.RoutineHandshake() @@ -346,10 +371,6 @@ func (device *Device) FlushPacketQueues() { if ok { elem.Drop() } - case elem, ok := <-device.queue.encryption: - if ok { - elem.Drop() - } case <-device.queue.handshake: default: return @@ -373,6 +394,10 @@ func (device *Device) Close() { device.isUp.Set(false) + // We kept a reference to the encryption queue, + // in case we started any new peers that might write to it. + // No new peers are coming; we are done with the encryption queue. + device.queue.encryption.wg.Done() close(device.signals.stop) device.state.stopping.Wait() diff --git a/device/device_test.go b/device/device_test.go index a89dcc2..65942ec 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -7,9 +7,11 @@ package device import ( "bytes" + "errors" "fmt" "io" "net" + "sync" "testing" "time" @@ -79,18 +81,74 @@ func genConfigs(t *testing.T) (cfgs [2]io.Reader) { return } -// genChannelTUNs creates a usable pair of ChannelTUNs for use in a test. -func genChannelTUNs(t *testing.T) (tun [2]*tuntest.ChannelTUN) { +// A testPair is a pair of testPeers. +type testPair [2]testPeer + +// A testPeer is a peer used for testing. +type testPeer struct { + tun *tuntest.ChannelTUN + dev *Device + ip net.IP +} + +type SendDirection bool + +const ( + Ping SendDirection = true + Pong SendDirection = false +) + +func (pair *testPair) Send(t *testing.T, ping SendDirection, done chan struct{}) { + t.Helper() + p0, p1 := pair[0], pair[1] + if !ping { + // pong is the new ping + p0, p1 = p1, p0 + } + msg := tuntest.Ping(p0.ip, p1.ip) + p1.tun.Outbound <- msg + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + var err error + select { + case msgRecv := <-p0.tun.Inbound: + if !bytes.Equal(msg, msgRecv) { + err = errors.New("ping did not transit correctly") + } + case <-timer.C: + err = errors.New("ping did not transit") + case <-done: + } + if err != nil { + // The error may have occurred because the test is done. + select { + case <-done: + return + default: + } + // Real error. + t.Error(err) + } +} + +// genTestPair creates a testPair. +func genTestPair(t *testing.T) (pair testPair) { const maxAttempts = 10 NextAttempt: for i := 0; i < maxAttempts; i++ { cfg := genConfigs(t) // Bring up a ChannelTun for each config. - for i := range tun { - tun[i] = tuntest.NewChannelTUN() - dev := NewDevice(tun[i].TUN(), NewLogger(LogLevelDebug, fmt.Sprintf("dev%d: ", i))) - dev.Up() - if err := dev.IpcSetOperation(cfg[i]); err != nil { + for i := range pair { + p := &pair[i] + p.tun = tuntest.NewChannelTUN() + if i == 0 { + p.ip = net.ParseIP("1.0.0.1") + } else { + p.ip = net.ParseIP("1.0.0.2") + } + p.dev = NewDevice(p.tun.TUN(), NewLogger(LogLevelDebug, fmt.Sprintf("dev%d: ", i))) + p.dev.Up() + if err := p.dev.IpcSetOperation(cfg[i]); err != nil { // genConfigs attempted to pick ports that were free. // There's a tiny window between genConfigs closing the port // and us opening it, during which another process could @@ -104,12 +162,12 @@ NextAttempt: // The device might still not be up, e.g. due to an error // in RoutineTUNEventReader's call to dev.Up that got swallowed. // Assume it's due to a transient error (port in use), and retry. - if !dev.isUp.Get() { - t.Logf("%v did not come up, trying again", dev) + if !p.dev.isUp.Get() { + t.Logf("device %d did not come up, trying again", i) continue NextAttempt } // The device is up. Close it when the test completes. - t.Cleanup(dev.Close) + t.Cleanup(p.dev.Close) } return // success } @@ -119,35 +177,49 @@ NextAttempt: } func TestTwoDevicePing(t *testing.T) { - tun := genChannelTUNs(t) - + pair := genTestPair(t) t.Run("ping 1.0.0.1", func(t *testing.T) { - msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) - tun[1].Outbound <- msg2to1 - select { - case msgRecv := <-tun[0].Inbound: - if !bytes.Equal(msg2to1, msgRecv) { - t.Error("ping did not transit correctly") - } - case <-time.After(5 * time.Second): - t.Error("ping did not transit") - } + pair.Send(t, Ping, nil) }) - t.Run("ping 1.0.0.2", func(t *testing.T) { - msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) - tun[0].Outbound <- msg1to2 - select { - case msgRecv := <-tun[1].Inbound: - if !bytes.Equal(msg1to2, msgRecv) { - t.Error("return ping did not transit correctly") - } - case <-time.After(5 * time.Second): - t.Error("return ping did not transit") - } + pair.Send(t, Pong, nil) }) } +// TestConcurrencySafety does other things concurrently with tunnel use. +// It is intended to be used with the race detector to catch data races. +func TestConcurrencySafety(t *testing.T) { + pair := genTestPair(t) + done := make(chan struct{}) + + const warmupIters = 10 + var warmup sync.WaitGroup + warmup.Add(warmupIters) + go func() { + // Send data continuously back and forth until we're done. + // Note that we may continue to attempt to send data + // even after done is closed. + i := warmupIters + for ping := Ping; ; ping = !ping { + pair.Send(t, ping, done) + select { + case <-done: + return + default: + } + if i > 0 { + warmup.Done() + i-- + } + } + }() + warmup.Wait() + + // coming soon: more things here... + + close(done) +} + func assertNil(t *testing.T, err error) { if err != nil { t.Fatal(err) diff --git a/device/send.go b/device/send.go index 0801b71..1b16edd 100644 --- a/device/send.go +++ b/device/send.go @@ -352,6 +352,9 @@ func (peer *Peer) RoutineNonce() { device := peer.device logDebug := device.log.Debug + // We write to the encryption queue; keep it alive until we are done. + device.queue.encryption.wg.Add(1) + flush := func() { for { select { @@ -368,6 +371,7 @@ func (peer *Peer) RoutineNonce() { flush() logDebug.Println(peer, "- Routine: nonce worker - stopped") peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) + device.queue.encryption.wg.Done() // no more writes from us peer.routines.stopping.Done() }() @@ -455,7 +459,7 @@ NextPacket: elem.Lock() // add to parallel and sequential queue - addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem) + addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption.c, elem) } } } @@ -486,76 +490,46 @@ func (device *Device) RoutineEncryption() { logDebug := device.log.Debug - defer func() { - for { - select { - case elem, ok := <-device.queue.encryption: - if ok && !elem.IsDropped() { - elem.Drop() - device.PutMessageBuffer(elem.buffer) - elem.Unlock() - } - default: - goto out - } - } - out: - logDebug.Println("Routine: encryption worker - stopped") - device.state.stopping.Done() - }() - + defer logDebug.Println("Routine: encryption worker - stopped") logDebug.Println("Routine: encryption worker - started") - for { + for elem := range device.queue.encryption.c { - // fetch next element + // check if dropped - select { - case <-device.signals.stop: - return - - case elem, ok := <-device.queue.encryption: - - if !ok { - return - } - - // check if dropped - - if elem.IsDropped() { - continue - } - - // populate header fields - - header := elem.buffer[:MessageTransportHeaderSize] - - fieldType := header[0:4] - fieldReceiver := header[4:8] - fieldNonce := header[8:16] - - binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) - binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) - - // pad content to multiple of 16 - - paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu))) - for i := 0; i < paddingSize; i++ { - elem.packet = append(elem.packet, 0) - } - - // encrypt content and release to consumer - - binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.packet = elem.keypair.send.Seal( - header, - nonce[:], - elem.packet, - nil, - ) - elem.Unlock() + if elem.IsDropped() { + continue } + + // populate header fields + + header := elem.buffer[:MessageTransportHeaderSize] + + fieldType := header[0:4] + fieldReceiver := header[4:8] + fieldNonce := header[8:16] + + binary.LittleEndian.PutUint32(fieldType, MessageTransportType) + binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) + binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) + + // pad content to multiple of 16 + + paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu))) + for i := 0; i < paddingSize; i++ { + elem.packet = append(elem.packet, 0) + } + + // encrypt content and release to consumer + + binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) + elem.packet = elem.keypair.send.Seal( + header, + nonce[:], + elem.packet, + nil, + ) + elem.Unlock() } } @@ -576,6 +550,7 @@ func (peer *Peer) RoutineSequentialSender() { select { case elem, ok := <-peer.queue.outbound: if ok { + elem.Lock() if !elem.IsDropped() { device.PutMessageBuffer(elem.buffer) elem.Drop()