device: split IpcSetOperation into parts
The goal of this change is to make the structure of IpcSetOperation easier to follow. IpcSetOperation contains a small state machine: It starts by configuring the device, then shifts to configuring one peer at a time. Having the code all in one giant method obscured that structure. Split out the parts into helper functions and encapsulate the peer state. This makes the overall structure more apparent. Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
parent
a029b942ae
commit
6252de0db9
166
device/uapi.go
166
device/uapi.go
|
@ -41,6 +41,8 @@ func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError {
|
|||
return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
|
||||
}
|
||||
|
||||
// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
|
||||
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
|
||||
func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||
lines := make([]string, 0, 100)
|
||||
send := func(line string) {
|
||||
|
@ -116,6 +118,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// IpcSetOperation implements the WireGuard configuration protocol "set" operation.
|
||||
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
|
||||
func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
|
@ -123,20 +127,14 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
}
|
||||
}()
|
||||
|
||||
logDebug := device.log.Debug
|
||||
|
||||
var peer *Peer
|
||||
dummy := false
|
||||
createdNewPeer := false
|
||||
peer := new(ipcSetPeer)
|
||||
deviceConfig := true
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
|
||||
// parse line
|
||||
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
// Blank line means terminate operation.
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(line, "=")
|
||||
|
@ -146,10 +144,34 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
key := parts[0]
|
||||
value := parts[1]
|
||||
|
||||
/* device configuration */
|
||||
|
||||
if key == "public_key" {
|
||||
if deviceConfig {
|
||||
device.log.Debug.Println("UAPI: Transition to peer configuration")
|
||||
deviceConfig = false
|
||||
}
|
||||
// Load/create the peer we are now configuring.
|
||||
err := device.handlePublicKeyLine(peer, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
var err error
|
||||
if deviceConfig {
|
||||
err = device.handleDeviceLine(key, value)
|
||||
} else {
|
||||
err = device.handlePeerLine(peer, key, value)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func (device *Device) handleDeviceLine(key, value string) error {
|
||||
switch key {
|
||||
case "private_key":
|
||||
var sk NoisePrivateKey
|
||||
|
@ -157,21 +179,17 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
|
||||
}
|
||||
logDebug.Println("UAPI: Updating private key")
|
||||
device.log.Debug.Println("UAPI: Updating private key")
|
||||
device.SetPrivateKey(sk)
|
||||
|
||||
case "listen_port":
|
||||
|
||||
// parse port number
|
||||
|
||||
port, err := strconv.ParseUint(value, 10, 16)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
|
||||
}
|
||||
|
||||
// update port and rebind
|
||||
|
||||
logDebug.Println("UAPI: Updating listen port")
|
||||
device.log.Debug.Println("UAPI: Updating listen port")
|
||||
|
||||
device.net.Lock()
|
||||
device.net.port = uint16(port)
|
||||
|
@ -182,9 +200,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
}
|
||||
|
||||
case "fwmark":
|
||||
|
||||
// parse fwmark field
|
||||
|
||||
fwmark, err := func() (uint32, error) {
|
||||
if value == "" {
|
||||
return 0, nil
|
||||
|
@ -197,95 +213,90 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
|
||||
}
|
||||
|
||||
logDebug.Println("UAPI: Updating fwmark")
|
||||
device.log.Debug.Println("UAPI: Updating fwmark")
|
||||
|
||||
if err := device.BindSetMark(uint32(fwmark)); err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
|
||||
}
|
||||
|
||||
case "public_key":
|
||||
// switch to peer configuration
|
||||
logDebug.Println("UAPI: Transition to peer configuration")
|
||||
deviceConfig = false
|
||||
|
||||
case "replace_peers":
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
|
||||
}
|
||||
logDebug.Println("UAPI: Removing all peers")
|
||||
device.log.Debug.Println("UAPI: Removing all peers")
|
||||
device.RemoveAllPeers()
|
||||
|
||||
default:
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
||||
}
|
||||
}
|
||||
|
||||
/* peer configuration */
|
||||
return nil
|
||||
}
|
||||
|
||||
if !deviceConfig {
|
||||
// An ipcSetPeer is the current state of an IPC set operation on a peer.
|
||||
type ipcSetPeer struct {
|
||||
*Peer // Peer is the current peer being operated on
|
||||
dummy bool // dummy reports whether this peer is a temporary, placeholder peer
|
||||
created bool // new reports whether this is a newly created peer
|
||||
}
|
||||
|
||||
switch key {
|
||||
|
||||
case "public_key":
|
||||
func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
|
||||
// Load/create the peer we are configuring.
|
||||
var publicKey NoisePublicKey
|
||||
err := publicKey.FromHex(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
|
||||
}
|
||||
|
||||
// ignore peer with public key of device
|
||||
|
||||
// Ignore peer with the same public key as this device.
|
||||
device.staticIdentity.RLock()
|
||||
dummy = device.staticIdentity.publicKey.Equals(publicKey)
|
||||
peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
|
||||
device.staticIdentity.RUnlock()
|
||||
|
||||
if dummy {
|
||||
peer = &Peer{}
|
||||
if peer.dummy {
|
||||
peer.Peer = &Peer{}
|
||||
} else {
|
||||
peer = device.LookupPeer(publicKey)
|
||||
peer.Peer = device.LookupPeer(publicKey)
|
||||
}
|
||||
|
||||
createdNewPeer = peer == nil
|
||||
if createdNewPeer {
|
||||
peer, err = device.NewPeer(publicKey)
|
||||
peer.created = peer.Peer == nil
|
||||
if peer.created {
|
||||
peer.Peer, err = device.NewPeer(publicKey)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
|
||||
}
|
||||
logDebug.Println(peer, "- UAPI: Created")
|
||||
device.log.Debug.Println(peer, "- UAPI: Created")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
|
||||
switch key {
|
||||
case "update_only":
|
||||
|
||||
// allow disabling of creation
|
||||
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
|
||||
}
|
||||
if createdNewPeer && !dummy {
|
||||
if peer.created && !peer.dummy {
|
||||
device.RemovePeer(peer.handshake.remoteStatic)
|
||||
peer = &Peer{}
|
||||
dummy = true
|
||||
peer.Peer = &Peer{}
|
||||
peer.dummy = true
|
||||
}
|
||||
|
||||
case "remove":
|
||||
|
||||
// remove currently selected peer from device
|
||||
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
|
||||
}
|
||||
if !dummy {
|
||||
logDebug.Println(peer, "- UAPI: Removing")
|
||||
if !peer.dummy {
|
||||
device.log.Debug.Println(peer, "- UAPI: Removing")
|
||||
device.RemovePeer(peer.handshake.remoteStatic)
|
||||
}
|
||||
peer = &Peer{}
|
||||
dummy = true
|
||||
peer.Peer = &Peer{}
|
||||
peer.dummy = true
|
||||
|
||||
case "preshared_key":
|
||||
|
||||
// update PSK
|
||||
|
||||
logDebug.Println(peer, "- UAPI: Updating preshared key")
|
||||
device.log.Debug.Println(peer, "- UAPI: Updating preshared key")
|
||||
|
||||
peer.handshake.mutex.Lock()
|
||||
err := peer.handshake.presharedKey.FromHex(value)
|
||||
|
@ -296,10 +307,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
}
|
||||
|
||||
case "endpoint":
|
||||
|
||||
// set endpoint destination
|
||||
|
||||
logDebug.Println(peer, "- UAPI: Updating endpoint")
|
||||
device.log.Debug.Println(peer, "- UAPI: Updating endpoint")
|
||||
|
||||
err := func() error {
|
||||
peer.Lock()
|
||||
|
@ -317,10 +325,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
}
|
||||
|
||||
case "persistent_keepalive_interval":
|
||||
|
||||
// update persistent keepalive interval
|
||||
|
||||
logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval")
|
||||
device.log.Debug.Println(peer, "- UAPI: Updating persistent keepalive interval")
|
||||
|
||||
secs, err := strconv.ParseUint(value, 10, 16)
|
||||
if err != nil {
|
||||
|
@ -329,49 +334,40 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
|
||||
old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
|
||||
|
||||
// send immediate keepalive if we're turning it on and before it wasn't on
|
||||
|
||||
// Send immediate keepalive if we're turning it on and before it wasn't on.
|
||||
if old == 0 && secs != 0 {
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err)
|
||||
}
|
||||
if device.isUp.Get() && !dummy {
|
||||
if device.isUp.Get() && !peer.dummy {
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
}
|
||||
|
||||
case "replace_allowed_ips":
|
||||
|
||||
logDebug.Println(peer, "- UAPI: Removing all allowedips")
|
||||
|
||||
device.log.Debug.Println(peer, "- UAPI: Removing all allowedips")
|
||||
if value != "true" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
|
||||
}
|
||||
|
||||
if dummy {
|
||||
continue
|
||||
if peer.dummy {
|
||||
return nil
|
||||
}
|
||||
|
||||
device.allowedips.RemoveByPeer(peer)
|
||||
device.allowedips.RemoveByPeer(peer.Peer)
|
||||
|
||||
case "allowed_ip":
|
||||
|
||||
logDebug.Println(peer, "- UAPI: Adding allowedip")
|
||||
device.log.Debug.Println(peer, "- UAPI: Adding allowedip")
|
||||
|
||||
_, network, err := net.ParseCIDR(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
||||
}
|
||||
|
||||
if dummy {
|
||||
continue
|
||||
if peer.dummy {
|
||||
return nil
|
||||
}
|
||||
|
||||
ones, _ := network.Mask.Size()
|
||||
device.allowedips.Insert(network.IP, uint(ones), peer)
|
||||
device.allowedips.Insert(network.IP, uint(ones), peer.Peer)
|
||||
|
||||
case "protocol_version":
|
||||
|
||||
if value != "1" {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
|
||||
}
|
||||
|
@ -379,10 +375,8 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|||
default:
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) IpcGet() (string, error) {
|
||||
|
|
Loading…
Reference in a new issue