tuntest: split out testing package
This code is useful to other packages writing tests. Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
This commit is contained in:
parent
85a45a9651
commit
1a1c3d0968
|
@ -8,15 +8,12 @@ package device
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun/tuntest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTwoDevicePing(t *testing.T) {
|
func TestTwoDevicePing(t *testing.T) {
|
||||||
|
@ -29,7 +26,7 @@ protocol_version=1
|
||||||
replace_allowed_ips=true
|
replace_allowed_ips=true
|
||||||
allowed_ip=1.0.0.2/32
|
allowed_ip=1.0.0.2/32
|
||||||
endpoint=127.0.0.1:53512`
|
endpoint=127.0.0.1:53512`
|
||||||
tun1 := NewChannelTUN()
|
tun1 := tuntest.NewChannelTUN()
|
||||||
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
|
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
|
||||||
dev1.Up()
|
dev1.Up()
|
||||||
defer dev1.Close()
|
defer dev1.Close()
|
||||||
|
@ -45,7 +42,7 @@ protocol_version=1
|
||||||
replace_allowed_ips=true
|
replace_allowed_ips=true
|
||||||
allowed_ip=1.0.0.1/32
|
allowed_ip=1.0.0.1/32
|
||||||
endpoint=127.0.0.1:53511`
|
endpoint=127.0.0.1:53511`
|
||||||
tun2 := NewChannelTUN()
|
tun2 := tuntest.NewChannelTUN()
|
||||||
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
|
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
|
||||||
dev2.Up()
|
dev2.Up()
|
||||||
defer dev2.Close()
|
defer dev2.Close()
|
||||||
|
@ -54,7 +51,7 @@ endpoint=127.0.0.1:53511`
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
|
msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
|
||||||
tun2.Outbound <- msg2to1
|
tun2.Outbound <- msg2to1
|
||||||
select {
|
select {
|
||||||
case msgRecv := <-tun1.Inbound:
|
case msgRecv := <-tun1.Inbound:
|
||||||
|
@ -67,7 +64,7 @@ endpoint=127.0.0.1:53511`
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||||
msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
|
msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
|
||||||
tun1.Outbound <- msg1to2
|
tun1.Outbound <- msg1to2
|
||||||
select {
|
select {
|
||||||
case msgRecv := <-tun2.Inbound:
|
case msgRecv := <-tun2.Inbound:
|
||||||
|
@ -80,139 +77,6 @@ endpoint=127.0.0.1:53511`
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ping(dst, src net.IP) []byte {
|
|
||||||
localPort := uint16(1337)
|
|
||||||
seq := uint16(0)
|
|
||||||
|
|
||||||
payload := make([]byte, 4)
|
|
||||||
binary.BigEndian.PutUint16(payload[0:], localPort)
|
|
||||||
binary.BigEndian.PutUint16(payload[2:], seq)
|
|
||||||
|
|
||||||
return genICMPv4(payload, dst, src)
|
|
||||||
}
|
|
||||||
|
|
||||||
// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
|
|
||||||
func checksum(buf []byte, initial uint16) uint16 {
|
|
||||||
v := uint32(initial)
|
|
||||||
for i := 0; i < len(buf)-1; i += 2 {
|
|
||||||
v += uint32(binary.BigEndian.Uint16(buf[i:]))
|
|
||||||
}
|
|
||||||
if len(buf)%2 == 1 {
|
|
||||||
v += uint32(buf[len(buf)-1]) << 8
|
|
||||||
}
|
|
||||||
for v > 0xffff {
|
|
||||||
v = (v >> 16) + (v & 0xffff)
|
|
||||||
}
|
|
||||||
return ^uint16(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func genICMPv4(payload []byte, dst, src net.IP) []byte {
|
|
||||||
const (
|
|
||||||
icmpv4ProtocolNumber = 1
|
|
||||||
icmpv4Echo = 8
|
|
||||||
icmpv4ChecksumOffset = 2
|
|
||||||
icmpv4Size = 8
|
|
||||||
ipv4Size = 20
|
|
||||||
ipv4TotalLenOffset = 2
|
|
||||||
ipv4ChecksumOffset = 10
|
|
||||||
ttl = 65
|
|
||||||
)
|
|
||||||
|
|
||||||
hdr := make([]byte, ipv4Size+icmpv4Size)
|
|
||||||
|
|
||||||
ip := hdr[0:ipv4Size]
|
|
||||||
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
|
|
||||||
|
|
||||||
// https://tools.ietf.org/html/rfc792
|
|
||||||
icmpv4[0] = icmpv4Echo // type
|
|
||||||
icmpv4[1] = 0 // code
|
|
||||||
chksum := ^checksum(icmpv4, checksum(payload, 0))
|
|
||||||
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
|
|
||||||
|
|
||||||
// https://tools.ietf.org/html/rfc760 section 3.1
|
|
||||||
length := uint16(len(hdr) + len(payload))
|
|
||||||
ip[0] = (4 << 4) | (ipv4Size / 4)
|
|
||||||
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
|
|
||||||
ip[8] = ttl
|
|
||||||
ip[9] = icmpv4ProtocolNumber
|
|
||||||
copy(ip[12:], src.To4())
|
|
||||||
copy(ip[16:], dst.To4())
|
|
||||||
chksum = ^checksum(ip[:], 0)
|
|
||||||
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
|
|
||||||
|
|
||||||
var v []byte
|
|
||||||
v = append(v, hdr...)
|
|
||||||
v = append(v, payload...)
|
|
||||||
return []byte(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(crawshaw): find a reusable home for this. package devicetest?
|
|
||||||
type ChannelTUN struct {
|
|
||||||
Inbound chan []byte // incoming packets, closed on TUN close
|
|
||||||
Outbound chan []byte // outbound packets, blocks forever on TUN close
|
|
||||||
|
|
||||||
closed chan struct{}
|
|
||||||
events chan tun.Event
|
|
||||||
tun chTun
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewChannelTUN() *ChannelTUN {
|
|
||||||
c := &ChannelTUN{
|
|
||||||
Inbound: make(chan []byte),
|
|
||||||
Outbound: make(chan []byte),
|
|
||||||
closed: make(chan struct{}),
|
|
||||||
events: make(chan tun.Event, 1),
|
|
||||||
}
|
|
||||||
c.tun.c = c
|
|
||||||
c.events <- tun.EventUp
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ChannelTUN) TUN() tun.Device {
|
|
||||||
return &c.tun
|
|
||||||
}
|
|
||||||
|
|
||||||
type chTun struct {
|
|
||||||
c *ChannelTUN
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *chTun) File() *os.File { return nil }
|
|
||||||
|
|
||||||
func (t *chTun) Read(data []byte, offset int) (int, error) {
|
|
||||||
select {
|
|
||||||
case <-t.c.closed:
|
|
||||||
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
|
|
||||||
case msg := <-t.c.Outbound:
|
|
||||||
return copy(data[offset:], msg), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write is called by the wireguard device to deliver a packet for routing.
|
|
||||||
func (t *chTun) Write(data []byte, offset int) (int, error) {
|
|
||||||
if offset == -1 {
|
|
||||||
close(t.c.closed)
|
|
||||||
close(t.c.events)
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
msg := make([]byte, len(data)-offset)
|
|
||||||
copy(msg, data[offset:])
|
|
||||||
select {
|
|
||||||
case <-t.c.closed:
|
|
||||||
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
|
|
||||||
case t.c.Inbound <- msg:
|
|
||||||
return len(data) - offset, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *chTun) Flush() error { return nil }
|
|
||||||
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
|
|
||||||
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
|
|
||||||
func (t *chTun) Events() chan tun.Event { return t.c.events }
|
|
||||||
func (t *chTun) Close() error {
|
|
||||||
t.Write(nil, -1)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertNil(t *testing.T, err error) {
|
func assertNil(t *testing.T, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|
150
tun/tuntest/tuntest.go
Normal file
150
tun/tuntest/tuntest.go
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package tuntest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Ping(dst, src net.IP) []byte {
|
||||||
|
localPort := uint16(1337)
|
||||||
|
seq := uint16(0)
|
||||||
|
|
||||||
|
payload := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint16(payload[0:], localPort)
|
||||||
|
binary.BigEndian.PutUint16(payload[2:], seq)
|
||||||
|
|
||||||
|
return genICMPv4(payload, dst, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
|
||||||
|
func checksum(buf []byte, initial uint16) uint16 {
|
||||||
|
v := uint32(initial)
|
||||||
|
for i := 0; i < len(buf)-1; i += 2 {
|
||||||
|
v += uint32(binary.BigEndian.Uint16(buf[i:]))
|
||||||
|
}
|
||||||
|
if len(buf)%2 == 1 {
|
||||||
|
v += uint32(buf[len(buf)-1]) << 8
|
||||||
|
}
|
||||||
|
for v > 0xffff {
|
||||||
|
v = (v >> 16) + (v & 0xffff)
|
||||||
|
}
|
||||||
|
return ^uint16(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func genICMPv4(payload []byte, dst, src net.IP) []byte {
|
||||||
|
const (
|
||||||
|
icmpv4ProtocolNumber = 1
|
||||||
|
icmpv4Echo = 8
|
||||||
|
icmpv4ChecksumOffset = 2
|
||||||
|
icmpv4Size = 8
|
||||||
|
ipv4Size = 20
|
||||||
|
ipv4TotalLenOffset = 2
|
||||||
|
ipv4ChecksumOffset = 10
|
||||||
|
ttl = 65
|
||||||
|
)
|
||||||
|
|
||||||
|
hdr := make([]byte, ipv4Size+icmpv4Size)
|
||||||
|
|
||||||
|
ip := hdr[0:ipv4Size]
|
||||||
|
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
|
||||||
|
|
||||||
|
// https://tools.ietf.org/html/rfc792
|
||||||
|
icmpv4[0] = icmpv4Echo // type
|
||||||
|
icmpv4[1] = 0 // code
|
||||||
|
chksum := ^checksum(icmpv4, checksum(payload, 0))
|
||||||
|
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
|
||||||
|
|
||||||
|
// https://tools.ietf.org/html/rfc760 section 3.1
|
||||||
|
length := uint16(len(hdr) + len(payload))
|
||||||
|
ip[0] = (4 << 4) | (ipv4Size / 4)
|
||||||
|
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
|
||||||
|
ip[8] = ttl
|
||||||
|
ip[9] = icmpv4ProtocolNumber
|
||||||
|
copy(ip[12:], src.To4())
|
||||||
|
copy(ip[16:], dst.To4())
|
||||||
|
chksum = ^checksum(ip[:], 0)
|
||||||
|
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
|
||||||
|
|
||||||
|
var v []byte
|
||||||
|
v = append(v, hdr...)
|
||||||
|
v = append(v, payload...)
|
||||||
|
return []byte(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(crawshaw): find a reusable home for this. package devicetest?
|
||||||
|
type ChannelTUN struct {
|
||||||
|
Inbound chan []byte // incoming packets, closed on TUN close
|
||||||
|
Outbound chan []byte // outbound packets, blocks forever on TUN close
|
||||||
|
|
||||||
|
closed chan struct{}
|
||||||
|
events chan tun.Event
|
||||||
|
tun chTun
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChannelTUN() *ChannelTUN {
|
||||||
|
c := &ChannelTUN{
|
||||||
|
Inbound: make(chan []byte),
|
||||||
|
Outbound: make(chan []byte),
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
events: make(chan tun.Event, 1),
|
||||||
|
}
|
||||||
|
c.tun.c = c
|
||||||
|
c.events <- tun.EventUp
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChannelTUN) TUN() tun.Device {
|
||||||
|
return &c.tun
|
||||||
|
}
|
||||||
|
|
||||||
|
type chTun struct {
|
||||||
|
c *ChannelTUN
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *chTun) File() *os.File { return nil }
|
||||||
|
|
||||||
|
func (t *chTun) Read(data []byte, offset int) (int, error) {
|
||||||
|
select {
|
||||||
|
case <-t.c.closed:
|
||||||
|
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
|
||||||
|
case msg := <-t.c.Outbound:
|
||||||
|
return copy(data[offset:], msg), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write is called by the wireguard device to deliver a packet for routing.
|
||||||
|
func (t *chTun) Write(data []byte, offset int) (int, error) {
|
||||||
|
if offset == -1 {
|
||||||
|
close(t.c.closed)
|
||||||
|
close(t.c.events)
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
msg := make([]byte, len(data)-offset)
|
||||||
|
copy(msg, data[offset:])
|
||||||
|
select {
|
||||||
|
case <-t.c.closed:
|
||||||
|
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
|
||||||
|
case t.c.Inbound <- msg:
|
||||||
|
return len(data) - offset, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const DefaultMTU = 1420
|
||||||
|
|
||||||
|
func (t *chTun) Flush() error { return nil }
|
||||||
|
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
|
||||||
|
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
|
||||||
|
func (t *chTun) Events() chan tun.Event { return t.c.events }
|
||||||
|
func (t *chTun) Close() error {
|
||||||
|
t.Write(nil, -1)
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
Reference in a new issue