From 88801529fd4097993f7c448b1c3eee0abc8cb51c Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Tue, 14 Nov 2017 18:26:28 +0100 Subject: [PATCH] Moved TUN device creation to pre-fork --- src/daemon_linux.go | 11 +---- src/device.go | 4 +- src/main.go | 104 +++++++++++++++++++++++++++++--------------- src/tests/netns.sh | 21 ++++----- src/tun.go | 2 + src/tun_linux.go | 28 ++++++++++++ 6 files changed, 111 insertions(+), 59 deletions(-) diff --git a/src/daemon_linux.go b/src/daemon_linux.go index 730f89e..8210f8b 100644 --- a/src/daemon_linux.go +++ b/src/daemon_linux.go @@ -11,18 +11,9 @@ import ( * TODO: Use env variable to spawn in background */ -func Daemonize() error { +func Daemonize(attr *os.ProcAttr) error { argv := []string{os.Args[0], "--foreground"} argv = append(argv, os.Args[1:]...) - attr := &os.ProcAttr{ - Dir: ".", - Env: os.Environ(), - Files: []*os.File{ - os.Stdin, - nil, - nil, - }, - } process, err := os.StartProcess( argv[0], argv, diff --git a/src/device.go b/src/device.go index 9422d49..429ee46 100644 --- a/src/device.go +++ b/src/device.go @@ -126,13 +126,13 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { device.pool.messageBuffers.Put(msg) } -func NewDevice(tun TUNDevice, logLevel int) *Device { +func NewDevice(tun TUNDevice, logger *Logger) *Device { device := new(Device) device.mutex.Lock() defer device.mutex.Unlock() - device.log = NewLogger(logLevel, "("+tun.Name()+") ") + device.log = logger device.peers = make(map[NoisePublicKey]*Peer) device.tun.device = tun diff --git a/src/main.go b/src/main.go index eb3c67f..3808c9c 100644 --- a/src/main.go +++ b/src/main.go @@ -2,10 +2,14 @@ package main import ( "fmt" - "log" "os" "os/signal" "runtime" + "strconv" +) + +const ( + EnvWGTunFD = "WG_TUN_FD" ) func printUsage() { @@ -43,28 +47,6 @@ func main() { interfaceName = os.Args[1] } - // daemonize the process - - if !foreground { - err := Daemonize() - if err != nil { - log.Println("Failed to daemonize:", err) - } - return - } - - // increase number of go workers (for Go <1.5) - - runtime.GOMAXPROCS(runtime.NumCPU()) - - // open TUN device - - tun, err := CreateTUN(interfaceName) - if err != nil { - log.Println("Failed to create tun device:", err) - return - } - // get log level (default: info) logLevel := func() int { @@ -79,22 +61,76 @@ func main() { return LogLevelInfo }() + logger := NewLogger( + logLevel, + fmt.Sprintf("(%s) ", interfaceName), + ) + logger.Debug.Println("Debug log enabled") + + // open TUN device + + tun, err := func() (TUNDevice, error) { + tunFdStr := os.Getenv(EnvWGTunFD) + if tunFdStr == "" { + return CreateTUN(interfaceName) + } + + // construct tun device from supplied FD + + fd, err := strconv.ParseUint(tunFdStr, 10, 32) + if err != nil { + return nil, err + } + + file := os.NewFile(uintptr(fd), "/dev/net/tun") + return CreateTUNFromFile(interfaceName, file) + }() + + if err != nil { + logger.Error.Println("Failed to create TUN device:", err) + } + + // daemonize the process + + if !foreground { + env := os.Environ() + _, ok := os.LookupEnv(EnvWGTunFD) + if !ok { + kvp := fmt.Sprintf("%s=3", EnvWGTunFD) + env = append(env, kvp) + } + attr := &os.ProcAttr{ + Files: []*os.File{ + nil, // stdin + nil, // stdout + nil, // stderr + tun.File(), + }, + Dir: ".", + Env: env, + } + err = Daemonize(attr) + if err != nil { + logger.Error.Println("Failed to daemonize:", err) + } + return + } + + // increase number of go workers (for Go <1.5) + + runtime.GOMAXPROCS(runtime.NumCPU()) + // create wireguard device - device := NewDevice(tun, logLevel) - - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - - logInfo.Println("Device started") - logDebug.Println("Debug log enabled") + device := NewDevice(tun, logger) + logger.Info.Println("Device started") // start configuration lister uapi, err := NewUAPIListener(interfaceName) if err != nil { - logError.Fatal("UAPI listen error:", err) + logger.Error.Println("UAPI listen error:", err) + return } errs := make(chan error) @@ -112,7 +148,7 @@ func main() { } }() - logInfo.Println("UAPI listener started") + logger.Info.Println("UAPI listener started") // wait for program to terminate @@ -129,5 +165,5 @@ func main() { uapi.Close() - logInfo.Println("Closing") + logger.Info.Println("Shutting down") } diff --git a/src/tests/netns.sh b/src/tests/netns.sh index 9124b80..b5c2f9c 100755 --- a/src/tests/netns.sh +++ b/src/tests/netns.sh @@ -28,7 +28,7 @@ netns0="wg-test-$$-0" netns1="wg-test-$$-1" netns2="wg-test-$$-2" program="../wireguard-go" -export LOG_LEVEL="debug" +export LOG_LEVEL="info" pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; } pp() { pretty "" "$*"; "$@"; } @@ -72,13 +72,11 @@ pp ip netns add $netns2 ip0 link set up dev lo # ip0 link add dev wg1 type wireguard -n0 $program -f wg1 & -sleep 1 +n0 $program wg1 ip0 link set wg1 netns $netns1 # ip0 link add dev wg1 type wireguard -n0 $program -f wg2 & -sleep 1 +n0 $program wg2 ip0 link set wg2 netns $netns2 key1="$(pp wg genkey)" @@ -147,8 +145,6 @@ tests() { n1 iperf3 -Z -n 1G -b 0 -u -c fd00::2 } -echo "4" - [[ $(ip1 link show dev wg1) =~ mtu\ ([0-9]+) ]] && orig_mtu="${BASH_REMATCH[1]}" big_mtu=$(( 34816 - 1500 + $orig_mtu )) @@ -234,9 +230,8 @@ ip2 link del wg2 # ip1 link add dev wg1 type wireguard # ip2 link add dev wg1 type wireguard -n1 $program -f wg1 & -n2 $program -f wg2 & -sleep 5 +n1 $program wg1 +n2 $program wg2 configure_peers @@ -291,9 +286,8 @@ ip2 link del wg2 # ip1 link add dev wg1 type wireguard # ip2 link add dev wg1 type wireguard -n1 $program -f wg1 & -n2 $program -f wg2 & -sleep 5 +n1 $program wg1 +n2 $program wg2 configure_peers @@ -354,4 +348,5 @@ n2 ping -W 1 -c 1 192.168.241.1 ip1 link del veth1 ip1 link del wg1 ip2 link del wg2 + echo "done" diff --git a/src/tun.go b/src/tun.go index 9eed987..5bdac0e 100644 --- a/src/tun.go +++ b/src/tun.go @@ -1,6 +1,7 @@ package main import ( + "os" "sync/atomic" ) @@ -15,6 +16,7 @@ const ( ) type TUNDevice interface { + File() *os.File // returns the file descriptor of the device Read([]byte) (int, error) // read a packet from the device (without any additional headers) Write([]byte) (int, error) // writes a packet to the device (without any additional headers) MTU() (int, error) // returns the MTU of the device diff --git a/src/tun_linux.go b/src/tun_linux.go index accc6c6..ce6304c 100644 --- a/src/tun_linux.go +++ b/src/tun_linux.go @@ -56,6 +56,11 @@ type NativeTun struct { events chan TUNEvent // device related events } +func (tun *NativeTun) File() *os.File { + println(tun.fd.Name()) + return tun.fd +} + func (tun *NativeTun) RoutineNetlinkListener() { sock := int(C.bind_rtmgrp()) if sock < 0 { @@ -248,6 +253,29 @@ func (tun *NativeTun) Close() error { return nil } +func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) { + device := &NativeTun{ + fd: fd, + name: name, + events: make(chan TUNEvent, 5), + errors: make(chan error, 5), + } + + // start event listener + + var err error + device.index, err = getIFIndex(device.name) + if err != nil { + return nil, err + } + + go device.RoutineNetlinkListener() + + // set default MTU + + return device, device.setMTU(DefaultMTU) +} + func CreateTUN(name string) (TUNDevice, error) { // open clone device