device: split IpcSetOperation into parts

The goal of this change is to make the structure
of IpcSetOperation easier to follow.

IpcSetOperation contains a small state machine:
It starts by configuring the device,
then shifts to configuring one peer at a time.

Having the code all in one giant method obscured that structure.
Split out the parts into helper functions and encapsulate the peer state.

This makes the overall structure more apparent.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
Josh Bleecher Snyder 2021-01-15 14:32:34 -08:00
parent a029b942ae
commit 6252de0db9

View file

@ -41,6 +41,8 @@ func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError {
return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
}
// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
func (device *Device) IpcGetOperation(w io.Writer) error {
lines := make([]string, 0, 100)
send := func(line string) {
@ -116,6 +118,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
return nil
}
// IpcSetOperation implements the WireGuard configuration protocol "set" operation.
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
func (device *Device) IpcSetOperation(r io.Reader) (err error) {
defer func() {
if err != nil {
@ -123,20 +127,14 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
}
}()
logDebug := device.log.Debug
var peer *Peer
dummy := false
createdNewPeer := false
peer := new(ipcSetPeer)
deviceConfig := true
scanner := bufio.NewScanner(r)
for scanner.Scan() {
// parse line
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
return nil
}
parts := strings.Split(line, "=")
@ -146,10 +144,34 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
key := parts[0]
value := parts[1]
/* device configuration */
if key == "public_key" {
if deviceConfig {
device.log.Debug.Println("UAPI: Transition to peer configuration")
deviceConfig = false
}
// Load/create the peer we are now configuring.
err := device.handlePublicKeyLine(peer, value)
if err != nil {
return err
}
continue
}
var err error
if deviceConfig {
err = device.handleDeviceLine(key, value)
} else {
err = device.handlePeerLine(peer, key, value)
}
if err != nil {
return err
}
}
return scanner.Err()
}
func (device *Device) handleDeviceLine(key, value string) error {
switch key {
case "private_key":
var sk NoisePrivateKey
@ -157,21 +179,17 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
}
logDebug.Println("UAPI: Updating private key")
device.log.Debug.Println("UAPI: Updating private key")
device.SetPrivateKey(sk)
case "listen_port":
// parse port number
port, err := strconv.ParseUint(value, 10, 16)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
}
// update port and rebind
logDebug.Println("UAPI: Updating listen port")
device.log.Debug.Println("UAPI: Updating listen port")
device.net.Lock()
device.net.port = uint16(port)
@ -182,9 +200,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
}
case "fwmark":
// parse fwmark field
fwmark, err := func() (uint32, error) {
if value == "" {
return 0, nil
@ -197,95 +213,90 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
}
logDebug.Println("UAPI: Updating fwmark")
device.log.Debug.Println("UAPI: Updating fwmark")
if err := device.BindSetMark(uint32(fwmark)); err != nil {
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
}
case "public_key":
// switch to peer configuration
logDebug.Println("UAPI: Transition to peer configuration")
deviceConfig = false
case "replace_peers":
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
}
logDebug.Println("UAPI: Removing all peers")
device.log.Debug.Println("UAPI: Removing all peers")
device.RemoveAllPeers()
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
}
}
/* peer configuration */
return nil
}
if !deviceConfig {
// An ipcSetPeer is the current state of an IPC set operation on a peer.
type ipcSetPeer struct {
*Peer // Peer is the current peer being operated on
dummy bool // dummy reports whether this peer is a temporary, placeholder peer
created bool // new reports whether this is a newly created peer
}
switch key {
case "public_key":
func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
// Load/create the peer we are configuring.
var publicKey NoisePublicKey
err := publicKey.FromHex(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
}
// ignore peer with public key of device
// Ignore peer with the same public key as this device.
device.staticIdentity.RLock()
dummy = device.staticIdentity.publicKey.Equals(publicKey)
peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
device.staticIdentity.RUnlock()
if dummy {
peer = &Peer{}
if peer.dummy {
peer.Peer = &Peer{}
} else {
peer = device.LookupPeer(publicKey)
peer.Peer = device.LookupPeer(publicKey)
}
createdNewPeer = peer == nil
if createdNewPeer {
peer, err = device.NewPeer(publicKey)
peer.created = peer.Peer == nil
if peer.created {
peer.Peer, err = device.NewPeer(publicKey)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
}
logDebug.Println(peer, "- UAPI: Created")
device.log.Debug.Println(peer, "- UAPI: Created")
}
return nil
}
func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
switch key {
case "update_only":
// allow disabling of creation
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
}
if createdNewPeer && !dummy {
if peer.created && !peer.dummy {
device.RemovePeer(peer.handshake.remoteStatic)
peer = &Peer{}
dummy = true
peer.Peer = &Peer{}
peer.dummy = true
}
case "remove":
// remove currently selected peer from device
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
}
if !dummy {
logDebug.Println(peer, "- UAPI: Removing")
if !peer.dummy {
device.log.Debug.Println(peer, "- UAPI: Removing")
device.RemovePeer(peer.handshake.remoteStatic)
}
peer = &Peer{}
dummy = true
peer.Peer = &Peer{}
peer.dummy = true
case "preshared_key":
// update PSK
logDebug.Println(peer, "- UAPI: Updating preshared key")
device.log.Debug.Println(peer, "- UAPI: Updating preshared key")
peer.handshake.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value)
@ -296,10 +307,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
}
case "endpoint":
// set endpoint destination
logDebug.Println(peer, "- UAPI: Updating endpoint")
device.log.Debug.Println(peer, "- UAPI: Updating endpoint")
err := func() error {
peer.Lock()
@ -317,10 +325,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
}
case "persistent_keepalive_interval":
// update persistent keepalive interval
logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval")
device.log.Debug.Println(peer, "- UAPI: Updating persistent keepalive interval")
secs, err := strconv.ParseUint(value, 10, 16)
if err != nil {
@ -329,49 +334,40 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
// send immediate keepalive if we're turning it on and before it wasn't on
// Send immediate keepalive if we're turning it on and before it wasn't on.
if old == 0 && secs != 0 {
if err != nil {
return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err)
}
if device.isUp.Get() && !dummy {
if device.isUp.Get() && !peer.dummy {
peer.SendKeepalive()
}
}
case "replace_allowed_ips":
logDebug.Println(peer, "- UAPI: Removing all allowedips")
device.log.Debug.Println(peer, "- UAPI: Removing all allowedips")
if value != "true" {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
}
if dummy {
continue
if peer.dummy {
return nil
}
device.allowedips.RemoveByPeer(peer)
device.allowedips.RemoveByPeer(peer.Peer)
case "allowed_ip":
logDebug.Println(peer, "- UAPI: Adding allowedip")
device.log.Debug.Println(peer, "- UAPI: Adding allowedip")
_, network, err := net.ParseCIDR(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
}
if dummy {
continue
if peer.dummy {
return nil
}
ones, _ := network.Mask.Size()
device.allowedips.Insert(network.IP, uint(ones), peer)
device.allowedips.Insert(network.IP, uint(ones), peer.Peer)
case "protocol_version":
if value != "1" {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
}
@ -379,10 +375,8 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
}
}
}
return scanner.Err()
return nil
}
func (device *Device) IpcGet() (string, error) {