Allows passing UAPI fd to service
This commit is contained in:
		
							parent
							
								
									88801529fd
								
							
						
					
					
						commit
						e1227d3af4
					
				
					 3 changed files with 111 additions and 67 deletions
				
			
		
							
								
								
									
										59
									
								
								src/main.go
									
									
									
									
									
								
							
							
						
						
									
										59
									
								
								src/main.go
									
									
									
									
									
								
							|  | @ -9,7 +9,8 @@ import ( | |||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	EnvWGTunFD = "WG_TUN_FD" | ||||
| 	ENV_WG_TUN_FD  = "WG_TUN_FD" | ||||
| 	ENV_WG_UAPI_FD = "WG_UAPI_FD" | ||||
| ) | ||||
| 
 | ||||
| func printUsage() { | ||||
|  | @ -65,46 +66,69 @@ func main() { | |||
| 		logLevel, | ||||
| 		fmt.Sprintf("(%s) ", interfaceName), | ||||
| 	) | ||||
| 
 | ||||
| 	logger.Debug.Println("Debug log enabled") | ||||
| 
 | ||||
| 	// open TUN device
 | ||||
| 	// open TUN device (or use supplied fd)
 | ||||
| 
 | ||||
| 	tun, err := func() (TUNDevice, error) { | ||||
| 		tunFdStr := os.Getenv(EnvWGTunFD) | ||||
| 		tunFdStr := os.Getenv(ENV_WG_TUN_FD) | ||||
| 		if tunFdStr == "" { | ||||
| 			return CreateTUN(interfaceName) | ||||
| 		} | ||||
| 
 | ||||
| 		// construct tun device from supplied FD
 | ||||
| 		// construct tun device from supplied fd
 | ||||
| 
 | ||||
| 		fd, err := strconv.ParseUint(tunFdStr, 10, 32) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		file := os.NewFile(uintptr(fd), "/dev/net/tun") | ||||
| 		file := os.NewFile(uintptr(fd), "") | ||||
| 		return CreateTUNFromFile(interfaceName, file) | ||||
| 	}() | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		logger.Error.Println("Failed to create TUN device:", err) | ||||
| 		os.Exit(ExitSetupFailed) | ||||
| 	} | ||||
| 
 | ||||
| 	// open UAPI file (or use supplied fd)
 | ||||
| 
 | ||||
| 	fileUAPI, err := func() (*os.File, error) { | ||||
| 		uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) | ||||
| 		if uapiFdStr == "" { | ||||
| 			return UAPIOpen(interfaceName) | ||||
| 		} | ||||
| 
 | ||||
| 		// use supplied fd
 | ||||
| 
 | ||||
| 		fd, err := strconv.ParseUint(uapiFdStr, 10, 32) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		return os.NewFile(uintptr(fd), ""), nil | ||||
| 	}() | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		logger.Error.Println("UAPI listen error:", err) | ||||
| 		os.Exit(ExitSetupFailed) | ||||
| 		return | ||||
| 	} | ||||
| 	// daemonize the process
 | ||||
| 
 | ||||
