Beginning work on TUN interface

And outbound routing

I am not entirely convinced the use of net.IP is a good idea,
since the internal representation of net.IP is a byte slice
and all constructor functions in "net" return 16 byte slices
(padded for IPv4), while the use in this project uses 4 byte slices.
Which may be confusing.
This commit is contained in:
Mathias Hall-Andersen 2017-06-04 21:48:15 +02:00
parent dbc3ee3e9d
commit 1868d15914
9 changed files with 290 additions and 62 deletions

View file

@ -7,6 +7,8 @@ import (
"io" "io"
"log" "log"
"net" "net"
"strconv"
"time"
) )
/* todo : use real error code /* todo : use real error code
@ -16,6 +18,7 @@ const (
ipcErrorNoPeer = 0 ipcErrorNoPeer = 0
ipcErrorNoKeyValue = 1 ipcErrorNoKeyValue = 1
ipcErrorInvalidKey = 2 ipcErrorInvalidKey = 2
ipcErrorInvalidValue = 2
ipcErrorInvalidPrivateKey = 3 ipcErrorInvalidPrivateKey = 3
ipcErrorInvalidPublicKey = 4 ipcErrorInvalidPublicKey = 4
ipcErrorInvalidPort = 5 ipcErrorInvalidPort = 5
@ -34,18 +37,16 @@ func (s *IPCError) ErrorCode() int {
return s.Code return s.Code
} }
// Writes the configuration to the socket
func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) { func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) {
} }
// Creates new config, from old and socket message func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
scanner := bufio.NewScanner(socket) scanner := bufio.NewScanner(socket)
dev.mutex.Lock() device.mutex.Lock()
defer dev.mutex.Unlock() defer device.mutex.Unlock()
for scanner.Scan() { for scanner.Scan() {
var key string var key string
@ -71,16 +72,16 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
case "private_key": case "private_key":
if value == "" { if value == "" {
dev.privateKey = NoisePrivateKey{} device.privateKey = NoisePrivateKey{}
} else { } else {
err := dev.privateKey.FromHex(value) err := device.privateKey.FromHex(value)
if err != nil { if err != nil {
return &IPCError{Code: ipcErrorInvalidPrivateKey} return &IPCError{Code: ipcErrorInvalidPrivateKey}
} }
} }
case "listen_port": case "listen_port":
_, err := fmt.Sscanf(value, "%ud", &dev.listenPort) _, err := fmt.Sscanf(value, "%ud", &device.listenPort)
if err != nil { if err != nil {
return &IPCError{Code: ipcErrorInvalidPort} return &IPCError{Code: ipcErrorInvalidPort}
} }
@ -94,7 +95,7 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
if err != nil { if err != nil {
return &IPCError{Code: ipcErrorInvalidPublicKey} return &IPCError{Code: ipcErrorInvalidPublicKey}
} }
found, ok := dev.peers[pubKey] found, ok := device.peers[pubKey]
if ok { if ok {
peer = found peer = found
} else { } else {
@ -102,14 +103,16 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
publicKey: pubKey, publicKey: pubKey,
} }
peer = newPeer peer = newPeer
dev.peers[pubKey] = newPeer device.peers[pubKey] = newPeer
} }
case "replace_peers": case "replace_peers":
if key == "true" { if key == "true" {
dev.RemoveAllPeers() device.RemoveAllPeers()
} else if key == "false" {
} else {
return &IPCError{Code: ipcErrorInvalidValue}
} }
// todo: else fail
default: default:
/* Peer configuration */ /* Peer configuration */
@ -122,7 +125,7 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
case "remove": case "remove":
peer.mutex.Lock() peer.mutex.Lock()
dev.RemovePeer(peer.publicKey) device.RemovePeer(peer.publicKey)
peer = nil peer = nil
case "preshared_key": case "preshared_key":
@ -145,15 +148,29 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
peer.mutex.Unlock() peer.mutex.Unlock()
case "persistent_keepalive_interval": case "persistent_keepalive_interval":
func() { secs, err := strconv.ParseInt(value, 10, 64)
peer.mutex.Lock() if secs < 0 || err != nil {
defer peer.mutex.Unlock() return &IPCError{Code: ipcErrorInvalidValue}
}() }
peer.mutex.Lock()
peer.persistentKeepaliveInterval = time.Duration(secs) * time.Second
peer.mutex.Unlock()
case "replace_allowed_ips": case "replace_allowed_ips":
// remove peer from trie if key == "true" {
device.routingTable.RemovePeer(peer)
} else if key == "false" {
} else {
return &IPCError{Code: ipcErrorInvalidValue}
}
case "allowed_ip": case "allowed_ip":
_, network, err := net.ParseCIDR(value)
if err != nil {
return &IPCError{Code: ipcErrorInvalidValue}
}
ones, _ := network.Mask.Size()
device.routingTable.Insert(network.IP, uint(ones), peer)
/* Invalid key */ /* Invalid key */

