diff --git a/device/uapi.go b/device/uapi.go index 7f50869..7d180bb 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -41,6 +41,8 @@ func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError { return &IPCError{code: code, err: fmt.Errorf(msg, args...)} } +// IpcGetOperation implements the WireGuard configuration protocol "get" operation. +// See https://www.wireguard.com/xplatform/#configuration-protocol for details. func (device *Device) IpcGetOperation(w io.Writer) error { lines := make([]string, 0, 100) send := func(line string) { @@ -116,6 +118,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error { return nil } +// IpcSetOperation implements the WireGuard configuration protocol "set" operation. +// See https://www.wireguard.com/xplatform/#configuration-protocol for details. func (device *Device) IpcSetOperation(r io.Reader) (err error) { defer func() { if err != nil { @@ -123,20 +127,14 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { } }() - logDebug := device.log.Debug - - var peer *Peer - dummy := false - createdNewPeer := false + peer := new(ipcSetPeer) deviceConfig := true scanner := bufio.NewScanner(r) for scanner.Scan() { - - // parse line - line := scanner.Text() if line == "" { + // Blank line means terminate operation. return nil } parts := strings.Split(line, "=") @@ -146,245 +144,241 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { key := parts[0] value := parts[1] - /* device configuration */ - - if deviceConfig { - - switch key { - case "private_key": - var sk NoisePrivateKey - err := sk.FromMaybeZeroHex(value) - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err) - } - logDebug.Println("UAPI: Updating private key") - device.SetPrivateKey(sk) - - case "listen_port": - - // parse port number - - port, err := strconv.ParseUint(value, 10, 16) - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) - } - - // update port and rebind - - logDebug.Println("UAPI: Updating listen port") - - device.net.Lock() - device.net.port = uint16(port) - device.net.Unlock() - - if err := device.BindUpdate(); err != nil { - return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err) - } - - case "fwmark": - - // parse fwmark field - - fwmark, err := func() (uint32, error) { - if value == "" { - return 0, nil - } - mark, err := strconv.ParseUint(value, 10, 32) - return uint32(mark), err - }() - - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err) - } - - logDebug.Println("UAPI: Updating fwmark") - - if err := device.BindSetMark(uint32(fwmark)); err != nil { - return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err) - } - - case "public_key": - // switch to peer configuration - logDebug.Println("UAPI: Transition to peer configuration") + if key == "public_key" { + if deviceConfig { + device.log.Debug.Println("UAPI: Transition to peer configuration") deviceConfig = false - - case "replace_peers": - if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) - } - logDebug.Println("UAPI: Removing all peers") - device.RemoveAllPeers() - - default: - return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) } + // Load/create the peer we are now configuring. + err := device.handlePublicKeyLine(peer, value) + if err != nil { + return err + } + continue } - /* peer configuration */ - - if !deviceConfig { - - switch key { - - case "public_key": - var publicKey NoisePublicKey - err := publicKey.FromHex(value) - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err) - } - - // ignore peer with public key of device - - device.staticIdentity.RLock() - dummy = device.staticIdentity.publicKey.Equals(publicKey) - device.staticIdentity.RUnlock() - - if dummy { - peer = &Peer{} - } else { - peer = device.LookupPeer(publicKey) - } - - createdNewPeer = peer == nil - if createdNewPeer { - peer, err = device.NewPeer(publicKey) - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err) - } - logDebug.Println(peer, "- UAPI: Created") - } - - case "update_only": - - // allow disabling of creation - - if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) - } - if createdNewPeer && !dummy { - device.RemovePeer(peer.handshake.remoteStatic) - peer = &Peer{} - dummy = true - } - - case "remove": - - // remove currently selected peer from device - - if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value) - } - if !dummy { - logDebug.Println(peer, "- UAPI: Removing") - device.RemovePeer(peer.handshake.remoteStatic) - } - peer = &Peer{} - dummy = true - - case "preshared_key": - - // update PSK - - logDebug.Println(peer, "- UAPI: Updating preshared key") - - peer.handshake.mutex.Lock() - err := peer.handshake.presharedKey.FromHex(value) - peer.handshake.mutex.Unlock() - - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err) - } - - case "endpoint": - - // set endpoint destination - - logDebug.Println(peer, "- UAPI: Updating endpoint") - - err := func() error { - peer.Lock() - defer peer.Unlock() - endpoint, err := conn.CreateEndpoint(value) - if err != nil { - return err - } - peer.endpoint = endpoint - return nil - }() - - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) - } - - case "persistent_keepalive_interval": - - // update persistent keepalive interval - - logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval") - - secs, err := strconv.ParseUint(value, 10, 16) - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) - } - - old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs)) - - // send immediate keepalive if we're turning it on and before it wasn't on - - if old == 0 && secs != 0 { - if err != nil { - return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err) - } - if device.isUp.Get() && !dummy { - peer.SendKeepalive() - } - } - - case "replace_allowed_ips": - - logDebug.Println(peer, "- UAPI: Removing all allowedips") - - if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) - } - - if dummy { - continue - } - - device.allowedips.RemoveByPeer(peer) - - case "allowed_ip": - - logDebug.Println(peer, "- UAPI: Adding allowedip") - - _, network, err := net.ParseCIDR(value) - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) - } - - if dummy { - continue - } - - ones, _ := network.Mask.Size() - device.allowedips.Insert(network.IP, uint(ones), peer) - - case "protocol_version": - - if value != "1" { - return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) - } - - default: - return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key) - } + var err error + if deviceConfig { + err = device.handleDeviceLine(key, value) + } else { + err = device.handlePeerLine(peer, key, value) + } + if err != nil { + return err } } return scanner.Err() } +func (device *Device) handleDeviceLine(key, value string) error { + switch key { + case "private_key": + var sk NoisePrivateKey + err := sk.FromMaybeZeroHex(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err) + } + device.log.Debug.Println("UAPI: Updating private key") + device.SetPrivateKey(sk) + + case "listen_port": + port, err := strconv.ParseUint(value, 10, 16) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) + } + + // update port and rebind + device.log.Debug.Println("UAPI: Updating listen port") + + device.net.Lock() + device.net.port = uint16(port) + device.net.Unlock() + + if err := device.BindUpdate(); err != nil { + return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err) + } + + case "fwmark": + // parse fwmark field + fwmark, err := func() (uint32, error) { + if value == "" { + return 0, nil + } + mark, err := strconv.ParseUint(value, 10, 32) + return uint32(mark), err + }() + + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err) + } + + device.log.Debug.Println("UAPI: Updating fwmark") + + if err := device.BindSetMark(uint32(fwmark)); err != nil { + return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err) + } + + case "replace_peers": + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) + } + device.log.Debug.Println("UAPI: Removing all peers") + device.RemoveAllPeers() + + default: + return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) + } + + return nil +} + +// An ipcSetPeer is the current state of an IPC set operation on a peer. +type ipcSetPeer struct { + *Peer // Peer is the current peer being operated on + dummy bool // dummy reports whether this peer is a temporary, placeholder peer + created bool // new reports whether this is a newly created peer +} + +func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error { + // Load/create the peer we are configuring. + var publicKey NoisePublicKey + err := publicKey.FromHex(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err) + } + + // Ignore peer with the same public key as this device. + device.staticIdentity.RLock() + peer.dummy = device.staticIdentity.publicKey.Equals(publicKey) + device.staticIdentity.RUnlock() + + if peer.dummy { + peer.Peer = &Peer{} + } else { + peer.Peer = device.LookupPeer(publicKey) + } + + peer.created = peer.Peer == nil + if peer.created { + peer.Peer, err = device.NewPeer(publicKey) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err) + } + device.log.Debug.Println(peer, "- UAPI: Created") + } + return nil +} + +func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error { + switch key { + case "update_only": + // allow disabling of creation + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) + } + if peer.created && !peer.dummy { + device.RemovePeer(peer.handshake.remoteStatic) + peer.Peer = &Peer{} + peer.dummy = true + } + + case "remove": + // remove currently selected peer from device + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value) + } + if !peer.dummy { + device.log.Debug.Println(peer, "- UAPI: Removing") + device.RemovePeer(peer.handshake.remoteStatic) + } + peer.Peer = &Peer{} + peer.dummy = true + + case "preshared_key": + device.log.Debug.Println(peer, "- UAPI: Updating preshared key") + + peer.handshake.mutex.Lock() + err := peer.handshake.presharedKey.FromHex(value) + peer.handshake.mutex.Unlock() + + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err) + } + + case "endpoint": + device.log.Debug.Println(peer, "- UAPI: Updating endpoint") + + err := func() error { + peer.Lock() + defer peer.Unlock() + endpoint, err := conn.CreateEndpoint(value) + if err != nil { + return err + } + peer.endpoint = endpoint + return nil + }() + + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) + } + + case "persistent_keepalive_interval": + device.log.Debug.Println(peer, "- UAPI: Updating persistent keepalive interval") + + secs, err := strconv.ParseUint(value, 10, 16) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) + } + + old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs)) + + // Send immediate keepalive if we're turning it on and before it wasn't on. + if old == 0 && secs != 0 { + if err != nil { + return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err) + } + if device.isUp.Get() && !peer.dummy { + peer.SendKeepalive() + } + } + + case "replace_allowed_ips": + device.log.Debug.Println(peer, "- UAPI: Removing all allowedips") + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) + } + if peer.dummy { + return nil + } + device.allowedips.RemoveByPeer(peer.Peer) + + case "allowed_ip": + device.log.Debug.Println(peer, "- UAPI: Adding allowedip") + + _, network, err := net.ParseCIDR(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) + } + if peer.dummy { + return nil + } + ones, _ := network.Mask.Size() + device.allowedips.Insert(network.IP, uint(ones), peer.Peer) + + case "protocol_version": + if value != "1" { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) + } + + default: + return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key) + } + + return nil +} + func (device *Device) IpcGet() (string, error) { buf := new(strings.Builder) if err := device.IpcGetOperation(buf); err != nil {