Allows passing UAPI fd to service

This commit is contained in:
Mathias Hall-Andersen 2017-11-17 14:36:08 +01:00
parent 88801529fd
commit e1227d3af4
3 changed files with 111 additions and 67 deletions

View file

@ -9,7 +9,8 @@ import (
) )
const ( const (
EnvWGTunFD = "WG_TUN_FD" ENV_WG_TUN_FD = "WG_TUN_FD"
ENV_WG_UAPI_FD = "WG_UAPI_FD"
) )
func printUsage() { func printUsage() {
@ -65,46 +66,69 @@ func main() {
logLevel, logLevel,
fmt.Sprintf("(%s) ", interfaceName), fmt.Sprintf("(%s) ", interfaceName),
) )
logger.Debug.Println("Debug log enabled") logger.Debug.Println("Debug log enabled")
// open TUN device // open TUN device (or use supplied fd)
tun, err := func() (TUNDevice, error) { tun, err := func() (TUNDevice, error) {
tunFdStr := os.Getenv(EnvWGTunFD) tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" { if tunFdStr == "" {
return CreateTUN(interfaceName) return CreateTUN(interfaceName)
} }
// construct tun device from supplied FD // construct tun device from supplied fd
fd, err := strconv.ParseUint(tunFdStr, 10, 32) fd, err := strconv.ParseUint(tunFdStr, 10, 32)
if err != nil { if err != nil {
return nil, err return nil, err
} }
file := os.NewFile(uintptr(fd), "/dev/net/tun") file := os.NewFile(uintptr(fd), "")
return CreateTUNFromFile(interfaceName, file) return CreateTUNFromFile(interfaceName, file)
}() }()
if err != nil { if err != nil {
logger.Error.Println("Failed to create TUN device:", err) logger.Error.Println("Failed to create TUN device:", err)
os.Exit(ExitSetupFailed)
} }
// open UAPI file (or use supplied fd)
fileUAPI, err := func() (*os.File, error) {
uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
if uapiFdStr == "" {
return UAPIOpen(interfaceName)
}
// use supplied fd
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), ""), nil
}()
if err != nil {
logger.Error.Println("UAPI listen error:", err)
os.Exit(ExitSetupFailed)
return
}
// daemonize the process // daemonize the process
if !foreground { if !foreground {
env := os.Environ() env := os.Environ()
_, ok := os.LookupEnv(EnvWGTunFD) env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD))
if !ok { env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
kvp := fmt.Sprintf("%s=3", EnvWGTunFD)
env = append(env, kvp)
}
attr := &os.ProcAttr{ attr := &os.ProcAttr{
Files: []*os.File{ Files: []*os.File{
nil, // stdin nil, // stdin
nil, // stdout nil, // stdout
nil, // stderr nil, // stderr
tun.File(), tun.File(),
fileUAPI,
}, },
Dir: ".", Dir: ".",
Env: env, Env: env,
@ -112,6 +136,7 @@ func main() {
err = Daemonize(attr) err = Daemonize(attr)
if err != nil { if err != nil {
logger.Error.Println("Failed to daemonize:", err) logger.Error.Println("Failed to daemonize:", err)
os.Exit(ExitSetupFailed)
} }
return return
} }
@ -123,20 +148,17 @@ func main() {
// create wireguard device // create wireguard device
device := NewDevice(tun, logger) device := NewDevice(tun, logger)
logger.Info.Println("Device started") logger.Info.Println("Device started")
// start configuration lister // start uapi listener
uapi, err := NewUAPIListener(interfaceName)
if err != nil {
logger.Error.Println("UAPI listen error:", err)
return
}
errs := make(chan error) errs := make(chan error)
term := make(chan os.Signal) term := make(chan os.Signal)
wait := device.WaitChannel() wait := device.WaitChannel()
uapi, err := UAPIListen(interfaceName, fileUAPI)
go func() { go func() {
for { for {
conn, err := uapi.Accept() conn, err := uapi.Accept()
@ -161,9 +183,10 @@ func main() {
case <-errs: case <-errs:
} }
// clean up UAPI bind // clean up
uapi.Close() uapi.Close()
device.Close()
logger.Info.Println("Shutting down") logger.Info.Println("Shutting down")
} }

View file

@ -227,7 +227,7 @@ func (tun *NativeTun) MTU() (int, error) {
val := binary.LittleEndian.Uint32(ifr[16:20]) val := binary.LittleEndian.Uint32(ifr[16:20])
if val >= (1 << 31) { if val >= (1 << 31) {
return int(val-(1<<31)) - (1 << 31), nil return int(toInt32(val)), nil
} }
return int(val), nil return int(val), nil
} }

View file

@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr {
return nil return nil
} }
func connectUnixSocket(path string) (net.Listener, error) { func UAPIListen(name string, file *os.File) (net.Listener, error) {
// attempt inital connection // wrap file in listener
listener, err := net.Listen("unix", path) listener, err := net.FileListener(file)
if err == nil {
return listener, nil
}
// check if active
_, err = net.Dial("unix", path)
if err == nil {
return nil, errors.New("Unix socket in use")
}
// attempt cleanup
err = os.Remove(path)
if err != nil {
return nil, err
}
return net.Listen("unix", path)
}
func NewUAPIListener(name string) (net.Listener, error) {
// check if path exist
err := os.MkdirAll(socketDirectory, 077)
if err != nil && !os.IsExist(err) {
return nil, err
}
// open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
listener, err := connectUnixSocket(socketPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) {
// watch for deletion of socket // watch for deletion of socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
uapi.inotifyFd, err = unix.InotifyInit() uapi.inotifyFd, err = unix.InotifyInit()
if err != nil { if err != nil {
return nil, err return nil, err
@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) {
go func(l *UAPIListener) { go func(l *UAPIListener) {
var buff [4096]byte var buff [4096]byte
for { for {
unix.Read(uapi.inotifyFd, buff[:]) // start with lstat to avoid race condition
if _, err := os.Lstat(socketPath); os.IsNotExist(err) { if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
l.connErr <- err l.connErr <- err
return return
} }
unix.Read(uapi.inotifyFd, buff[:])
} }
}(uapi) }(uapi)
@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) {
return uapi, nil return uapi, nil
} }
func UAPIOpen(name string) (*os.File, error) {
// check if path exist
err := os.MkdirAll(socketDirectory, 0600)
if err != nil && !os.IsExist(err) {
return nil, err
}
// open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
listener, err := func() (*net.UnixListener, error) {
// initial connection attempt
listener, err := net.ListenUnix("unix", addr)
if err == nil {
return listener, nil
}
// check if socket already active
_, err = net.Dial("unix", socketPath)
if err == nil {
return nil, errors.New("unix socket in use")
}
// cleanup & attempt again
err = os.Remove(socketPath)
if err != nil {
return nil, err
}
return net.ListenUnix("unix", addr)
}()
if err != nil {
return nil, err
}
return listener.File()
}