Ported remaining netns.sh

- Ported remaining netns.sh tests
- Begin work on generic implementation of bind interface
This commit is contained in:
Mathias Hall-Andersen 2017-11-17 17:25:45 +01:00
parent e1227d3af4
commit fa399a91d5
13 changed files with 194 additions and 28 deletions

View file

@ -15,6 +15,22 @@ type UDPBind interface {
Close() error Close() error
} }
/* An Endpoint maintains the source/destination caching for a peer
*
* dst : the remote address of a peer
* src : the local address from which datagrams originate going to the peer
*
*/
type UDPEndpoint interface {
ClearSrc() // clears the source address
ClearDst() // clears the destination address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP
SrcIP() net.IP
}
func parseEndpoint(s string) (*net.UDPAddr, error) { func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address // ensure that the host is an IP address

View file

@ -6,6 +6,41 @@ import (
"net" "net"
) )
/* This code is meant to be a temporary solution
* on platforms for which the sticky socket / source caching behavior
* has not yet been implemented.
*
* See conn_linux.go for an implementation on the linux platform.
*/
type Endpoint *net.UDPAddr
type NativeBind *net.UDPConn
func CreateUDPBind(port uint16) (UDPBind, uint16, error) {
// listen
addr := UDPAddr{
Port: int(port),
}
conn, err := net.ListenUDP("udp", &addr)
if err != nil {
return nil, 0, err
}
// retrieve port
laddr := conn.LocalAddr()
uaddr, _ = net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
return uaddr.Port
}
func (_ Endpoint) ClearSrc() {}
func SetMark(conn *net.UDPConn, value uint32) error { func SetMark(conn *net.UDPConn, value uint32) error {
return nil return nil
} }

View file

