Beginning work on UAPI and routing table
This commit is contained in:
parent
6bd0b2fbe2
commit
1eebdf88a3
190
src/config.go
Normal file
190
src/config.go
Normal file
|
@ -0,0 +1,190 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* todo : use real error code
|
||||||
|
* Many of which will be the same
|
||||||
|
*/
|
||||||
|
const (
|
||||||
|
ipcErrorNoPeer = 0
|
||||||
|
ipcErrorNoKeyValue = 1
|
||||||
|
ipcErrorInvalidKey = 2
|
||||||
|
ipcErrorInvalidPrivateKey = 3
|
||||||
|
ipcErrorInvalidPublicKey = 4
|
||||||
|
ipcErrorInvalidPort = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
type IPCError struct {
|
||||||
|
Code int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IPCError) Error() string {
|
||||||
|
return fmt.Sprintf("IPC error: %d", s.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IPCError) ErrorCode() int {
|
||||||
|
return s.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
// Writes the configuration to the socket
|
||||||
|
func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates new config, from old and socket message
|
||||||
|
func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(socket)
|
||||||
|
|
||||||
|
dev.mutex.Lock()
|
||||||
|
defer dev.mutex.Unlock()
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
var key string
|
||||||
|
var value string
|
||||||
|
var peer *Peer
|
||||||
|
|
||||||
|
// Parse line
|
||||||
|
|
||||||
|
line := scanner.Text()
|
||||||
|
if line == "\n" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fmt.Println(line)
|
||||||
|
n, err := fmt.Sscanf(line, "%s=%s\n", &key, &value)
|
||||||
|
if n != 2 || err != nil {
|
||||||
|
fmt.Println(err, n)
|
||||||
|
return &IPCError{Code: ipcErrorNoKeyValue}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch key {
|
||||||
|
|
||||||
|
/* Interface configuration */
|
||||||
|
|
||||||
|
case "private_key":
|
||||||
|
if value == "" {
|
||||||
|
dev.privateKey = NoisePrivateKey{}
|
||||||
|
} else {
|
||||||
|
err := dev.privateKey.FromHex(value)
|
||||||
|
if err != nil {
|
||||||
|
return &IPCError{Code: ipcErrorInvalidPrivateKey}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case "listen_port":
|
||||||
|
_, err := fmt.Sscanf(value, "%ud", &dev.listenPort)
|
||||||
|
if err != nil {
|
||||||
|
return &IPCError{Code: ipcErrorInvalidPort}
|
||||||
|
}
|
||||||
|
|
||||||
|
case "fwmark":
|
||||||
|
panic(nil) // not handled yet
|
||||||
|
|
||||||
|
case "public_key":
|
||||||
|
var pubKey NoisePublicKey
|
||||||
|
err := pubKey.FromHex(value)
|
||||||
|
if err != nil {
|
||||||
|
return &IPCError{Code: ipcErrorInvalidPublicKey}
|
||||||
|
}
|
||||||
|
found, ok := dev.peers[pubKey]
|
||||||
|
if ok {
|
||||||
|
peer = found
|
||||||
|
} else {
|
||||||
|
newPeer := &Peer{
|
||||||
|
publicKey: pubKey,
|
||||||
|
}
|
||||||
|
peer = newPeer
|
||||||
|
dev.peers[pubKey] = newPeer
|
||||||
|
}
|
||||||
|
|
||||||
|
case "replace_peers":
|
||||||
|
|
||||||
|
default:
|
||||||
|
/* Peer configuration */
|
||||||
|
|
||||||
|
if peer == nil {
|
||||||
|
return &IPCError{Code: ipcErrorNoPeer}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch key {
|
||||||
|
|
||||||
|
case "remove":
|
||||||
|
peer.mutex.Lock()
|
||||||
|
|
||||||
|
peer = nil
|
||||||
|
|
||||||
|
case "preshared_key":
|
||||||
|
func() {
|
||||||
|
peer.mutex.Lock()
|
||||||
|
defer peer.mutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
case "endpoint":
|
||||||
|
func() {
|
||||||
|
peer.mutex.Lock()
|
||||||
|
defer peer.mutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
case "persistent_keepalive_interval":
|
||||||
|
func() {
|
||||||
|
peer.mutex.Lock()
|
||||||
|
defer peer.mutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
case "replace_allowed_ips":
|
||||||
|
// remove peer from trie
|
||||||
|
|
||||||
|
case "allowed_ip":
|
||||||
|
|
||||||
|
/* Invalid key */
|
||||||
|
|
||||||
|
default:
|
||||||
|
return &IPCError{Code: ipcErrorInvalidKey}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipcListen(dev *Device, socket io.ReadWriter) error {
|
||||||
|
|
||||||
|
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
|
||||||
|
reader := bufio.NewReader(s)
|
||||||
|
writer := bufio.NewWriter(s)
|
||||||
|
return bufio.NewReadWriter(reader, writer)
|
||||||
|
}(socket)
|
||||||
|
|
||||||
|
for {
|
||||||
|
op, err := buffered.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Println(op)
|
||||||
|
|
||||||
|
switch op {
|
||||||
|
|
||||||
|
case "set=1\n":
|
||||||
|
err := ipcSetOperation(dev, buffered)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(buffered, "errno=%d\n", err.ErrorCode())
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(buffered, "errno=0\n")
|
||||||
|
}
|
||||||
|
buffered.Flush()
|
||||||
|
|
||||||
|
case "get=1\n":
|
||||||
|
|
||||||
|
default:
|
||||||
|
return errors.New("handle this please")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
14
src/device.go
Normal file
14
src/device.go
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Device struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
peers map[NoisePublicKey]*Peer
|
||||||
|
privateKey NoisePrivateKey
|
||||||
|
publicKey NoisePublicKey
|
||||||
|
fwMark uint32
|
||||||
|
listenPort uint16
|
||||||
|
}
|
28
src/main.go
Normal file
28
src/main.go
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("listen error:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
fd, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("accept error:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dev Device
|
||||||
|
go func(conn net.Conn) {
|
||||||
|
err := ipcListen(&dev, conn)
|
||||||
|
fmt.Println(err)
|
||||||
|
}(fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
8
src/misc.go
Normal file
8
src/misc.go
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
func min(a uint, b uint) uint {
|
||||||
|
if a > b {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
51
src/noise.go
Normal file
51
src/noise.go
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
NoisePublicKeySize = 32
|
||||||
|
NoisePrivateKeySize = 32
|
||||||
|
NoiseSymmetricKeySize = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
NoisePublicKey [NoisePublicKeySize]byte
|
||||||
|
NoisePrivateKey [NoisePrivateKeySize]byte
|
||||||
|
NoiseSymmetricKey [NoiseSymmetricKeySize]byte
|
||||||
|
NoiseNonce uint64 // padded to 12-bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
func (key *NoisePrivateKey) FromHex(s string) error {
|
||||||
|
slice, err := hex.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(slice) != NoisePrivateKeySize {
|
||||||
|
return errors.New("Invalid length of hex string for curve25519 point")
|
||||||
|
}
|
||||||
|
copy(key[:], slice)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key *NoisePrivateKey) ToHex() string {
|
||||||
|
return hex.EncodeToString(key[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key *NoisePublicKey) FromHex(s string) error {
|
||||||
|
slice, err := hex.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(slice) != NoisePublicKeySize {
|
||||||
|
return errors.New("Invalid length of hex string for curve25519 scalar")
|
||||||
|
}
|
||||||
|
copy(key[:], slice)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key *NoisePublicKey) ToHex() string {
|
||||||
|
return hex.EncodeToString(key[:])
|
||||||
|
}
|
18
src/peer.go
Normal file
18
src/peer.go
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyPair struct {
|
||||||
|
recieveKey NoiseSymmetricKey
|
||||||
|
recieveNonce NoiseNonce
|
||||||
|
sendKey NoiseSymmetricKey
|
||||||
|
sendNonce NoiseNonce
|
||||||
|
}
|
||||||
|
|
||||||
|
type Peer struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
publicKey NoisePublicKey
|
||||||
|
presharedKey NoiseSymmetricKey
|
||||||
|
}
|
154
src/trie.go
Normal file
154
src/trie.go
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
/* Syncronization must be done seperatly
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
type Trie struct {
|
||||||
|
cidr uint
|
||||||
|
child [2]*Trie
|
||||||
|
bits []byte
|
||||||
|
peer *Peer
|
||||||
|
|
||||||
|
// Index of "branching" bit
|
||||||
|
// bit_at_shift
|
||||||
|
bit_at_byte uint
|
||||||
|
bit_at_shift uint
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Finds length of matching prefix
|
||||||
|
* Maybe there is a faster way
|
||||||
|
*
|
||||||
|
* Assumption: len(s1) == len(s2)
|
||||||
|
*/
|
||||||
|
func commonBits(s1 []byte, s2 []byte) uint {
|
||||||
|
var i uint
|
||||||
|
size := uint(len(s1))
|
||||||
|
for i = 0; i < size; i += 1 {
|
||||||
|
v := s1[i] ^ s2[i]
|
||||||
|
if v != 0 {
|
||||||
|
v >>= 1
|
||||||
|
if v == 0 {
|
||||||
|
return i*8 + 7
|
||||||
|
}
|
||||||
|
|
||||||
|
v >>= 1
|
||||||
|
if v == 0 {
|
||||||
|
return i*8 + 6
|
||||||
|
}
|
||||||
|
|
||||||
|
v >>= 1
|
||||||
|
if v == 0 {
|
||||||
|
return i*8 + 5
|
||||||
|
}
|
||||||
|
|
||||||
|
v >>= 1
|
||||||
|
if v == 0 {
|
||||||
|
return i*8 + 4
|
||||||
|
}
|
||||||
|
|
||||||
|
v >>= 1
|
||||||
|
if v == 0 {
|
||||||
|
return i*8 + 3
|
||||||
|
}
|
||||||
|
|
||||||
|
v >>= 1
|
||||||
|
if v == 0 {
|
||||||
|
return i*8 + 2
|
||||||
|
}
|
||||||
|
|
||||||
|
v >>= 1
|
||||||
|
if v == 0 {
|
||||||
|
return i*8 + 1
|
||||||
|
}
|
||||||
|
return i * 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return i * 8
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *Trie) RemovePeer(p *Peer) *Trie {
|
||||||
|
if node == nil {
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// Walk recursivly
|
||||||
|
|
||||||
|
node.child[0] = node.child[0].RemovePeer(p)
|
||||||
|
node.child[1] = node.child[1].RemovePeer(p)
|
||||||
|
|
||||||
|
if node.peer != p {
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove peer & merge
|
||||||
|
|
||||||
|
node.peer = nil
|
||||||
|
if node.child[0] == nil {
|
||||||
|
return node.child[1]
|
||||||
|
}
|
||||||
|
return node.child[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
|
||||||
|
if node == nil {
|
||||||
|
return &Trie{
|
||||||
|
bits: key,
|
||||||
|
peer: peer,
|
||||||
|
cidr: cidr,
|
||||||
|
bit_at_byte: cidr / 8,
|
||||||
|
bit_at_shift: 7 - (cidr % 8),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverse deeper
|
||||||
|
|
||||||
|
common := commonBits(node.bits, key)
|
||||||
|
if node.cidr <= cidr && common >= node.cidr {
|
||||||
|
// Check if match the t.bits[:t.cidr] exactly
|
||||||
|
if node.cidr == cidr {
|
||||||
|
node.peer = peer
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// Go to child
|
||||||
|
bit := (key[node.bit_at_byte] >> node.bit_at_shift) & 1
|
||||||
|
node.child[bit] = node.child[bit].Insert(key, cidr, peer)
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split node
|
||||||
|
|
||||||
|
fmt.Println("new", common)
|
||||||
|
|
||||||
|
newNode := &Trie{
|
||||||
|
bits: key,
|
||||||
|
peer: peer,
|
||||||
|
cidr: cidr,
|
||||||
|
bit_at_byte: cidr / 8,
|
||||||
|
bit_at_shift: 7 - (cidr % 8),
|
||||||
|
}
|
||||||
|
|
||||||
|
cidr = min(cidr, common)
|
||||||
|
node.cidr = cidr
|
||||||
|
node.bit_at_byte = cidr / 8
|
||||||
|
node.bit_at_shift = 7 - (cidr % 8)
|
||||||
|
|
||||||
|
// bval := node.bits[node.bit_at_byte] >> node.bit_at_shift // todo : remember index
|
||||||
|
// Work in progress
|
||||||
|
node.child[0] = newNode
|
||||||
|
node.child[1] = newNode
|
||||||
|
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Trie) Lookup(key []byte) *Peer {
|
||||||
|
if t == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
66
src/trie_test.go
Normal file
66
src/trie_test.go
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testPairCommonBits struct {
|
||||||
|
s1 []byte
|
||||||
|
s2 []byte
|
||||||
|
match uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type testPairTrieInsert struct {
|
||||||
|
key []byte
|
||||||
|
cidr uint
|
||||||
|
peer *Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
func printTrie(t *testing.T, p *Trie) {
|
||||||
|
if p == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Log(p)
|
||||||
|
printTrie(t, p.child[0])
|
||||||
|
printTrie(t, p.child[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommonBits(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []testPairCommonBits{
|
||||||
|
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
|
||||||
|
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
|
||||||
|
{s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31},
|
||||||
|
{s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15},
|
||||||
|
{s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range tests {
|
||||||
|
v := commonBits(p.s1, p.s2)
|
||||||
|
if v != p.match {
|
||||||
|
t.Error(
|
||||||
|
"For slice", p.s1, p.s2,
|
||||||
|
"expected match", p.match,
|
||||||
|
"got", v,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTrieInsertV4(t *testing.T) {
|
||||||
|
var trie *Trie
|
||||||
|
|
||||||
|
peer1 := Peer{}
|
||||||
|
peer2 := Peer{}
|
||||||
|
|
||||||
|
tests := []testPairTrieInsert{
|
||||||
|
{key: []byte{192, 168, 1, 1}, cidr: 24, peer: &peer1},
|
||||||
|
{key: []byte{192, 169, 1, 1}, cidr: 24, peer: &peer2},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range tests {
|
||||||
|
trie = trie.Insert(p.key, p.cidr, p.peer)
|
||||||
|
printTrie(t, trie)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in a new issue