Completed get/set configuration

For debugging of "outbound flow"
Mostly, a few things still missing
This commit is contained in:
Mathias Hall-Andersen 2017-06-29 14:39:21 +02:00
parent 1f0976a26c
commit 7e185db141
6 changed files with 109 additions and 80 deletions

View file

@ -5,24 +5,22 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"net" "net"
"strconv" "strconv"
"strings"
"time" "time"
) )
/* TODO : use real error code // #include <errno.h>
* Many of which will be the same import "C"
/* TODO: More fine grained?
*/ */
const ( const (
ipcErrorNoPeer = 0 ipcErrorNoPeer = C.EPROTO
ipcErrorNoKeyValue = 1 ipcErrorNoKeyValue = C.EPROTO
ipcErrorInvalidKey = 2 ipcErrorInvalidKey = C.EPROTO
ipcErrorInvalidValue = 2 ipcErrorInvalidValue = C.EPROTO
ipcErrorInvalidPrivateKey = 3
ipcErrorInvalidPublicKey = 4
ipcErrorInvalidPort = 5
ipcErrorInvalidIPAddress = 6
) )
type IPCError struct { type IPCError struct {
@ -78,7 +76,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
// send lines // send lines
for _, line := range lines { for _, line := range lines {
device.log.Debug.Println("config:", line) device.log.Debug.Println("Response:", line)
_, err := socket.WriteString(line + "\n") _, err := socket.WriteString(line + "\n")
if err != nil { if err != nil {
return err return err
@ -89,29 +87,26 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
} }
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logger := device.log.Debug
scanner := bufio.NewScanner(socket) scanner := bufio.NewScanner(socket)
device.mutex.Lock() var peer *Peer
defer device.mutex.Unlock()
for scanner.Scan() { for scanner.Scan() {
var key string
var value string
var peer *Peer
// Parse line // Parse line
line := scanner.Text() line := scanner.Text()
if line == "\n" { if line == "" {
break return nil
} }
fmt.Println(line) parts := strings.Split(line, "=")
n, err := fmt.Sscanf(line, "%s=%s\n", &key, &value) if len(parts) != 2 {
if n != 2 || err != nil { device.log.Debug.Println(parts)
fmt.Println(err, n)
return &IPCError{Code: ipcErrorNoKeyValue} return &IPCError{Code: ipcErrorNoKeyValue}
} }
key := parts[0]
value := parts[1]
logger.Println("Key-value pair: (", key, ",", value, ")") // TODO: Remove, leaks private key to log
switch key { switch key {
@ -119,41 +114,60 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "private_key": case "private_key":
if value == "" { if value == "" {
device.mutex.Lock()
device.privateKey = NoisePrivateKey{} device.privateKey = NoisePrivateKey{}
device.mutex.Unlock()
} else { } else {
device.mutex.Lock()
err := device.privateKey.FromHex(value) err := device.privateKey.FromHex(value)
device.mutex.Unlock()
if err != nil { if err != nil {
return &IPCError{Code: ipcErrorInvalidPrivateKey} logger.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
} }
} }
case "listen_port": case "listen_port":
_, err := fmt.Sscanf(value, "%ud", &device.address.Port) var port int
if err != nil { _, err := fmt.Sscanf(value, "%d", &port)
return &IPCError{Code: ipcErrorInvalidPort} if err != nil || port > (1<<16) || port < 0 {
logger.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorInvalidValue}
} }
device.mutex.Lock()
if device.address == nil {
device.address = &net.UDPAddr{}
}
device.address.Port = port
device.mutex.Unlock()
case "fwmark": case "fwmark":
panic(nil) // not handled yet logger.Println("FWMark not handled yet")
case "public_key": case "public_key":
var pubKey NoisePublicKey var pubKey NoisePublicKey
err := pubKey.FromHex(value) err := pubKey.FromHex(value)
if err != nil { if err != nil {
return &IPCError{Code: ipcErrorInvalidPublicKey} logger.Println("Failed to get peer by public_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
} }
device.mutex.RLock()
found, ok := device.peers[pubKey] found, ok := device.peers[pubKey]
device.mutex.RUnlock()
if ok { if ok {
peer = found peer = found
} else { } else {
peer = device.NewPeer(pubKey) peer = device.NewPeer(pubKey)
} }
if peer == nil {
panic(errors.New("bug: failed to find peer"))
}
case "replace_peers": case "replace_peers":
if key == "true" { if value == "true" {
device.RemoveAllPeers() device.RemoveAllPeers()
} else if key == "false" {
} else { } else {
logger.Println("Failed to set replace_peers, invalid value:", value)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
@ -161,6 +175,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
/* Peer configuration */ /* Peer configuration */
if peer == nil { if peer == nil {
logger.Println("No peer referenced, before peer operation")
return &IPCError{Code: ipcErrorNoPeer} return &IPCError{Code: ipcErrorNoPeer}
} }
@ -168,7 +183,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "remove": case "remove":
peer.mutex.Lock() peer.mutex.Lock()
// device.RemovePeer(peer.publicKey) device.RemovePeer(peer.handshake.remoteStatic)
peer.mutex.Unlock()
logger.Println("Remove peer")
peer = nil peer = nil
case "preshared_key": case "preshared_key":
@ -178,13 +195,15 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return peer.handshake.presharedKey.FromHex(value) return peer.handshake.presharedKey.FromHex(value)
}() }()
if err != nil { if err != nil {
return &IPCError{Code: ipcErrorInvalidPublicKey} logger.Println("Failed to set preshared_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
} }
case "endpoint": case "endpoint":
ip := net.ParseIP(value) ip := net.ParseIP(value)
if ip == nil { if ip == nil {
return &IPCError{Code: ipcErrorInvalidIPAddress} logger.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalidValue}
} }
peer.mutex.Lock() peer.mutex.Lock()
// peer.endpoint = ip FIX // peer.endpoint = ip FIX
@ -193,6 +212,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "persistent_keepalive_interval": case "persistent_keepalive_interval":
secs, err := strconv.ParseInt(value, 10, 64) secs, err := strconv.ParseInt(value, 10, 64)
if secs < 0 || err != nil { if secs < 0 || err != nil {
logger.Println("Failed to set persistent_keepalive_interval:", err)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
peer.mutex.Lock() peer.mutex.Lock()
@ -200,24 +220,27 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
peer.mutex.Unlock() peer.mutex.Unlock()
case "replace_allowed_ips": case "replace_allowed_ips":
if key == "true" { if value == "true" {
device.routingTable.RemovePeer(peer) device.routingTable.RemovePeer(peer)
} else if key == "false" {
} else { } else {
logger.Println("Failed to set replace_allowed_ips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
case "allowed_ip": case "allowed_ip":
_, network, err := net.ParseCIDR(value) _, network, err := net.ParseCIDR(value)
if err != nil { if err != nil {
logger.Println("Failed to set allowed_ip:", err)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
ones, _ := network.Mask.Size() ones, _ := network.Mask.Size()
logger.Println(network, ones, network.IP)
device.routingTable.Insert(network.IP, uint(ones), peer) device.routingTable.Insert(network.IP, uint(ones), peer)
/* Invalid key */ /* Invalid key */
default: default:
logger.Println("Invalid key:", key)
return &IPCError{Code: ipcErrorInvalidKey} return &IPCError{Code: ipcErrorInvalidKey}
} }
} }
@ -226,49 +249,48 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return nil return nil
} }
func ipcListen(device *Device, socket io.ReadWriter) error { func ipcHandle(device *Device, socket net.Conn) {
buffered := func(s io.ReadWriter) *bufio.ReadWriter { func() {
reader := bufio.NewReader(s) buffered := func(s io.ReadWriter) *bufio.ReadWriter {
writer := bufio.NewWriter(s) reader := bufio.NewReader(s)
return bufio.NewReadWriter(reader, writer) writer := bufio.NewWriter(s)
}(socket) return bufio.NewReadWriter(reader, writer)
}(socket)
defer buffered.Flush() defer buffered.Flush()
for {
op, err := buffered.ReadString('\n') op, err := buffered.ReadString('\n')
if err != nil { if err != nil {
return err return
} }
log.Println(op)
switch op { switch op {
case "set=1\n": case "set=1\n":
device.log.Debug.Println("Config, set operation")
err := ipcSetOperation(device, buffered) err := ipcSetOperation(device, buffered)
if err != nil { if err != nil {
fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
return err
} else { } else {
fmt.Fprintf(buffered, "errno=0\n\n") fmt.Fprintf(buffered, "errno=0\n\n")
} }
buffered.Flush() break
case "get=1\n": case "get=1\n":
device.log.Debug.Println("Config, get operation")
err := ipcGetOperation(device, buffered) err := ipcGetOperation(device, buffered)
if err != nil { if err != nil {
fmt.Fprintf(buffered, "errno=1\n\n") // fix fmt.Fprintf(buffered, "errno=1\n\n") // fix
return err
} else { } else {
fmt.Fprintf(buffered, "errno=0\n\n") fmt.Fprintf(buffered, "errno=0\n\n")
} }
buffered.Flush() break
case "\n":
default: default:
return errors.New("handle this please") device.log.Info.Println("Invalid UAPI operation:", op)
} }
} }()
socket.Close()
} }

View file

@ -81,10 +81,7 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
peer.mutex.Lock() peer.mutex.Lock()
device.routingTable.RemovePeer(peer) device.routingTable.RemovePeer(peer)
delete(device.peers, key) delete(device.peers, key)
} peer.Close()
func (device *Device) RemoveAllAllowedIps(peer *Peer) {
} }
func (device *Device) RemoveAllPeers() { func (device *Device) RemoveAllPeers() {
@ -93,8 +90,7 @@ func (device *Device) RemoveAllPeers() {
for key, peer := range device.peers { for key, peer := range device.peers {
peer.mutex.Lock() peer.mutex.Lock()
device.routingTable.RemovePeer(peer)
delete(device.peers, key) delete(device.peers, key)
peer.mutex.Unlock() peer.Close()
} }
} }

View file

@ -1,21 +1,28 @@
package main package main
import ( import (
"fmt"
"log" "log"
"net" "net"
"os"
) )
/* /* TODO: Fix logging
* * TODO: Fix daemon
* TODO: Fix logging
*/ */
func main() { func main() {
if len(os.Args) != 2 {
return
}
deviceName := os.Args[1]
// Open TUN device // Open TUN device
// TODO: Fix capabilities // TODO: Fix capabilities
tun, err := CreateTUN("test0") tun, err := CreateTUN(deviceName)
log.Println(tun, err) log.Println(tun, err)
if err != nil { if err != nil {
return return
@ -25,19 +32,17 @@ func main() {
// Start configuration lister // Start configuration lister
l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock") socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
l, err := net.Listen("unix", socketPath)
if err != nil { if err != nil {
log.Fatal("listen error:", err) log.Fatal("listen error:", err)
} }
for { for {
fd, err := l.Accept() conn, err := l.Accept()
if err != nil { if err != nil {
log.Fatal("accept error:", err) log.Fatal("accept error:", err)
} }
go func(conn net.Conn) { go ipcHandle(device, conn)
err := ipcListen(device, conn)
log.Println(err)
}(fd)
} }
} }

View file

@ -16,9 +16,9 @@ func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
allowed := make([]net.IPNet, 10) allowed := make([]net.IPNet, 0, 10)
table.IPv4.AllowedIPs(peer, allowed) allowed = table.IPv4.AllowedIPs(peer, allowed)
table.IPv6.AllowedIPs(peer, allowed) allowed = table.IPv6.AllowedIPs(peer, allowed)
return allowed return allowed
} }

View file

@ -61,9 +61,11 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
* Obs. Single instance per TUN device * Obs. Single instance per TUN device
*/ */
func (device *Device) RoutineReadFromTUN(tun TUNDevice) { func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
device.log.Debug.Println("Routine, TUN Reader: started")
for { for {
// read packet // read packet
device.log.Debug.Println("Read")
packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
size, err := tun.Read(packet) size, err := tun.Read(packet)
if err != nil { if err != nil {
@ -76,8 +78,6 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
continue continue
} }
device.log.Debug.Println("New packet on TUN:", packet) // TODO: Slow debugging, remove.
// lookup peer // lookup peer
var peer *Peer var peer *Peer
@ -85,10 +85,12 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
case IPv4version: case IPv4version:
dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routingTable.LookupIPv4(dst) peer = device.routingTable.LookupIPv4(dst)
device.log.Debug.Println("New IPv4 packet:", packet, dst)
case IPv6version: case IPv6version:
dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routingTable.LookupIPv6(dst) peer = device.routingTable.LookupIPv6(dst)
device.log.Debug.Println("New IPv6 packet:", packet, dst)
default: default:
device.log.Debug.Println("Receieved packet with unknown IP version") device.log.Debug.Println("Receieved packet with unknown IP version")
@ -97,7 +99,7 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
if peer == nil { if peer == nil {
device.log.Debug.Println("No peer configured for IP") device.log.Debug.Println("No peer configured for IP")
return continue
} }
// insert into nonce/pre-handshake queue // insert into nonce/pre-handshake queue

View file

@ -195,7 +195,10 @@ func (node *Trie) Count() uint {
return l + r return l + r
} }
func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) { func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
if node == nil {
return results
}
if node.peer == p { if node.peer == p {
var mask net.IPNet var mask net.IPNet
mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8) mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
@ -213,6 +216,7 @@ func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) {
} }
results = append(results, mask) results = append(results, mask)
} }
node.child[0].AllowedIPs(p, results) results = node.child[0].AllowedIPs(p, results)
node.child[1].AllowedIPs(p, results) results = node.child[1].AllowedIPs(p, results)
return results
} }