device: split IpcSetOperation into parts

The goal of this change is to make the structure
of IpcSetOperation easier to follow.

IpcSetOperation contains a small state machine:
It starts by configuring the device,
then shifts to configuring one peer at a time.

Having the code all in one giant method obscured that structure.
Split out the parts into helper functions and encapsulate the peer state.

This makes the overall structure more apparent.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
Josh Bleecher Snyder 2021-01-15 14:32:34 -08:00
parent a029b942ae
commit 6252de0db9

View file

@ -41,6 +41,8 @@ func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError {
return &IPCError{code: code, err: fmt.Errorf(msg, args...)} return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
} }
// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
func (device *Device) IpcGetOperation(w io.Writer) error { func (device *Device) IpcGetOperation(w io.Writer) error {
lines := make([]string, 0, 100) lines := make([]string, 0, 100)
send := func(line string) { send := func(line string) {
@ -116,6 +118,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
return nil return nil
} }
// IpcSetOperation implements the WireGuard configuration protocol "set" operation.
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
func (device *Device) IpcSetOperation(r io.Reader) (err error) { func (device *Device) IpcSetOperation(r io.Reader) (err error) {
defer func() { defer func() {
if err != nil { if err != nil {
@ -123,20 +127,14 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
} }
}() }()
logDebug := device.log.Debug peer := new(ipcSetPeer)
var peer *Peer
dummy := false
createdNewPeer := false
deviceConfig := true deviceConfig := true
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
for scanner.Scan() { for scanner.Scan() {
// parse line
line := scanner.Text() line := scanner.Text()
if line == "" { if line == "" {
// Blank line means terminate operation.
return nil return nil
} }
parts := strings.Split(line, "=") parts := strings.Split(line, "=")
@ -146,10 +144,34 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
key := parts[0] key := parts[0]
value := parts[1] value := parts[1]
/* device configuration */ if key == "public_key" {
if deviceConfig { if deviceConfig {
device.log.Debug.Println("UAPI: Transition to peer configuration")
deviceConfig = false
}
// Load/create the peer we are now configuring.
err := device.handlePublicKeyLine(peer, value)
if err != nil {
return err
}
continue
}
var err error
if deviceConfig {
err = device.handleDeviceLine(key, value)
} else {
err = device.handlePeerLine(peer, key, value)
}
if err != nil {
return err
}
}
return scanner.Err()
}
func (device *Device) handleDeviceLine(key, value string) error {
switch key { switch key {
case "private_key": case "private_key":
var sk NoisePrivateKey var sk NoisePrivateKey
@ -157,21 +179,17 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
} }
logDebug.Println("UAPI: Updating private key") device.log.Debug.Println("UAPI: Updating private key")
device.SetPrivateKey(sk) device.SetPrivateKey(sk)
case "listen_port": case "listen_port":
// parse port number
port, err := strconv.ParseUint(value, 10, 16) port, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
} }
// update port and rebind // update port and rebind
device.log.Debug.Println("UAPI: Updating listen port")
logDebug.Println("UAPI: Updating listen port")
device.net.Lock() device.net.Lock()
device.net.port = uint16(port) device.net.port = uint16(port)
@ -182,9 +200,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
} }
case "fwmark": case "fwmark":
// parse fwmark field // parse fwmark field
fwmark, err := func() (uint32, error) { fwmark, err := func() (uint32, error) {
if value == "" { if value == "" {
return 0, nil return 0, nil
@ -197,95 +213,90 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
} }
logDebug.Println("UAPI: Updating fwmark") device.log.Debug.Println("UAPI: Updating fwmark")
if err := device.BindSetMark(uint32(fwmark)); err != nil { if err := device.BindSetMark(uint32(fwmark)); err != nil {
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err) return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
} }
case "public_key":
// switch to peer configuration
logDebug.Println("UAPI: Transition to peer configuration")
deviceConfig = false
case "replace_peers": case "replace_peers":
if value != "true" { if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
} }
logDebug.Println("UAPI: Removing all peers") device.log.Debug.Println("UAPI: Removing all peers")
device.RemoveAllPeers() device.RemoveAllPeers()
default: default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
} }
return nil
} }
/* peer configuration */ // An ipcSetPeer is the current state of an IPC set operation on a peer.
type ipcSetPeer struct {
*Peer // Peer is the current peer being operated on
dummy bool // dummy reports whether this peer is a temporary, placeholder peer
created bool // new reports whether this is a newly created peer
}
if !deviceConfig { func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
// Load/create the peer we are configuring.
switch key {
case "public_key":
var publicKey NoisePublicKey var publicKey NoisePublicKey
err := publicKey.FromHex(value) err := publicKey.FromHex(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
} }
// ignore peer with public key of device // Ignore peer with the same public key as this device.
device.staticIdentity.RLock() device.staticIdentity.RLock()
dummy = device.staticIdentity.publicKey.Equals(publicKey) peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
device.staticIdentity.RUnlock() device.staticIdentity.RUnlock()
if dummy { if peer.dummy {
peer = &Peer{} peer.Peer = &Peer{}
} else { } else {
peer = device.LookupPeer(publicKey) peer.Peer = device.LookupPeer(publicKey)
} }
createdNewPeer = peer == nil peer.created = peer.Peer == nil
if createdNewPeer { if peer.created {
peer, err = device.NewPeer(publicKey) peer.Peer, err = device.NewPeer(publicKey)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
} }
logDebug.Println(peer, "- UAPI: Created") device.log.Debug.Println(peer, "- UAPI: Created")
}
return nil
} }
func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
switch key {
case "update_only": case "update_only":
// allow disabling of creation // allow disabling of creation
if value != "true" { if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
} }
if createdNewPeer && !dummy { if peer.created && !peer.dummy {
device.RemovePeer(peer.handshake.remoteStatic) device.RemovePeer(peer.handshake.remoteStatic)
peer = &Peer{} peer.Peer = &Peer{}
dummy = true peer.dummy = true
} }
case "remove": case "remove":
// remove currently selected peer from device // remove currently selected peer from device
if value != "true" { if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
} }
if !dummy { if !peer.dummy {
logDebug.Println(peer, "- UAPI: Removing") device.log.Debug.Println(peer, "- UAPI: Removing")
device.RemovePeer(peer.handshake.remoteStatic) device.RemovePeer(peer.handshake.remoteStatic)
} }
peer = &Peer{} peer.Peer = &Peer{}
dummy = true peer.dummy = true
case "preshared_key": case "preshared_key":
device.log.Debug.Println(peer, "- UAPI: Updating preshared key")
// update PSK
logDebug.Println(peer, "- UAPI: Updating preshared key")
peer.handshake.mutex.Lock() peer.handshake.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value) err := peer.handshake.presharedKey.FromHex(value)
@ -296,10 +307,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
} }
case "endpoint": case "endpoint":
device.log.Debug.Println(peer, "- UAPI: Updating endpoint")
// set endpoint destination
logDebug.Println(peer, "- UAPI: Updating endpoint")
err := func() error { err := func() error {
peer.Lock() peer.Lock()
@ -317,10 +325,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
} }
case "persistent_keepalive_interval": case "persistent_keepalive_interval":
device.log.Debug.Println(peer, "- UAPI: Updating persistent keepalive interval")
// update persistent keepalive interval
logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval")
secs, err := strconv.ParseUint(value, 10, 16) secs, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {
@ -329,49 +334,40 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs)) old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
// send immediate keepalive if we're turning it on and before it wasn't on // Send immediate keepalive if we're turning it on and before it wasn't on.
if old == 0 && secs != 0 { if old == 0 && secs != 0 {
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err) return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err)
} }
if device.isUp.Get() && !dummy { if device.isUp.Get() && !peer.dummy {
peer.SendKeepalive() peer.SendKeepalive()
} }
} }
case "replace_allowed_ips": case "replace_allowed_ips":
device.log.Debug.Println(peer, "- UAPI: Removing all allowedips")
logDebug.Println(peer, "- UAPI: Removing all allowedips")
if value != "true" { if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
} }
if peer.dummy {
if dummy { return nil
continue
} }
device.allowedips.RemoveByPeer(peer.Peer)
device.allowedips.RemoveByPeer(peer)
case "allowed_ip": case "allowed_ip":
device.log.Debug.Println(peer, "- UAPI: Adding allowedip")
logDebug.Println(peer, "- UAPI: Adding allowedip")
_, network, err := net.ParseCIDR(value) _, network, err := net.ParseCIDR(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
} }
if peer.dummy {
if dummy { return nil
continue
} }
ones, _ := network.Mask.Size() ones, _ := network.Mask.Size()
device.allowedips.Insert(network.IP, uint(ones), peer) device.allowedips.Insert(network.IP, uint(ones), peer.Peer)
case "protocol_version": case "protocol_version":
if value != "1" { if value != "1" {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
} }
@ -379,10 +375,8 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
default: default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key) return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
} }
}
}
return scanner.Err() return nil
} }
func (device *Device) IpcGet() (string, error) { func (device *Device) IpcGet() (string, error) {