Beginning work on UAPI and routing table
This commit is contained in:
parent
6bd0b2fbe2
commit
1eebdf88a3
190
src/config.go
Normal file
190
src/config.go
Normal file
|
@ -0,0 +1,190 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
)
|
||||
|
||||
/* todo : use real error code
|
||||
* Many of which will be the same
|
||||
*/
|
||||
const (
|
||||
ipcErrorNoPeer = 0
|
||||
ipcErrorNoKeyValue = 1
|
||||
ipcErrorInvalidKey = 2
|
||||
ipcErrorInvalidPrivateKey = 3
|
||||
ipcErrorInvalidPublicKey = 4
|
||||
ipcErrorInvalidPort = 5
|
||||
)
|
||||
|
||||
type IPCError struct {
|
||||
Code int
|
||||
}
|
||||
|
||||
func (s *IPCError) Error() string {
|
||||
return fmt.Sprintf("IPC error: %d", s.Code)
|
||||
}
|
||||
|
||||
func (s *IPCError) ErrorCode() int {
|
||||
return s.Code
|
||||
}
|
||||
|
||||
// Writes the configuration to the socket
|
||||
func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) {
|
||||
|
||||
}
|
||||
|
||||
// Creates new config, from old and socket message
|
||||
func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
|
||||
|
||||
scanner := bufio.NewScanner(socket)
|
||||
|
||||
dev.mutex.Lock()
|
||||
defer dev.mutex.Unlock()
|
||||
|
||||
for scanner.Scan() {
|
||||
var key string
|
||||
var value string
|
||||
var peer *Peer
|
||||
|
||||
// Parse line
|
||||
|
||||
line := scanner.Text()
|
||||
if line == "\n" {
|
||||
break
|
||||
}
|
||||
fmt.Println(line)
|
||||
n, err := fmt.Sscanf(line, "%s=%s\n", &key, &value)
|
||||
if n != 2 || err != nil {
|
||||
fmt.Println(err, n)
|
||||
return &IPCError{Code: ipcErrorNoKeyValue}
|
||||
}
|
||||
|
||||
switch key {
|
||||
|
||||
/* Interface configuration */
|
||||
|
||||
case "private_key":
|
||||
if value == "" {
|
||||
dev.privateKey = NoisePrivateKey{}
|
||||
} else {
|
||||
err := dev.privateKey.FromHex(value)
|
||||
if err != nil {
|
||||
return &IPCError{Code: ipcErrorInvalidPrivateKey}
|
||||
}
|
||||
}
|
||||
|
||||
case "listen_port":
|
||||
_, err := fmt.Sscanf(value, "%ud", &dev.listenPort)
|
||||
if err != nil {
|
||||
return &IPCError{Code: ipcErrorInvalidPort}
|
||||
}
|
||||
|
||||
case "fwmark":
|
||||
panic(nil) // not handled yet
|
||||
|
||||
case "public_key":
|
||||
var pubKey NoisePublicKey
|
||||
err := pubKey.FromHex(value)
|
||||
if err != nil {
|
||||
return &IPCError{Code: ipcErrorInvalidPublicKey}
|
||||
}
|
||||
found, ok := dev.peers[pubKey]
|
||||
if ok {
|
||||
peer = found
|
||||
} else {
|
||||
newPeer := &Peer{
|
||||
publicKey: pubKey,
|
||||
}
|
||||
peer = newPeer
|
||||
dev.peers[pubKey] = newPeer
|
||||
}
|
||||
|
||||
case "replace_peers":
|
||||
|
||||
default:
|
||||
/* Peer configuration */
|
||||
|
||||
if peer == nil {
|
||||
return &IPCError{Code: ipcErrorNoPeer}
|
||||
}
|
||||
|
||||
switch key {
|
||||
|
||||
case "remove":
|
||||
peer.mutex.Lock()
|
||||
|
||||
peer = nil
|
||||
|
||||
case "preshared_key":
|
||||
func() {
|
||||
peer.mutex.Lock()
|
||||
defer peer.mutex.Unlock()
|
||||
}()
|
||||
|
||||
case "endpoint":
|
||||
func() {
|
||||
peer.mutex.Lock()
|
||||
defer peer.mutex.Unlock()
|
||||
}()
|
||||
|
||||
case "persistent_keepalive_interval":
|
||||
func() {
|
||||
peer.mutex.Lock()
|
||||
defer peer.mutex.Unlock()
|
||||
}()
|
||||
|
||||
case "replace_allowed_ips":
|
||||
// remove peer from trie
|
||||
|
||||
case "allowed_ip":
|
||||
|
||||
/* Invalid key */
|
||||
|
||||
default:
|
||||
return &IPCError{Code: ipcErrorInvalidKey}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ipcListen(dev *Device, socket io.ReadWriter) error {
|
||||
|
||||
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
|
||||
reader := bufio.NewReader(s)
|
||||
writer := bufio.NewWriter(s)
|
||||
return bufio.NewReadWriter(reader, writer)
|
||||
}(socket)
|
||||
|
||||
for {
|
||||
op, err := buffered.ReadString('\n')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Println(op)
|
||||
|
||||
switch op {
|
||||
|
||||
case "set=1\n":
|
||||
err := ipcSetOperation(dev, buffered)
|
||||
if err != nil {
|
||||
fmt.Fprintf(buffered, "errno=%d\n", err.ErrorCode())
|
||||
return err
|
||||
} else {
|
||||
fmt.Fprintf(buffered, "errno=0\n")
|
||||
}
|
||||
buffered.Flush()
|
||||
|
||||
case "get=1\n":
|
||||
|
||||
default:
|
||||
return errors.New("handle this please")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
14
src/device.go
Normal file
14
src/device.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Device struct {
|
||||
mutex sync.RWMutex
|
||||
peers map[NoisePublicKey]*Peer
|
||||
privateKey NoisePrivateKey
|
||||
publicKey NoisePublicKey
|
||||
fwMark uint32
|
||||
listenPort uint16
|
||||
}
|
28
src/main.go
Normal file
28
src/main.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
)
|
||||
|
||||
func main() {
|
||||
l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
|
||||
if err != nil {
|
||||
log.Fatal("listen error:", err)
|
||||
}
|
||||
|
||||
for {
|
||||
fd, err := l.Accept()
|
||||
if err != nil {
|
||||
log.Fatal("accept error:", err)
|
||||
}
|
||||
|
||||
var dev Device
|
||||
go func(conn net.Conn) {
|
||||
err := ipcListen(&dev, conn)
|
||||
fmt.Println(err)
|
||||
}(fd)
|
||||
}
|
||||
|
||||
}
|
8
src/misc.go
Normal file
8
src/misc.go
Normal file
|
@ -0,0 +1,8 @@
|
|||
package main
|
||||
|
||||
func min(a uint, b uint) uint {
|
||||
if a > b {
|
||||
return b
|
||||
}
|
||||
return a
|
||||
}
|
51
src/noise.go
Normal file
51
src/noise.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const (
|
||||
NoisePublicKeySize = 32
|
||||
NoisePrivateKeySize = 32
|
||||
NoiseSymmetricKeySize = 32
|
||||
)
|
||||
|
||||
type (
|
||||
NoisePublicKey [NoisePublicKeySize]byte
|
||||
NoisePrivateKey [NoisePrivateKeySize]byte
|
||||
NoiseSymmetricKey [NoiseSymmetricKeySize]byte
|
||||
NoiseNonce uint64 // padded to 12-bytes
|
||||
)
|
||||
|
||||
func (key *NoisePrivateKey) FromHex(s string) error {
|
||||
slice, err := hex.DecodeString(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(slice) != NoisePrivateKeySize {
|
||||
return errors.New("Invalid length of hex string for curve25519 point")
|
||||
}
|
||||
copy(key[:], slice)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (key *NoisePrivateKey) ToHex() string {
|
||||
return hex.EncodeToString(key[:])
|
||||
}
|
||||
|
||||
func (key *NoisePublicKey) FromHex(s string) error {
|
||||
slice, err := hex.DecodeString(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(slice) != NoisePublicKeySize {
|
||||
return errors.New("Invalid length of hex string for curve25519 scalar")
|
||||
}
|
||||
copy(key[:], slice)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (key *NoisePublicKey) ToHex() string {
|
||||
return hex.EncodeToString(key[:])
|
||||
}
|
18
src/peer.go
Normal file
18
src/peer.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type KeyPair struct {
|
||||
recieveKey NoiseSymmetricKey
|
||||
recieveNonce NoiseNonce
|
||||
sendKey NoiseSymmetricKey
|
||||
sendNonce NoiseNonce
|
||||
}
|
||||
|
||||
type Peer struct {
|
||||
mutex sync.RWMutex
|
||||
publicKey NoisePublicKey
|
||||
presharedKey NoiseSymmetricKey
|
||||
}
|
154
src/trie.go
Normal file
154
src/trie.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
/* Syncronization must be done seperatly
|
||||
*
|
||||
*/
|
||||
|
||||
type Trie struct {
|
||||
cidr uint
|
||||
child [2]*Trie
|
||||
bits []byte
|
||||
peer *Peer
|
||||
|
||||
// Index of "branching" bit
|
||||
// bit_at_shift
|
||||
bit_at_byte uint
|
||||
bit_at_shift uint
|
||||
}
|
||||
|
||||
/* Finds length of matching prefix
|
||||
* Maybe there is a faster way
|
||||
*
|
||||
* Assumption: len(s1) == len(s2)
|
||||
*/
|
||||
func commonBits(s1 []byte, s2 []byte) uint {
|
||||
var i uint
|
||||
size := uint(len(s1))
|
||||
for i = 0; i < size; i += 1 {
|
||||
v := s1[i] ^ s2[i]
|
||||
if v != 0 {
|
||||
v >>= 1
|
||||
if v == 0 {
|
||||
return i*8 + 7
|
||||
}
|
||||
|
||||
v >>= 1
|
||||
if v == 0 {
|
||||
return i*8 + 6
|
||||
}
|
||||
|
||||
v >>= 1
|
||||
if v == 0 {
|
||||
return i*8 + 5
|
||||
}
|
||||
|
||||
v >>= 1
|
||||
if v == 0 {
|
||||
return i*8 + 4
|
||||
}
|
||||
|
||||
v >>= 1
|
||||
if v == 0 {
|
||||
return i*8 + 3
|
||||
}
|
||||
|
||||
v >>= 1
|
||||
if v == 0 {
|
||||
return i*8 + 2
|
||||
}
|
||||
|
||||
v >>= 1
|
||||
if v == 0 {
|
||||
return i*8 + 1
|
||||
}
|
||||
return i * 8
|
||||
}
|
||||
}
|
||||
return i * 8
|
||||
}
|
||||
|
||||
func (node *Trie) RemovePeer(p *Peer) *Trie {
|
||||
if node == nil {
|
||||
return node
|
||||
}
|
||||
|
||||
// Walk recursivly
|
||||
|
||||
node.child[0] = node.child[0].RemovePeer(p)
|
||||
node.child[1] = node.child[1].RemovePeer(p)
|
||||
|
||||
if node.peer != p {
|
||||
return node
|
||||
}
|
||||
|
||||
// Remove peer & merge
|
||||
|
||||
node.peer = nil
|
||||
if node.child[0] == nil {
|
||||
return node.child[1]
|
||||
}
|
||||
return node.child[0]
|
||||
}
|
||||
|
||||
func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
|
||||
if node == nil {
|
||||
return &Trie{
|
||||
bits: key,
|
||||
peer: peer,
|
||||
cidr: cidr,
|
||||
bit_at_byte: cidr / 8,
|
||||
bit_at_shift: 7 - (cidr % 8),
|
||||
}
|
||||
}
|
||||
|
||||
// Traverse deeper
|
||||
|
||||
common := commonBits(node.bits, key)
|
||||
if node.cidr <= cidr && common >= node.cidr {
|
||||
// Check if match the t.bits[:t.cidr] exactly
|
||||
if node.cidr == cidr {
|
||||
node.peer = peer
|
||||
return node
|
||||
}
|
||||
|
||||
// Go to child
|
||||
bit := (key[node.bit_at_byte] >> node.bit_at_shift) & 1
|
||||
node.child[bit] = node.child[bit].Insert(key, cidr, peer)
|
||||
return node
|
||||
}
|
||||
|
||||
// Split node
|
||||
|
||||
fmt.Println("new", common)
|
||||
|
||||
newNode := &Trie{
|
||||
bits: key,
|
||||
peer: peer,
|
||||
cidr: cidr,
|
||||
bit_at_byte: cidr / 8,
|
||||
bit_at_shift: 7 - (cidr % 8),
|
||||
}
|
||||
|
||||
cidr = min(cidr, common)
|
||||
node.cidr = cidr
|
||||
node.bit_at_byte = cidr / 8
|
||||
node.bit_at_shift = 7 - (cidr % 8)
|
||||
|
||||
// bval := node.bits[node.bit_at_byte] >> node.bit_at_shift // todo : remember index
|
||||
// Work in progress
|
||||
node.child[0] = newNode
|
||||
node.child[1] = newNode
|
||||
|
||||
return node
|
||||
}
|
||||
|
||||
func (t *Trie) Lookup(key []byte) *Peer {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
66
src/trie_test.go
Normal file
66
src/trie_test.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testPairCommonBits struct {
|
||||
s1 []byte
|
||||
s2 []byte
|
||||
match uint
|
||||
}
|
||||
|
||||
type testPairTrieInsert struct {
|
||||
key []byte
|
||||
cidr uint
|
||||
peer *Peer
|
||||
}
|
||||
|
||||
func printTrie(t *testing.T, p *Trie) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
t.Log(p)
|
||||
printTrie(t, p.child[0])
|
||||
printTrie(t, p.child[1])
|
||||
}
|
||||
|
||||
func TestCommonBits(t *testing.T) {
|
||||
|
||||
tests := []testPairCommonBits{
|
||||
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
|
||||
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
|
||||
{s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31},
|
||||
{s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15},
|
||||
{s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0},
|
||||
}
|
||||
|
||||
for _, p := range tests {
|
||||
v := commonBits(p.s1, p.s2)
|
||||
if v != p.match {
|
||||
t.Error(
|
||||
"For slice", p.s1, p.s2,
|
||||
"expected match", p.match,
|
||||
"got", v,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrieInsertV4(t *testing.T) {
|
||||
var trie *Trie
|
||||
|
||||
peer1 := Peer{}
|
||||
peer2 := Peer{}
|
||||
|
||||
tests := []testPairTrieInsert{
|
||||
{key: []byte{192, 168, 1, 1}, cidr: 24, peer: &peer1},
|
||||
{key: []byte{192, 169, 1, 1}, cidr: 24, peer: &peer2},
|
||||
}
|
||||
|
||||
for _, p := range tests {
|
||||
trie = trie.Insert(p.key, p.cidr, p.peer)
|
||||
printTrie(t, trie)
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in a new issue