17
src/ip.go Normal file
View file

@ -0,0 +1,17 @@
package main
import (
"net"
)
const (
IPv4version = 4
IPv4offsetSrc = 12
IPv4offsetDst = IPv4offsetSrc + net.IPv4len
)
const (
IPv6version = 6
IPv6offsetSrc = 8
IPv6offsetDst = IPv6offsetSrc + net.IPv6len
)

View file

@ -1,11 +1,33 @@
package main package main
import "fmt"
func main() {
fd, err := CreateTUN("test0")
fmt.Println(fd, err)
queue := make(chan []byte, 1000)
var device Device
go OutgoingRoutingWorker(&device, queue)
for {
tmp := make([]byte, 1<<16)
n, err := fd.Read(tmp)
if err != nil {
break
}
queue <- tmp[:n]
}
}
/*
import ( import (
"fmt" "fmt"
"log" "log"
"net" "net"
) )
func main() { func main() {
l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock") l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
if err != nil { if err != nil {
@ -24,5 +46,5 @@ func main() {
fmt.Println(err) fmt.Println(err)
}(fd) }(fd)
} }
} }
*/

View file

@ -3,6 +3,7 @@ package main
import ( import (
"net" "net"
"sync" "sync"
"time"
) )
type KeyPair struct { type KeyPair struct {
@ -13,8 +14,9 @@ type KeyPair struct {
} }
type Peer struct { type Peer struct {
mutex sync.RWMutex mutex sync.RWMutex
publicKey NoisePublicKey publicKey NoisePublicKey
presharedKey NoiseSymmetricKey presharedKey NoiseSymmetricKey
endpoint net.IP endpoint net.IP
persistentKeepaliveInterval time.Duration
} }

View file

@ -1,13 +1,12 @@
package main package main
import ( import (
"errors"
"fmt"
"net"
"sync" "sync"
) )
/* Thread-safe high level functions for cryptkey routing.
*
*/
type RoutingTable struct { type RoutingTable struct {
IPv4 *Trie IPv4 *Trie
IPv6 *Trie IPv6 *Trie
@ -20,3 +19,51 @@ func (table *RoutingTable) RemovePeer(peer *Peer) {
table.IPv4 = table.IPv4.RemovePeer(peer) table.IPv4 = table.IPv4.RemovePeer(peer)
table.IPv6 = table.IPv6.RemovePeer(peer) table.IPv6 = table.IPv6.RemovePeer(peer)
} }
func (table *RoutingTable) Insert(ip net.IP, cidr uint, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
switch len(ip) {
case net.IPv6len:
table.IPv6 = table.IPv6.Insert(ip, cidr, peer)
case net.IPv4len:
table.IPv4 = table.IPv4.Insert(ip, cidr, peer)
default:
panic(errors.New("Inserting unknown address type"))
}
}
func (table *RoutingTable) LookupIPv4(address []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.IPv4.Lookup(address)
}
func (table *RoutingTable) LookupIPv6(address []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.IPv6.Lookup(address)
}
func OutgoingRoutingWorker(device *Device, queue chan []byte) {
for {
packet := <-queue
switch packet[0] >> 4 {
case IPv4version:
dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer := device.routingTable.LookupIPv4(dst)
fmt.Println("IPv4", peer)
case IPv6version:
dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer := device.routingTable.LookupIPv6(dst)
fmt.Println("IPv6", peer)
default:
// todo: log
fmt.Println("Unknown IP version")
}
}
}

View file

@ -1,5 +1,9 @@
package main package main
import (
"net"
)
/* Binary trie /* Binary trie
* *
* Syncronization done seperatly * Syncronization done seperatly
@ -22,13 +26,13 @@ type Trie struct {
/* Finds length of matching prefix /* Finds length of matching prefix
* Maybe there is a faster way * Maybe there is a faster way
* *
* Assumption: len(s1) == len(s2) * Assumption: len(ip1) == len(ip2)
*/ */
func commonBits(s1 []byte, s2 []byte) uint { func commonBits(ip1 net.IP, ip2 net.IP) uint {
var i uint var i uint
size := uint(len(s1)) size := uint(len(ip1))
for i = 0; i < size; i += 1 { for i = 0; i < size; i += 1 {
v := s1[i] ^ s2[i] v := ip1[i] ^ ip2[i]
if v != 0 { if v != 0 {
v >>= 1 v >>= 1
if v == 0 { if v == 0 {
@ -93,17 +97,17 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
return node.child[0] return node.child[0]
} }
func (node *Trie) choose(key []byte) byte { func (node *Trie) choose(ip net.IP) byte {
return (key[node.bit_at_byte] >> node.bit_at_shift) & 1 return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
} }
func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie { func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
// At leaf // At leaf
if node == nil { if node == nil {
return &Trie{ return &Trie{
bits: key, bits: ip,
peer: peer, peer: peer,
cidr: cidr, cidr: cidr,
bit_at_byte: cidr / 8, bit_at_byte: cidr / 8,
@ -113,21 +117,21 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
// Traverse deeper // Traverse deeper
common := commonBits(node.bits, key) common := commonBits(node.bits, ip)
if node.cidr <= cidr && common >= node.cidr { if node.cidr <= cidr && common >= node.cidr {
if node.cidr == cidr { if node.cidr == cidr {
node.peer = peer node.peer = peer
return node return node
} }
bit := node.choose(key) bit := node.choose(ip)
node.child[bit] = node.child[bit].Insert(key, cidr, peer) node.child[bit] = node.child[bit].Insert(ip, cidr, peer)
return node return node
} }
// Split node // Split node
newNode := &Trie{ newNode := &Trie{
bits: key, bits: ip,
peer: peer, peer: peer,
cidr: cidr, cidr: cidr,
bit_at_byte: cidr / 8, bit_at_byte: cidr / 8,
@ -147,31 +151,31 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
// Create new parent for node & newNode // Create new parent for node & newNode
parent := &Trie{ parent := &Trie{
bits: key, bits: ip,
peer: nil, peer: nil,
cidr: cidr, cidr: cidr,
bit_at_byte: cidr / 8, bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8), bit_at_shift: 7 - (cidr % 8),
} }
bit := parent.choose(key) bit := parent.choose(ip)
parent.child[bit] = newNode parent.child[bit] = newNode
parent.child[bit^1] = node parent.child[bit^1] = node
return parent return parent
} }
func (node *Trie) Lookup(key []byte) *Peer { func (node *Trie) Lookup(ip net.IP) *Peer {
var found *Peer var found *Peer
size := uint(len(key)) size := uint(len(ip))
for node != nil && commonBits(node.bits, key) >= node.cidr { for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil { if node.peer != nil {
found = node.peer found = node.peer
} }
if node.bit_at_byte == size { if node.bit_at_byte == size {
break break
} }
bit := node.choose(key) bit := node.choose(ip)
node = node.child[bit] node = node.child[bit]
} }
return found return found

View file

@ -1,6 +1,8 @@
package main package main
import ( import (
"math/rand"
"net"
"testing" "testing"
) )
@ -55,6 +57,49 @@ func TestCommonBits(t *testing.T) {
} }
} }
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
var trie *Trie
var peers []*Peer
rand.Seed(1)
const AddressLength = 4
for n := 0; n < peerNumber; n += 1 {
peers = append(peers, &Peer{})
}
for n := 0; n < addressNumber; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber
trie = trie.Insert(addr[:], cidr, peers[index])
}
for n := 0; n < b.N; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
trie.Lookup(addr[:])
}
}
func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) {
benchmarkTrie(100, 1000, net.IPv4len, b)
}
func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) {
benchmarkTrie(10, 10, net.IPv4len, b)
}
func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) {
benchmarkTrie(100, 1000, net.IPv6len, b)
}
func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
benchmarkTrie(10, 10, net.IPv6len, b)
}
/* Test ported from kernel implementation: /* Test ported from kernel implementation:
* selftest/routingtable.h * selftest/routingtable.h
*/ */
@ -91,10 +136,10 @@ func TestTrieIPv4(t *testing.T) {
insert(b, 192, 168, 4, 4, 32) insert(b, 192, 168, 4, 4, 32)
insert(c, 192, 168, 0, 0, 16) insert(c, 192, 168, 0, 0, 16)
insert(d, 192, 95, 5, 64, 27) insert(d, 192, 95, 5, 64, 27)
insert(c, 192, 95, 5, 65, 27) /* replaces previous entry, and maskself is required */ insert(c, 192, 95, 5, 65, 27)
insert(e, 0, 0, 0, 0, 0) insert(e, 0, 0, 0, 0, 0)
insert(g, 64, 15, 112, 0, 20) insert(g, 64, 15, 112, 0, 20)
insert(h, 64, 15, 123, 211, 25) /* maskself is required */ insert(h, 64, 15, 123, 211, 25)
insert(a, 10, 0, 0, 0, 25) insert(a, 10, 0, 0, 0, 25)
insert(b, 10, 0, 0, 128, 25) insert(b, 10, 0, 0, 128, 25)
insert(a, 10, 1, 0, 0, 30) insert(a, 10, 1, 0, 0, 30)
@ -186,20 +231,6 @@ func TestTrieIPv6(t *testing.T) {
} }
} }
/*
assertNEQ := func(peer *Peer, a, b, c, d uint32) {
var addr []byte
addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
p := trie.Lookup(addr)
if p == peer {
t.Error("Assert NEQ failed")
}
}
*/
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
insert(c, 0x26075300, 0x60006b00, 0, 0, 64) insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
insert(e, 0, 0, 0, 0, 0) insert(e, 0, 0, 0, 0, 0)

8
src/tun.go Normal file
View file

@ -0,0 +1,8 @@
package main
type TUN interface {
Read([]byte) (int, error)
Write([]byte) (int, error)
Name() string
MTU() uint
}

80
src/tun_linux.go Normal file
View file

@ -0,0 +1,80 @@
package main
import (
"encoding/binary"
"errors"
"os"
"strings"
"syscall"
"unsafe"
)
/* Platform dependent functions for interacting with
* TUN devices on linux systems
*
*/
const CloneDevicePath = "/dev/net/tun"
const (
IFF_NO_PI = 0x1000
IFF_TUN = 0x1
IFNAMSIZ = 0x10
TUNSETIFF = 0x400454CA
)
type NativeTun struct {
fd *os.File
name string
mtu uint
}
func (tun *NativeTun) Name() string {
return tun.name
}
func (tun *NativeTun) MTU() uint {
return tun.mtu
}
func (tun *NativeTun) Write(d []byte) (int, error) {
return tun.fd.Write(d)
}
func (tun *NativeTun) Read(d []byte) (int, error) {
return tun.fd.Read(d)
}
func CreateTUN(name string) (TUN, error) {
// Open clone device
fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0)
if err != nil {
return nil, err
}
// Prepare ifreq struct
var ifr [18]byte
var flags uint16 = IFF_TUN | IFF_NO_PI
nameBytes := []byte(name)
if len(nameBytes) >= IFNAMSIZ {
return nil, errors.New("Name size too long")
}
copy(ifr[:], nameBytes)
binary.LittleEndian.PutUint16(ifr[16:], flags)
// Create new device
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL,
uintptr(fd.Fd()), uintptr(TUNSETIFF),
uintptr(unsafe.Pointer(&ifr[0])))
if errno != 0 {
return nil, errors.New("Failed to create tun, ioctl call failed")
}
// Read name of interface
newName := string(ifr[:])
newName = newName[:strings.Index(newName, "\000")]
return &NativeTun{
fd: fd,
name: newName,
}, nil
}