Completed get/set configuration

For debugging of "outbound flow"
Mostly, a few things still missing
This commit is contained in:
Mathias Hall-Andersen 2017-06-29 14:39:21 +02:00
parent 1f0976a26c
commit 7e185db141
6 changed files with 109 additions and 80 deletions

View file

@ -5,24 +5,22 @@ import (
"errors"
"fmt"
"io"
"log"
"net"
"strconv"
"strings"
"time"
)
/* TODO : use real error code
* Many of which will be the same
// #include <errno.h>
import "C"
/* TODO: More fine grained?
*/
const (
ipcErrorNoPeer = 0
ipcErrorNoKeyValue = 1
ipcErrorInvalidKey = 2
ipcErrorInvalidValue = 2
ipcErrorInvalidPrivateKey = 3
ipcErrorInvalidPublicKey = 4
ipcErrorInvalidPort = 5
ipcErrorInvalidIPAddress = 6
ipcErrorNoPeer = C.EPROTO
ipcErrorNoKeyValue = C.EPROTO
ipcErrorInvalidKey = C.EPROTO
ipcErrorInvalidValue = C.EPROTO
)
type IPCError struct {
@ -78,7 +76,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
// send lines
for _, line := range lines {
device.log.Debug.Println("config:", line)
device.log.Debug.Println("Response:", line)
_, err := socket.WriteString(line + "\n")
if err != nil {
return err
@ -89,29 +87,26 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
}
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logger := device.log.Debug
scanner := bufio.NewScanner(socket)
device.mutex.Lock()
defer device.mutex.Unlock()
for scanner.Scan() {
var key string
var value string
var peer *Peer
for scanner.Scan() {
// Parse line
line := scanner.Text()
if line == "\n" {
break
if line == "" {
return nil
}
fmt.Println(line)
n, err := fmt.Sscanf(line, "%s=%s\n", &key, &value)
if n != 2 || err != nil {
fmt.Println(err, n)
parts := strings.Split(line, "=")
if len(parts) != 2 {
device.log.Debug.Println(parts)
return &IPCError{Code: ipcErrorNoKeyValue}
}
key := parts[0]
value := parts[1]
logger.Println("Key-value pair: (", key, ",", value, ")") // TODO: Remove, leaks private key to log
switch key {
@ -119,41 +114,60 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "private_key":
if value == "" {
device.mutex.Lock()
device.privateKey = NoisePrivateKey{}
device.mutex.Unlock()
} else {
device.mutex.Lock()
err := device.privateKey.FromHex(value)
device.mutex.Unlock()
if err != nil {
return &IPCError{Code: ipcErrorInvalidPrivateKey}
logger.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
}
case "listen_port":
_, err := fmt.Sscanf(value, "%ud", &device.address.Port)
if err != nil {
return &IPCError{Code: ipcErrorInvalidPort}
var port int
_, err := fmt.Sscanf(value, "%d", &port)
if err != nil || port > (1<<16) || port < 0 {
logger.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
device.mutex.Lock()
if device.address == nil {
device.address = &net.UDPAddr{}
}
device.address.Port = port
device.mutex.Unlock()
case "fwmark":
panic(nil) // not handled yet
logger.Println("FWMark not handled yet")
case "public_key":
var pubKey NoisePublicKey
err := pubKey.FromHex(value)
if err != nil {
return &IPCError{Code: ipcErrorInvalidPublicKey}
logger.Println("Failed to get peer by public_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
device.mutex.RLock()
found, ok := device.peers[pubKey]
device.mutex.RUnlock()
if ok {
peer = found
} else {
peer = device.NewPeer(pubKey)
}
if peer == nil {
panic(errors.New("bug: failed to find peer"))
}
case "replace_peers":
if key == "true" {
if value == "true" {
device.RemoveAllPeers()
} else if key == "false" {
} else {
logger.Println("Failed to set replace_peers, invalid value:", value)
return &IPCError{Code: ipcErrorInvalidValue}
}
@ -161,6 +175,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
/* Peer configuration */
if peer == nil {
logger.Println("No peer referenced, before peer operation")
return &IPCError{Code: ipcErrorNoPeer}
}
@ -168,7 +183,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "remove":
peer.mutex.Lock()
// device.RemovePeer(peer.publicKey)
device.RemovePeer(peer.handshake.remoteStatic)
peer.mutex.Unlock()
logger.Println("Remove peer")
peer = nil
case "preshared_key":
@ -178,13 +195,15 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return peer.handshake.presharedKey.FromHex(value)
}()
if err != nil {
return &IPCError{Code: ipcErrorInvalidPublicKey}
logger.Println("Failed to set preshared_key:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
case "endpoint":
ip := net.ParseIP(value)
if ip == nil {
return &IPCError{Code: ipcErrorInvalidIPAddress}
logger.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalidValue}
}
peer.mutex.Lock()
// peer.endpoint = ip FIX
@ -193,6 +212,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "persistent_keepalive_interval":
secs, err := strconv.ParseInt(value, 10, 64)
if secs < 0 || err != nil {
logger.Println("Failed to set persistent_keepalive_interval:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
peer.mutex.Lock()
@ -200,24 +220,27 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
peer.mutex.Unlock()
case "replace_allowed_ips":
if key == "true" {
if value == "true" {
device.routingTable.RemovePeer(peer)
} else if key == "false" {
} else {
logger.Println("Failed to set replace_allowed_ips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalidValue}
}
case "allowed_ip":
_, network, err := net.ParseCIDR(value)
if err != nil {
logger.Println("Failed to set allowed_ip:", err)
return &IPCError{Code: ipcErrorInvalidValue}
}
ones, _ := network.Mask.Size()
logger.Println(network, ones, network.IP)
device.routingTable.Insert(network.IP, uint(ones), peer)
/* Invalid key */
default:
logger.Println("Invalid key:", key)
return &IPCError{Code: ipcErrorInvalidKey}
}
}
@ -226,8 +249,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return nil
}
func ipcListen(device *Device, socket io.ReadWriter) error {
func ipcHandle(device *Device, socket net.Conn) {
func() {
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
reader := bufio.NewReader(s)
writer := bufio.NewWriter(s)
@ -236,39 +260,37 @@ func ipcListen(device *Device, socket io.ReadWriter) error {
defer buffered.Flush()
for {
op, err := buffered.ReadString('\n')
if err != nil {
return err
return
}
log.Println(op)
switch op {
case "set=1\n":
device.log.Debug.Println("Config, set operation")
err := ipcSetOperation(device, buffered)
if err != nil {
fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
return err
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
}
buffered.Flush()
break
case "get=1\n":
device.log.Debug.Println("Config, get operation")
err := ipcGetOperation(device, buffered)
if err != nil {
fmt.Fprintf(buffered, "errno=1\n\n") // fix
return err
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
}
buffered.Flush()
break
case "\n":
default:
return errors.New("handle this please")
}
device.log.Info.Println("Invalid UAPI operation:", op)
}
}()
socket.Close()
}

View file

@ -81,10 +81,7 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
peer.mutex.Lock()
device.routingTable.RemovePeer(peer)
delete(device.peers, key)
}
func (device *Device) RemoveAllAllowedIps(peer *Peer) {
peer.Close()
}
func (device *Device) RemoveAllPeers() {
@ -93,8 +90,7 @@ func (device *Device) RemoveAllPeers() {
for key, peer := range device.peers {
peer.mutex.Lock()
device.routingTable.RemovePeer(peer)
delete(device.peers, key)
peer.mutex.Unlock()
peer.Close()
}
}

View file

@ -1,21 +1,28 @@
package main
import (
"fmt"
"log"
"net"
"os"
)
/*
*
* TODO: Fix logging
/* TODO: Fix logging
* TODO: Fix daemon
*/
func main() {
if len(os.Args) != 2 {
return
}
deviceName := os.Args[1]
// Open TUN device
// TODO: Fix capabilities
tun, err := CreateTUN("test0")
tun, err := CreateTUN(deviceName)
log.Println(tun, err)
if err != nil {
return
@ -25,19 +32,17 @@ func main() {
// Start configuration lister
l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
l, err := net.Listen("unix", socketPath)
if err != nil {
log.Fatal("listen error:", err)
}
for {
fd, err := l.Accept()
conn, err := l.Accept()
if err != nil {
log.Fatal("accept error:", err)
}
go func(conn net.Conn) {
err := ipcListen(device, conn)
log.Println(err)
}(fd)
go ipcHandle(device, conn)
}
}

View file

@ -16,9 +16,9 @@ func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet {
table.mutex.RLock()
defer table.mutex.RUnlock()
allowed := make([]net.IPNet, 10)
table.IPv4.AllowedIPs(peer, allowed)
table.IPv6.AllowedIPs(peer, allowed)
allowed := make([]net.IPNet, 0, 10)
allowed = table.IPv4.AllowedIPs(peer, allowed)
allowed = table.IPv6.AllowedIPs(peer, allowed)
return allowed
}

View file

@ -61,9 +61,11 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
* Obs. Single instance per TUN device
*/
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
device.log.Debug.Println("Routine, TUN Reader: started")
for {
// read packet
device.log.Debug.Println("Read")
packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
size, err := tun.Read(packet)
if err != nil {
@ -76,8 +78,6 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
continue
}
device.log.Debug.Println("New packet on TUN:", packet) // TODO: Slow debugging, remove.
// lookup peer
var peer *Peer
@ -85,10 +85,12 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
case IPv4version:
dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routingTable.LookupIPv4(dst)
device.log.Debug.Println("New IPv4 packet:", packet, dst)
case IPv6version:
dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routingTable.LookupIPv6(dst)
device.log.Debug.Println("New IPv6 packet:", packet, dst)
default:
device.log.Debug.Println("Receieved packet with unknown IP version")
@ -97,7 +99,7 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
if peer == nil {
device.log.Debug.Println("No peer configured for IP")
return
continue
}
// insert into nonce/pre-handshake queue

View file

@ -195,7 +195,10 @@ func (node *Trie) Count() uint {
return l + r
}
func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) {
func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
if node == nil {
return results
}
if node.peer == p {
var mask net.IPNet
mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
@ -213,6 +216,7 @@ func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) {
}
results = append(results, mask)
}
node.child[0].AllowedIPs(p, results)
node.child[1].AllowedIPs(p, results)
results = node.child[0].AllowedIPs(p, results)
results = node.child[1].AllowedIPs(p, results)
return results
}