wireguard-go/src/config.go

297 lines
6.6 KiB
Go
Raw Normal View History

package main
import (
"bufio"
"errors"
"fmt"
"io"
2017-06-01 19:31:30 +00:00
"net"
"strconv"
"strings"
"time"
)
// #include <errno.h>
import "C"
/* TODO: More fine grained?
*/
const (
ipcErrorNoPeer = C.EPROTO
ipcErrorNoKeyValue = C.EPROTO
ipcErrorInvalidKey = C.EPROTO
ipcErrorInvalidValue = C.EPROTO
)
type IPCError struct {
Code int
}
func (s *IPCError) Error() string {
return fmt.Sprintf("IPC error: %d", s.Code)
}
func (s *IPCError) ErrorCode() int {
return s.Code
}
func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
device.mutex.RLock()
defer device.mutex.RUnlock()
// create lines
lines := make([]string, 0, 100)
send := func(line string) {
lines = append(lines, line)
}
if !device.privateKey.IsZero() {
send("private_key=" + device.privateKey.ToHex())
}
if device.address != nil {
send(fmt.Sprintf("listen_port=%d", device.address.Port))
}
for _, peer := range device.peers {
func() {
peer.mutex.RLock()
defer peer.mutex.RUnlock()
send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
if peer.endpoint != nil {
send("endpoint=" + peer.endpoint.String())
}
send(fmt.Sprintf("tx_bytes=%d", peer.tx_bytes))
send(fmt.Sprintf("rx_bytes=%d", peer.rx_bytes))
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
for _, ip := range device.routingTable.AllowedIPs(peer) {
send("allowed_ip=" + ip.String())
}
}()
}
// send lines
for _, line := range lines {
device.log.Debug.Println("Response:", line)
_, err := socket.WriteString(line + "\n")
if err != nil {
return err
}
}
return nil
}
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logger := device.log.Debug
scanner := bufio.NewScanner(socket)
var peer *Peer
for scanner.Scan() {
// Parse line
line := scanner.Text()
if line == "" {
return nil
}
parts := strings.Split(line, "=")
if len(parts) != 2 {
device.log.Debug.Println(parts)
return &IPCError{Code: ipcErrorNoKeyValue}
}
key := parts[0]
value := parts[1]
logger.Println("Key-value pair: (", key, ",", value, ")") // TODO: Remove, leaks private key to log
switch key {
/* Interface configuration */
case "private_key":
if value == "" {
device.mutex.Lock()
device.privateKey = NoisePrivateKey{}
device.mutex.Unlock()
} else {
device.mutex.Lock()
err := device.privateKey.FromHex(value)
device.mutex.Unlock()
if err != nil {
logger.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
}
case "listen_port":
var port int
_, err := fmt.Sscanf(value, "%d", &port)
if err != nil || port > (1<<16) || port < 0 {
logger.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
device.mutex.Lock()
if device.address == nil {
device.address = &net.UDPAddr{}
}
device.address.Port = port
device.mutex.Unlock()
case "fwmark":
logger.Println("FWMark not handled yet")
case "public_key":
var pubKey NoisePublicKey
err := pubKey.FromHex(value)
if err != nil {
logger.Println("Failed to get peer by public_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
device.mutex.RLock()
found, ok := device.peers[pubKey]
device.mutex.RUnlock()
if ok {
peer = found
} else {
2017-06-24 13:34:17 +00:00
peer = device.NewPeer(pubKey)
}
if peer == nil {
panic(errors.New("bug: failed to find peer"))
}
case "replace_peers":
if value == "true" {
device.RemoveAllPeers()
} else {
logger.Println("Failed to set replace_peers, invalid value:", value)
return &IPCError{Code: ipcErrorInvalidValue}
2017-06-01 19:31:30 +00:00
}
default:
/* Peer configuration */
if peer == nil {
logger.Println("No peer referenced, before peer operation")
return &IPCError{Code: ipcErrorNoPeer}
}
switch key {
case "remove":
peer.mutex.Lock()
device.RemovePeer(peer.handshake.remoteStatic)
peer.mutex.Unlock()
logger.Println("Remove peer")
peer = nil
case "preshared_key":
2017-06-01 19:31:30 +00:00
err := func() error {
peer.mutex.Lock()
defer peer.mutex.Unlock()
2017-06-24 13:34:17 +00:00
return peer.handshake.presharedKey.FromHex(value)
}()
2017-06-01 19:31:30 +00:00
if err != nil {
logger.Println("Failed to set preshared_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
2017-06-01 19:31:30 +00:00
}
case "endpoint":
2017-06-01 19:31:30 +00:00
ip := net.ParseIP(value)
if ip == nil {
logger.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalidValue}
2017-06-01 19:31:30 +00:00
}
peer.mutex.Lock()
2017-06-24 13:34:17 +00:00
// peer.endpoint = ip FIX
2017-06-01 19:31:30 +00:00
peer.mutex.Unlock()
case "persistent_keepalive_interval":
secs, err := strconv.ParseInt(value, 10, 64)
if secs < 0 || err != nil {
logger.Println("Failed to set persistent_keepalive_interval:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
peer.mutex.Lock()
peer.persistentKeepaliveInterval = time.Duration(secs) * time.Second
peer.mutex.Unlock()
case "replace_allowed_ips":
if value == "true" {
device.routingTable.RemovePeer(peer)
} else {
logger.Println("Failed to set replace_allowed_ips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalidValue}
}
case "allowed_ip":
_, network, err := net.ParseCIDR(value)
if err != nil {
logger.Println("Failed to set allowed_ip:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
ones, _ := network.Mask.Size()
logger.Println(network, ones, network.IP)
device.routingTable.Insert(network.IP, uint(ones), peer)
/* Invalid key */
default:
logger.Println("Invalid key:", key)
return &IPCError{Code: ipcErrorInvalidKey}
}
}
}
return nil
}
func ipcHandle(device *Device, socket net.Conn) {
func() {
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
reader := bufio.NewReader(s)
writer := bufio.NewWriter(s)
return bufio.NewReadWriter(reader, writer)
}(socket)
defer buffered.Flush()
op, err := buffered.ReadString('\n')
if err != nil {
return
}
switch op {
case "set=1\n":
device.log.Debug.Println("Config, set operation")
err := ipcSetOperation(device, buffered)
if err != nil {
fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
}
break
case "get=1\n":
device.log.Debug.Println("Config, get operation")
err := ipcGetOperation(device, buffered)
if err != nil {
fmt.Fprintf(buffered, "errno=1\n\n") // fix
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
}
break
default:
device.log.Info.Println("Invalid UAPI operation:", op)
}
}()
socket.Close()
}