Properly close DummyTUN to avoid deadlock in TestNoiseHandshake

License: MIT
Signed-off-by: Filippo Valsorda <valsorda@google.com>
This commit is contained in:
Filippo Valsorda 2018-05-20 23:12:55 -04:00 committed by Jason A. Donenfeld
parent 1c666576d5
commit 7bdc5eb54e
3 changed files with 12 additions and 1 deletions

View file

@ -8,6 +8,7 @@ package main
import ( import (
"bytes" "bytes"
"errors"
"os" "os"
"testing" "testing"
) )
@ -40,6 +41,8 @@ func (tun *DummyTUN) Write(d []byte, offset int) (int, error) {
} }
func (tun *DummyTUN) Close() error { func (tun *DummyTUN) Close() error {
close(tun.events)
close(tun.packets)
return nil return nil
} }
@ -48,7 +51,10 @@ func (tun *DummyTUN) Events() chan TUNEvent {
} }
func (tun *DummyTUN) Read(d []byte, offset int) (int, error) { func (tun *DummyTUN) Read(d []byte, offset int) (int, error) {
t := <-tun.packets t, ok := <-tun.packets
if !ok {
return 0, errors.New("device closed")
}
copy(d[offset:], t) copy(d[offset:], t)
return len(t), nil return len(t), nil
} }
@ -57,6 +63,7 @@ func CreateDummyTUN(name string) (TUNDevice, error) {
var dummy DummyTUN var dummy DummyTUN
dummy.mtu = 0 dummy.mtu = 0
dummy.packets = make(chan []byte, 100) dummy.packets = make(chan []byte, 100)
dummy.events = make(chan TUNEvent, 10)
return &dummy, nil return &dummy, nil
} }

View file

@ -58,6 +58,7 @@ func TestNoiseHandshake(t *testing.T) {
packet := make([]byte, 0, 256) packet := make([]byte, 0, 256)
writer := bytes.NewBuffer(packet) writer := bytes.NewBuffer(packet)
err = binary.Write(writer, binary.LittleEndian, msg1) err = binary.Write(writer, binary.LittleEndian, msg1)
assertNil(t, err)
peer := dev2.ConsumeMessageInitiation(msg1) peer := dev2.ConsumeMessageInitiation(msg1)
if peer == nil { if peer == nil {
t.Fatal("handshake failed at initiation message") t.Fatal("handshake failed at initiation message")

3
tun.go
View file

@ -33,9 +33,11 @@ type TUNDevice interface {
func (device *Device) RoutineTUNEventReader() { func (device *Device) RoutineTUNEventReader() {
setUp := false setUp := false
logDebug := device.log.Debug
logInfo := device.log.Info logInfo := device.log.Info
logError := device.log.Error logError := device.log.Error
logDebug.Println("Routine: event worker - started")
device.state.starting.Done() device.state.starting.Done()
for event := range device.tun.device.Events() { for event := range device.tun.device.Events() {
@ -67,5 +69,6 @@ func (device *Device) RoutineTUNEventReader() {
} }
} }
logDebug.Println("Routine: event worker - stopped")
device.state.stopping.Done() device.state.stopping.Done()
} }