device: expand IPCError
Expand IPCError to contain a wrapped error, and add a helper to make constructing such errors easier. Add a defer-based "log on returned error" to IpcSetOperation. This lets us simplify all of the error return paths. Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
parent
db3fa1409c
commit
a029b942ae
|
@ -21,15 +21,24 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type IPCError struct {
|
type IPCError struct {
|
||||||
int64
|
code int64 // error code
|
||||||
|
err error // underlying/wrapped error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s IPCError) Error() string {
|
func (s IPCError) Error() string {
|
||||||
return fmt.Sprintf("IPC error: %d", s.int64)
|
return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s IPCError) Unwrap() error {
|
||||||
|
return s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s IPCError) ErrorCode() int64 {
|
func (s IPCError) ErrorCode() int64 {
|
||||||
return s.int64
|
return s.code
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError {
|
||||||
|
return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) IpcGetOperation(w io.Writer) error {
|
func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||||
|
@ -100,24 +109,28 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
_, err := io.WriteString(w, line+"\n")
|
_, err := io.WriteString(w, line+"\n")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &IPCError{ipc.IpcErrorIO}
|
return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) IpcSetOperation(r io.Reader) error {
|
func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||||
scanner := bufio.NewScanner(r)
|
defer func() {
|
||||||
logError := device.log.Error
|
if err != nil {
|
||||||
|
device.log.Error.Println(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
|
|
||||||
var peer *Peer
|
var peer *Peer
|
||||||
|
|
||||||
dummy := false
|
dummy := false
|
||||||
createdNewPeer := false
|
createdNewPeer := false
|
||||||
deviceConfig := true
|
deviceConfig := true
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
|
|
||||||
// parse line
|
// parse line
|
||||||
|
@ -128,7 +141,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
}
|
}
|
||||||
parts := strings.Split(line, "=")
|
parts := strings.Split(line, "=")
|
||||||
if len(parts) != 2 {
|
if len(parts) != 2 {
|
||||||
return &IPCError{ipc.IpcErrorProtocol}
|
return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q, found %d =-separated parts, want 2", line, len(parts))
|
||||||
}
|
}
|
||||||
key := parts[0]
|
key := parts[0]
|
||||||
value := parts[1]
|
value := parts[1]
|
||||||
|
@ -142,8 +155,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
var sk NoisePrivateKey
|
var sk NoisePrivateKey
|
||||||
err := sk.FromMaybeZeroHex(value)
|
err := sk.FromMaybeZeroHex(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to set private_key:", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
logDebug.Println("UAPI: Updating private key")
|
logDebug.Println("UAPI: Updating private key")
|
||||||
device.SetPrivateKey(sk)
|
device.SetPrivateKey(sk)
|
||||||
|
@ -154,8 +166,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
|
|
||||||
port, err := strconv.ParseUint(value, 10, 16)
|
port, err := strconv.ParseUint(value, 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to parse listen_port:", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// update port and rebind
|
// update port and rebind
|
||||||
|
@ -167,8 +178,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
device.net.Unlock()
|
device.net.Unlock()
|
||||||
|
|
||||||
if err := device.BindUpdate(); err != nil {
|
if err := device.BindUpdate(); err != nil {
|
||||||
logError.Println("Failed to set listen_port:", err)
|
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
|
||||||
return &IPCError{ipc.IpcErrorPortInUse}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case "fwmark":
|
case "fwmark":
|
||||||
|
@ -184,15 +194,13 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Invalid fwmark", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logDebug.Println("UAPI: Updating fwmark")
|
logDebug.Println("UAPI: Updating fwmark")
|
||||||
|
|
||||||
if err := device.BindSetMark(uint32(fwmark)); err != nil {
|
if err := device.BindSetMark(uint32(fwmark)); err != nil {
|
||||||
logError.Println("Failed to update fwmark:", err)
|
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
|
||||||
return &IPCError{ipc.IpcErrorPortInUse}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case "public_key":
|
case "public_key":
|
||||||
|
@ -202,15 +210,13 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
|
|
||||||
case "replace_peers":
|
case "replace_peers":
|
||||||
if value != "true" {
|
if value != "true" {
|
||||||
logError.Println("Failed to set replace_peers, invalid value:", value)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
logDebug.Println("UAPI: Removing all peers")
|
logDebug.Println("UAPI: Removing all peers")
|
||||||
device.RemoveAllPeers()
|
device.RemoveAllPeers()
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logError.Println("Invalid UAPI device key:", key)
|
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,8 +230,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
var publicKey NoisePublicKey
|
var publicKey NoisePublicKey
|
||||||
err := publicKey.FromHex(value)
|
err := publicKey.FromHex(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to get peer by public key:", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ignore peer with public key of device
|
// ignore peer with public key of device
|
||||||
|
@ -244,8 +249,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
if createdNewPeer {
|
if createdNewPeer {
|
||||||
peer, err = device.NewPeer(publicKey)
|
peer, err = device.NewPeer(publicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to create new peer:", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
logDebug.Println(peer, "- UAPI: Created")
|
logDebug.Println(peer, "- UAPI: Created")
|
||||||
}
|
}
|
||||||
|
@ -255,8 +259,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
// allow disabling of creation
|
// allow disabling of creation
|
||||||
|
|
||||||
if value != "true" {
|
if value != "true" {
|
||||||
logError.Println("Failed to set update only, invalid value:", value)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
if createdNewPeer && !dummy {
|
if createdNewPeer && !dummy {
|
||||||
device.RemovePeer(peer.handshake.remoteStatic)
|
device.RemovePeer(peer.handshake.remoteStatic)
|
||||||
|
@ -269,8 +272,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
// remove currently selected peer from device
|
// remove currently selected peer from device
|
||||||
|
|
||||||
if value != "true" {
|
if value != "true" {
|
||||||
logError.Println("Failed to set remove, invalid value:", value)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
if !dummy {
|
if !dummy {
|
||||||
logDebug.Println(peer, "- UAPI: Removing")
|
logDebug.Println(peer, "- UAPI: Removing")
|
||||||
|
@ -290,8 +292,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
peer.handshake.mutex.Unlock()
|
peer.handshake.mutex.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to set preshared key:", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case "endpoint":
|
case "endpoint":
|
||||||
|
@ -312,8 +313,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to set endpoint:", err, ":", value)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case "persistent_keepalive_interval":
|
case "persistent_keepalive_interval":
|
||||||
|
@ -324,8 +324,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
|
|
||||||
secs, err := strconv.ParseUint(value, 10, 16)
|
secs, err := strconv.ParseUint(value, 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to set persistent keepalive interval:", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
|
old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
|
||||||
|
@ -334,8 +333,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
|
|
||||||
if old == 0 && secs != 0 {
|
if old == 0 && secs != 0 {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to get tun device status:", err)
|
return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err)
|
||||||
return &IPCError{ipc.IpcErrorIO}
|
|
||||||
}
|
}
|
||||||
if device.isUp.Get() && !dummy {
|
if device.isUp.Get() && !dummy {
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
|
@ -347,8 +345,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
logDebug.Println(peer, "- UAPI: Removing all allowedips")
|
logDebug.Println(peer, "- UAPI: Removing all allowedips")
|
||||||
|
|
||||||
if value != "true" {
|
if value != "true" {
|
||||||
logError.Println("Failed to replace allowedips, invalid value:", value)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if dummy {
|
if dummy {
|
||||||
|
@ -363,8 +360,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
|
|
||||||
_, network, err := net.ParseCIDR(value)
|
_, network, err := net.ParseCIDR(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to set allowed ip:", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if dummy {
|
if dummy {
|
||||||
|
@ -377,13 +373,11 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
|
||||||
case "protocol_version":
|
case "protocol_version":
|
||||||
|
|
||||||
if value != "1" {
|
if value != "1" {
|
||||||
logError.Println("Invalid protocol version:", value)
|
return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logError.Println("Invalid UAPI peer key:", key)
|
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -431,16 +425,14 @@ func (device *Device) IpcHandle(socket net.Conn) {
|
||||||
err = device.IpcSetOperation(buffered.Reader)
|
err = device.IpcSetOperation(buffered.Reader)
|
||||||
if err != nil && !errors.As(err, &status) {
|
if err != nil && !errors.As(err, &status) {
|
||||||
// should never happen
|
// should never happen
|
||||||
device.log.Error.Println("Invalid UAPI error:", err)
|
status = ipcErrorf(1, "invalid UAPI error: %w", err)
|
||||||
status = &IPCError{1}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case "get=1\n":
|
case "get=1\n":
|
||||||
err = device.IpcGetOperation(buffered.Writer)
|
err = device.IpcGetOperation(buffered.Writer)
|
||||||
if err != nil && !errors.As(err, &status) {
|
if err != nil && !errors.As(err, &status) {
|
||||||
// should never happen
|
// should never happen
|
||||||
device.log.Error.Println("Invalid UAPI error:", err)
|
status = ipcErrorf(1, "invalid UAPI error: %w", err)
|
||||||
status = &IPCError{1}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
|
Loading…
Reference in a new issue