@ -168,7 +168,7 @@ func (end *Endpoint) DstIP() net.IP {
} }
} }
func (end *Endpoint) SrcToBytes() []byte { func (end *Endpoint) DstToBytes() []byte {
ptr := unsafe.Pointer(&end.src) ptr := unsafe.Pointer(&end.src)
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
return arr[:] return arr[:]

View file

@ -1,7 +1,6 @@
package main package main
import ( import (
"net"
"testing" "testing"
) )
@ -25,7 +24,7 @@ func TestCookieMAC1(t *testing.T) {
// check mac1 // check mac1
src, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4000") src := []byte{192, 168, 13, 37, 10, 10, 10}
checkMAC1 := func(msg []byte) { checkMAC1 := func(msg []byte) {
generator.AddMacs(msg) generator.AddMacs(msg)
@ -128,12 +127,12 @@ func TestCookieMAC1(t *testing.T) {
msg[5] ^= 0x20 msg[5] ^= 0x20
srcBad1, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4001") srcBad1 := []byte{192, 168, 13, 37, 40, 01}
if checker.CheckMAC2(msg, srcBad1) { if checker.CheckMAC2(msg, srcBad1) {
t.Fatal("MAC2 generation/verification failed") t.Fatal("MAC2 generation/verification failed")
} }
srcBad2, _ := net.ResolveUDPAddr("udp", "192.168.13.38:4000") srcBad2 := []byte{192, 168, 13, 38, 40, 01}
if checker.CheckMAC2(msg, srcBad2) { if checker.CheckMAC2(msg, srcBad2) {
t.Fatal("MAC2 generation/verification failed") t.Fatal("MAC2 generation/verification failed")
} }

View file

@ -2,20 +2,25 @@ package main
import ( import (
"os" "os"
"os/exec"
) )
/* Daemonizes the process on linux /* Daemonizes the process on linux
* *
* This is done by spawning and releasing a copy with the --foreground flag * This is done by spawning and releasing a copy with the --foreground flag
*
* TODO: Use env variable to spawn in background
*/ */
func Daemonize(attr *os.ProcAttr) error { func Daemonize(attr *os.ProcAttr) error {
// I would like to use os.Executable,
// however this means dropping support for Go <1.8
path, err := exec.LookPath(os.Args[0])
if err != nil {
return err
}
argv := []string{os.Args[0], "--foreground"} argv := []string{os.Args[0], "--foreground"}
argv = append(argv, os.Args[1:]...) argv = append(argv, os.Args[1:]...)
process, err := os.StartProcess( process, err := os.StartProcess(
argv[0], path,
argv, argv,
attr, attr,
) )

View file

@ -8,6 +8,7 @@ import (
) )
type Device struct { type Device struct {
closed AtomicBool // device is closed? (acting as guard)
log *Logger // collection of loggers for levels log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers idCounter uint // for assigning debug ids to peers
fwMark uint32 fwMark uint32
@ -203,6 +204,9 @@ func (device *Device) RemoveAllPeers() {
} }
func (device *Device) Close() { func (device *Device) Close() {
if device.closed.Swap(true) {
return
}
device.log.Info.Println("Closing device") device.log.Info.Println("Closing device")
device.RemoveAllPeers() device.RemoveAllPeers()
close(device.signal.stop) close(device.signal.stop)

View file

@ -2,6 +2,7 @@ package main
import ( import (
"bytes" "bytes"
"os"
"testing" "testing"
) )
@ -15,6 +16,10 @@ type DummyTUN struct {
events chan TUNEvent events chan TUNEvent
} }
func (tun *DummyTUN) File() *os.File {
return nil
}
func (tun *DummyTUN) Name() string { func (tun *DummyTUN) Name() string {
return tun.name return tun.name
} }
@ -67,7 +72,8 @@ func randDevice(t *testing.T) *Device {
t.Fatal(err) t.Fatal(err)
} }
tun, _ := CreateDummyTUN("dummy") tun, _ := CreateDummyTUN("dummy")
device := NewDevice(tun, LogLevelError) logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk) device.SetPrivateKey(sk)
return device return device
} }

View file

@ -21,6 +21,14 @@ func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.flag) == AtomicTrue return atomic.LoadInt32(&a.flag) == AtomicTrue
} }
func (a *AtomicBool) Swap(val bool) bool {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
return atomic.SwapInt32(&a.flag, flag) == AtomicTrue
}
func (a *AtomicBool) Set(val bool) { func (a *AtomicBool) Set(val bool) {
flag := AtomicFalse flag := AtomicFalse
if val { if val {

View file

@ -117,8 +117,8 @@ func TestNoiseHandshake(t *testing.T) {
var err error var err error
var out []byte var out []byte
var nonce [12]byte var nonce [12]byte
out = key1.send.aead.Seal(out, nonce[:], testMsg, nil) out = key1.send.Seal(out, nonce[:], testMsg, nil)
out, err = key2.receive.aead.Open(out[:0], nonce[:], out, nil) out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
assertNil(t, err) assertNil(t, err)
assertEqual(t, out, testMsg) assertEqual(t, out, testMsg)
}() }()
@ -128,8 +128,8 @@ func TestNoiseHandshake(t *testing.T) {
var err error var err error
var out []byte var out []byte
var nonce [12]byte var nonce [12]byte
out = key2.send.aead.Seal(out, nonce[:], testMsg, nil) out = key2.send.Seal(out, nonce[:], testMsg, nil)
out, err = key1.receive.aead.Open(out[:0], nonce[:], out, nil) out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
assertNil(t, err) assertNil(t, err)
assertEqual(t, out, testMsg) assertEqual(t, out, testMsg)
}() }()

View file

@ -311,7 +311,10 @@ func (device *Device) RoutineHandshake() {
return return
} }
srcBytes := elem.endpoint.SrcToBytes() // endpoints destination address is the source of the datagram
srcBytes := elem.endpoint.DstToBytes()
if device.IsUnderLoad() { if device.IsUnderLoad() {
// verify MAC2 field // verify MAC2 field
@ -320,8 +323,12 @@ func (device *Device) RoutineHandshake() {
// construct cookie reply // construct cookie reply
logDebug.Println("Sending cookie reply to:", elem.endpoint.SrcToString()) logDebug.Println(
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" "Sending cookie reply to:",
elem.endpoint.DstToString(),
)
sender := binary.LittleEndian.Uint32(elem.packet[4:8])
reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes) reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
if err != nil { if err != nil {
logError.Println("Failed to create cookie reply:", err) logError.Println("Failed to create cookie reply:", err)
@ -555,8 +562,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.routingTable.LookupIPv4(src) != peer { if device.routingTable.LookupIPv4(src) != peer {
logInfo.Println(src) logInfo.Println(
logInfo.Println("Packet with unallowed source IPv4 from", peer.String()) "IPv4 packet with unallowed source address from",
peer.String(),
)
continue continue
} }
@ -581,8 +590,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routingTable.LookupIPv6(src) != peer { if device.routingTable.LookupIPv6(src) != peer {
logInfo.Println(src) logInfo.Println(
logInfo.Println("Packet with unallowed source IPv6 from", peer.String()) "IPv6 packet with unallowed source address from",
peer.String(),
)
continue continue
} }
@ -591,7 +602,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue continue
} }
// write to tun // write to tun device
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
_, err := device.tun.device.Write(elem.packet) _, err := device.tun.device.Write(elem.packet)

View file

@ -20,6 +20,14 @@
# wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1 # wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1
# interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further # interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further
# details on how this is accomplished. # details on how this is accomplished.
# This code is ported to the WireGuard-Go directly from the kernel project.
#
# Please ensure that you have installed the newest version of the WireGuard
# tools from the WireGuard project and before running these tests as:
#
# ./netns.sh <path to wireguard-go>
set -e set -e
exec 3>&1 exec 3>&1
@ -27,7 +35,7 @@ export WG_HIDE_KEYS=never
netns0="wg-test-$$-0" netns0="wg-test-$$-0"
netns1="wg-test-$$-1" netns1="wg-test-$$-1"
netns2="wg-test-$$-2" netns2="wg-test-$$-2"
program="../wireguard-go" program=$1
export LOG_LEVEL="info" export LOG_LEVEL="info"
pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; } pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; }
@ -349,4 +357,68 @@ ip1 link del veth1
ip1 link del wg1 ip1 link del wg1
ip2 link del wg2 ip2 link del wg2
echo "done" # Test that Netlink/IPC is working properly by doing things that usually cause split responses
n0 $program wg0
sleep 5
config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" )
for a in {1..255}; do
for b in {0..255}; do
config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" )
done
done
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
i=0
for ip in $(n0 wg show wg0 allowed-ips); do
((++i))
done
((i == 255*256*2+1))
ip0 link del wg0
n0 $program wg0
config=( "[Interface]" "PrivateKey=$(wg genkey)" )
for a in {1..40}; do
config+=( "[Peer]" "PublicKey=$(wg genkey)" )
for b in {1..52}; do
config+=( "AllowedIPs=$a.$b.0.0/16" )
done
done
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
i=0
while read -r line; do
j=0
for ip in $line; do
((++j))
done
((j == 53))
((++i))
done < <(n0 wg show wg0 allowed-ips)
((i == 40))
ip0 link del wg0
n0 $program wg0
config=( )
for i in {1..29}; do
config+=( "[Peer]" "PublicKey=$(wg genkey)" )
done
config+=( "[Peer]" "PublicKey=$(wg genkey)" "AllowedIPs=255.2.3.4/32,abcd::255/128" )
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
n0 wg showconf wg0 > /dev/null
ip0 link del wg0
! n0 wg show doesnotexist || false
declare -A objects
while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
[[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue
objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"
done < /dev/kmsg
alldeleted=1
for object in "${!objects[@]}"; do
if [[ ${objects["$object"]} != *createddestroyed ]]; then
echo "Error: $object: merely ${objects["$object"]}" >&3
alldeleted=0
fi
done
[[ $alldeleted -eq 1 ]]
pretty "" "Objects that were created were also destroyed."

View file

@ -57,7 +57,6 @@ type NativeTun struct {
} }
func (tun *NativeTun) File() *os.File { func (tun *NativeTun) File() *os.File {
println(tun.fd.Name())
return tun.fd return tun.fd
} }

View file

@ -145,11 +145,22 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
} }
case "fwmark": case "fwmark":
fwmark, err := strconv.ParseUint(value, 10, 32)
// parse fwmark field
fwmark, err := func() (uint32, error) {
if value == "" {
return 0, nil
}
mark, err := strconv.ParseUint(value, 10, 32)
return uint32(mark), err
}()
if err != nil { if err != nil {
logError.Println("Invalid fwmark", err) logError.Println("Invalid fwmark", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
device.net.mutex.Lock() device.net.mutex.Lock()
device.net.fwmark = uint32(fwmark) device.net.fwmark = uint32(fwmark)
device.net.mutex.Unlock() device.net.mutex.Unlock()