diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go new file mode 100644 index 0000000..ad8fa05 --- /dev/null +++ b/conn/bindtest/bindtest.go @@ -0,0 +1,136 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package bindtest + +import ( + "fmt" + "math/rand" + "net" + "os" + "strconv" + + "golang.zx2c4.com/wireguard/conn" +) + +type ChannelBind struct { + rx4, tx4 *chan []byte + rx6, tx6 *chan []byte + closeSignal chan bool + source4, source6 ChannelEndpoint + target4, target6 ChannelEndpoint +} + +type ChannelEndpoint uint16 + +var _ conn.Bind = (*ChannelBind)(nil) +var _ conn.Endpoint = (*ChannelEndpoint)(nil) + +func NewChannelBinds() [2]conn.Bind { + arx4 := make(chan []byte, 8192) + brx4 := make(chan []byte, 8192) + arx6 := make(chan []byte, 8192) + brx6 := make(chan []byte, 8192) + var binds [2]ChannelBind + binds[0].rx4 = &arx4 + binds[0].tx4 = &brx4 + binds[1].rx4 = &brx4 + binds[1].tx4 = &arx4 + binds[0].rx6 = &arx6 + binds[0].tx6 = &brx6 + binds[1].rx6 = &brx6 + binds[1].tx6 = &arx6 + binds[0].target4 = ChannelEndpoint(1) + binds[1].target4 = ChannelEndpoint(2) + binds[0].target6 = ChannelEndpoint(3) + binds[1].target6 = ChannelEndpoint(4) + binds[0].source4 = binds[1].target4 + binds[0].source6 = binds[1].target6 + binds[1].source4 = binds[0].target4 + binds[1].source6 = binds[0].target6 + return [2]conn.Bind{&binds[0], &binds[1]} +} + +func (c ChannelEndpoint) ClearSrc() {} + +func (c ChannelEndpoint) SrcToString() string { return "" } + +func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) } + +func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } + +func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) } + +func (c ChannelEndpoint) SrcIP() net.IP { return nil } + +func (c *ChannelBind) Open(port uint16) (actualPort uint16, err error) { + c.closeSignal = make(chan bool) + if rand.Uint32()&1 == 0 { + return uint16(c.source4), nil + } else { + return uint16(c.source6), nil + } +} + +func (c *ChannelBind) Close() error { + if c.closeSignal != nil { + select { + case <-c.closeSignal: + default: + close(c.closeSignal) + } + } + return nil +} + +func (c *ChannelBind) SetMark(mark uint32) error { return nil } + +func (c *ChannelBind) ReceiveIPv6(b []byte) (n int, ep conn.Endpoint, err error) { + select { + case <-c.closeSignal: + return 0, nil, net.ErrClosed + case rx := <-*c.rx6: + return copy(b, rx), c.target6, nil + } +} + +func (c *ChannelBind) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) { + select { + case <-c.closeSignal: + return 0, nil, net.ErrClosed + case rx := <-*c.rx4: + return copy(b, rx), c.target4, nil + } +} + +func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { + select { + case <-c.closeSignal: + return net.ErrClosed + default: + bc := make([]byte, len(b)) + copy(bc, b) + if ep.(ChannelEndpoint) == c.target4 { + *c.tx4 <- bc + } else if ep.(ChannelEndpoint) == c.target6 { + *c.tx6 <- bc + } else { + return os.ErrInvalid + } + } + return nil +} + +func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { + _, port, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + i, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return nil, err + } + return ChannelEndpoint(i), nil +} diff --git a/device/device_test.go b/device/device_test.go index 1716f92..29daeb9 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -8,7 +8,6 @@ package device import ( "bytes" "encoding/hex" - "errors" "fmt" "io" "math/rand" @@ -17,11 +16,11 @@ import ( "runtime/pprof" "sync" "sync/atomic" - "syscall" "testing" "time" "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/conn/bindtest" "golang.zx2c4.com/wireguard/tun/tuntest" ) @@ -148,8 +147,14 @@ func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{} } // genTestPair creates a testPair. -func genTestPair(tb testing.TB) (pair testPair) { +func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { cfg, endpointCfg := genConfigs(tb) + var binds [2]conn.Bind + if realSocket { + binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() + } else { + binds = bindtest.NewChannelBinds() + } // Bring up a ChannelTun for each config. for i := range pair { p := &pair[i] @@ -159,7 +164,7 @@ func genTestPair(tb testing.TB) (pair testPair) { if _, ok := tb.(*testing.B); ok && !testing.Verbose() { level = LogLevelError } - p.dev = NewDevice(p.tun.TUN(), conn.NewDefaultBind(), NewLogger(level, fmt.Sprintf("dev%d: ", i))) + p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i))) if err := p.dev.IpcSet(cfg[i]); err != nil { tb.Errorf("failed to configure device %d: %v", i, err) p.dev.Close() @@ -187,7 +192,7 @@ func genTestPair(tb testing.TB) (pair testPair) { func TestTwoDevicePing(t *testing.T) { goroutineLeakCheck(t) - pair := genTestPair(t) + pair := genTestPair(t, true) t.Run("ping 1.0.0.1", func(t *testing.T) { pair.Send(t, Ping, nil) }) @@ -198,11 +203,11 @@ func TestTwoDevicePing(t *testing.T) { func TestUpDown(t *testing.T) { goroutineLeakCheck(t) - const itrials = 20 - const otrials = 1 + const itrials = 50 + const otrials = 10 for n := 0; n < otrials; n++ { - pair := genTestPair(t) + pair := genTestPair(t, false) for i := range pair { for k := range pair[i].dev.peers.keyMap { pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) @@ -214,17 +219,8 @@ func TestUpDown(t *testing.T) { go func(d *Device) { defer wg.Done() for i := 0; i < itrials; i++ { - start := time.Now() - for { - if err := d.Up(); err != nil { - if errors.Is(err, syscall.EADDRINUSE) && time.Now().Sub(start) < time.Second*4 { - // Some other test process is racing with us, so try again. - time.Sleep(time.Millisecond * 10) - continue - } - t.Errorf("failed up bring up device: %v", err) - } - break + if err := d.Up(); err != nil { + t.Errorf("failed up bring up device: %v", err) } time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) if err := d.Down(); err != nil { @@ -245,7 +241,7 @@ func TestUpDown(t *testing.T) { // TestConcurrencySafety does other things concurrently with tunnel use. // It is intended to be used with the race detector to catch data races. func TestConcurrencySafety(t *testing.T) { - pair := genTestPair(t) + pair := genTestPair(t, true) done := make(chan struct{}) const warmupIters = 10 @@ -315,7 +311,7 @@ func TestConcurrencySafety(t *testing.T) { } func BenchmarkLatency(b *testing.B) { - pair := genTestPair(b) + pair := genTestPair(b, true) // Establish a connection. pair.Send(b, Ping, nil) @@ -329,7 +325,7 @@ func BenchmarkLatency(b *testing.B) { } func BenchmarkThroughput(b *testing.B) { - pair := genTestPair(b) + pair := genTestPair(b, true) // Establish a connection. pair.Send(b, Ping, nil) @@ -373,7 +369,7 @@ func BenchmarkThroughput(b *testing.B) { } func BenchmarkUAPIGet(b *testing.B) { - pair := genTestPair(b) + pair := genTestPair(b, true) pair.Send(b, Ping, nil) pair.Send(b, Pong, nil) b.ReportAllocs() diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index 80ccdf9..92aa9d8 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -79,7 +79,6 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte { return pkt } -// TODO(crawshaw): find a reusable home for this. package devicetest? type ChannelTUN struct { Inbound chan []byte // incoming packets, closed on TUN close Outbound chan []byte // outbound packets, blocks forever on TUN close