| 	if !foreground { | ||||
| 		env := os.Environ() | ||||
| 		_, ok := os.LookupEnv(EnvWGTunFD) | ||||
| 		if !ok { | ||||
| 			kvp := fmt.Sprintf("%s=3", EnvWGTunFD) | ||||
| 			env = append(env, kvp) | ||||
| 		} | ||||
| 		env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD)) | ||||
| 		env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) | ||||
| 		attr := &os.ProcAttr{ | ||||
| 			Files: []*os.File{ | ||||
| 				nil, // stdin
 | ||||
| 				nil, // stdout
 | ||||
| 				nil, // stderr
 | ||||
| 				tun.File(), | ||||
| 				fileUAPI, | ||||
| 			}, | ||||
| 			Dir: ".", | ||||
| 			Env: env, | ||||
|  | @ -112,6 +136,7 @@ func main() { | |||
| 		err = Daemonize(attr) | ||||
| 		if err != nil { | ||||
| 			logger.Error.Println("Failed to daemonize:", err) | ||||
| 			os.Exit(ExitSetupFailed) | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | @ -123,20 +148,17 @@ func main() { | |||
| 	// create wireguard device
 | ||||
| 
 | ||||
| 	device := NewDevice(tun, logger) | ||||
| 
 | ||||
| 	logger.Info.Println("Device started") | ||||
| 
 | ||||
| 	// start configuration lister
 | ||||
| 
 | ||||
| 	uapi, err := NewUAPIListener(interfaceName) | ||||
| 	if err != nil { | ||||
| 		logger.Error.Println("UAPI listen error:", err) | ||||
| 		return | ||||
| 	} | ||||
| 	// start uapi listener
 | ||||
| 
 | ||||
| 	errs := make(chan error) | ||||
| 	term := make(chan os.Signal) | ||||
| 	wait := device.WaitChannel() | ||||
| 
 | ||||
| 	uapi, err := UAPIListen(interfaceName, fileUAPI) | ||||
| 
 | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			conn, err := uapi.Accept() | ||||
|  | @ -161,9 +183,10 @@ func main() { | |||
| 	case <-errs: | ||||
| 	} | ||||
| 
 | ||||
| 	// clean up UAPI bind
 | ||||
| 	// clean up
 | ||||
| 
 | ||||
| 	uapi.Close() | ||||
| 	device.Close() | ||||
| 
 | ||||
| 	logger.Info.Println("Shutting down") | ||||
| } | ||||
|  |  | |||
|  | @ -227,7 +227,7 @@ func (tun *NativeTun) MTU() (int, error) { | |||
| 
 | ||||
| 	val := binary.LittleEndian.Uint32(ifr[16:20]) | ||||
| 	if val >= (1 << 31) { | ||||
| 		return int(val-(1<<31)) - (1 << 31), nil | ||||
| 		return int(toInt32(val)), nil | ||||
| 	} | ||||
| 	return int(val), nil | ||||
| } | ||||
|  |  | |||
|  | @ -10,12 +10,12 @@ import ( | |||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	ipcErrorIO         = -int64(unix.EIO) | ||||
| 	ipcErrorProtocol   = -int64(unix.EPROTO) | ||||
| 	ipcErrorInvalid    = -int64(unix.EINVAL) | ||||
| 	ipcErrorPortInUse  = -int64(unix.EADDRINUSE) | ||||
| 	socketDirectory    = "/var/run/wireguard" | ||||
| 	socketName         = "%s.sock" | ||||
| 	ipcErrorIO        = -int64(unix.EIO) | ||||
| 	ipcErrorProtocol  = -int64(unix.EPROTO) | ||||
| 	ipcErrorInvalid   = -int64(unix.EINVAL) | ||||
| 	ipcErrorPortInUse = -int64(unix.EADDRINUSE) | ||||
| 	socketDirectory   = "/var/run/wireguard" | ||||
| 	socketName        = "%s.sock" | ||||
| ) | ||||
| 
 | ||||
| type UAPIListener struct { | ||||
|  | @ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func connectUnixSocket(path string) (net.Listener, error) { | ||||
| func UAPIListen(name string, file *os.File) (net.Listener, error) { | ||||
| 
 | ||||
| 	// attempt inital connection
 | ||||
| 	// wrap file in listener
 | ||||
| 
 | ||||
| 	listener, err := net.Listen("unix", path) | ||||
| 	if err == nil { | ||||
| 		return listener, nil | ||||
| 	} | ||||
| 
 | ||||
| 	// check if active
 | ||||
| 
 | ||||
| 	_, err = net.Dial("unix", path) | ||||
| 	if err == nil { | ||||
| 		return nil, errors.New("Unix socket in use") | ||||
| 	} | ||||
| 
 | ||||
| 	// attempt cleanup
 | ||||
| 
 | ||||
| 	err = os.Remove(path) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return net.Listen("unix", path) | ||||
| } | ||||
| 
 | ||||
| func NewUAPIListener(name string) (net.Listener, error) { | ||||
| 
 | ||||
| 	// check if path exist
 | ||||
| 
 | ||||
| 	err := os.MkdirAll(socketDirectory, 077) | ||||
| 	if err != nil && !os.IsExist(err) { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// open UNIX socket
 | ||||
| 
 | ||||
| 	socketPath := path.Join( | ||||
| 		socketDirectory, | ||||
| 		fmt.Sprintf(socketName, name), | ||||
| 	) | ||||
| 
 | ||||
| 	listener, err := connectUnixSocket(socketPath) | ||||
| 	listener, err := net.FileListener(file) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | @ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) { | |||
| 
 | ||||
| 	// watch for deletion of socket
 | ||||
| 
 | ||||
| 	socketPath := path.Join( | ||||
| 		socketDirectory, | ||||
| 		fmt.Sprintf(socketName, name), | ||||
| 	) | ||||
| 
 | ||||
| 	uapi.inotifyFd, err = unix.InotifyInit() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
|  | @ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) { | |||
| 	go func(l *UAPIListener) { | ||||
| 		var buff [4096]byte | ||||
| 		for { | ||||
| 			unix.Read(uapi.inotifyFd, buff[:]) | ||||
| 			// start with lstat to avoid race condition
 | ||||
| 			if _, err := os.Lstat(socketPath); os.IsNotExist(err) { | ||||
| 				l.connErr <- err | ||||
| 				return | ||||
| 			} | ||||
| 			unix.Read(uapi.inotifyFd, buff[:]) | ||||
| 		} | ||||
| 	}(uapi) | ||||
| 
 | ||||
|  | @ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) { | |||
| 
 | ||||
| 	return uapi, nil | ||||
| } | ||||
| 
 | ||||
| func UAPIOpen(name string) (*os.File, error) { | ||||
| 
 | ||||
| 	// check if path exist
 | ||||
| 
 | ||||
| 	err := os.MkdirAll(socketDirectory, 0600) | ||||
| 	if err != nil && !os.IsExist(err) { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// open UNIX socket
 | ||||
| 
 | ||||
| 	socketPath := path.Join( | ||||
| 		socketDirectory, | ||||
| 		fmt.Sprintf(socketName, name), | ||||
| 	) | ||||
| 
 | ||||
| 	addr, err := net.ResolveUnixAddr("unix", socketPath) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	listener, err := func() (*net.UnixListener, error) { | ||||
| 
 | ||||
| 		// initial connection attempt
 | ||||
| 
 | ||||
| 		listener, err := net.ListenUnix("unix", addr) | ||||
| 		if err == nil { | ||||
| 			return listener, nil | ||||
| 		} | ||||
| 
 | ||||
| 		// check if socket already active
 | ||||
| 
 | ||||
| 		_, err = net.Dial("unix", socketPath) | ||||
| 		if err == nil { | ||||
| 			return nil, errors.New("unix socket in use") | ||||
| 		} | ||||
| 
 | ||||
| 		// cleanup & attempt again
 | ||||
| 
 | ||||
| 		err = os.Remove(socketPath) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		return net.ListenUnix("unix", addr) | ||||
| 	}() | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return listener.File() | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue