Work on UAPI

Cross-platform API (get operation)
Handshake initiation creation process
Outbound packet flow
Fixes from code-review
This commit is contained in:
Mathias Hall-Andersen 2017-06-28 23:45:45 +02:00
parent 8236f3afa2
commit 1f0976a26c
18 changed files with 707 additions and 243 deletions

9
src/Makefile Normal file
View file

@ -0,0 +1,9 @@
BINARY=wireguard-go
build:
go build -o ${BINARY}
clean:
if [ -f ${BINARY} ]; then rm ${BINARY}; fi
.PHONY: clean

View file

@ -11,7 +11,7 @@ import (
"time"
)
/* todo : use real error code
/* TODO : use real error code
* Many of which will be the same
*/
const (
@ -37,8 +37,55 @@ func (s *IPCError) ErrorCode() int {
return s.Code
}
func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) {
func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
device.mutex.RLock()
defer device.mutex.RUnlock()
// create lines
lines := make([]string, 0, 100)
send := func(line string) {
lines = append(lines, line)
}
if !device.privateKey.IsZero() {
send("private_key=" + device.privateKey.ToHex())
}
if device.address != nil {
send(fmt.Sprintf("listen_port=%d", device.address.Port))
}
for _, peer := range device.peers {
func() {
peer.mutex.RLock()
defer peer.mutex.RUnlock()
send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
if peer.endpoint != nil {
send("endpoint=" + peer.endpoint.String())
}
send(fmt.Sprintf("tx_bytes=%d", peer.tx_bytes))
send(fmt.Sprintf("rx_bytes=%d", peer.rx_bytes))
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
for _, ip := range device.routingTable.AllowedIPs(peer) {
send("allowed_ip=" + ip.String())
}
}()
}
// send lines
for _, line := range lines {
device.log.Debug.Println("config:", line)
_, err := socket.WriteString(line + "\n")
if err != nil {
return err
}
}
return nil
}
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
@ -179,7 +226,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return nil
}
func ipcListen(dev *Device, socket io.ReadWriter) error {
func ipcListen(device *Device, socket io.ReadWriter) error {
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
reader := bufio.NewReader(s)
@ -187,6 +234,8 @@ func ipcListen(dev *Device, socket io.ReadWriter) error {
return bufio.NewReadWriter(reader, writer)
}(socket)
defer buffered.Flush()
for {
op, err := buffered.ReadString('\n')
if err != nil {
@ -197,17 +246,26 @@ func ipcListen(dev *Device, socket io.ReadWriter) error {
switch op {
case "set=1\n":
err := ipcSetOperation(dev, buffered)
err := ipcSetOperation(device, buffered)
if err != nil {
fmt.Fprintf(buffered, "errno=%d\n", err.ErrorCode())
fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
return err
} else {
fmt.Fprintf(buffered, "errno=0\n")
fmt.Fprintf(buffered, "errno=0\n\n")
}
buffered.Flush()
case "get=1\n":
err := ipcGetOperation(device, buffered)
if err != nil {
fmt.Fprintf(buffered, "errno=1\n\n") // fix
return err
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
}
buffered.Flush()
case "\n":
default:
return errors.New("handle this please")
}

View file

@ -5,12 +5,17 @@ import (
)
const (
RekeyAfterMessage = (1 << 64) - (1 << 16) - 1
RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5
RejectAfterTime = time.Second * 180
RejectAfterMessage = (1 << 64) - (1 << 4) - 1
KeepaliveTimeout = time.Second * 10
CookieRefreshTime = time.Second * 2
RekeyAfterMessage = (1 << 64) - (1 << 16) - 1
RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5 // TODO: Exponential backoff
RejectAfterTime = time.Second * 180
RejectAfterMessage = (1 << 64) - (1 << 4) - 1
KeepaliveTimeout = time.Second * 10
CookieRefreshTime = time.Second * 2
MaxHandshakeAttempTime = time.Second * 90
)
const (
QueueOutboundSize = 1024
)

View file

@ -2,23 +2,26 @@ package main
import (
"net"
"runtime"
"sync"
)
type Device struct {
mtu int
fwMark uint32
address *net.UDPAddr // UDP source address
conn *net.UDPConn // UDP "connection"
mutex sync.RWMutex
privateKey NoisePrivateKey
publicKey NoisePublicKey
routingTable RoutingTable
indices IndexTable
log *Logger
queueWorkOutbound chan *OutboundWorkQueueElement
peers map[NoisePublicKey]*Peer
mac MacStateDevice
mtu int
fwMark uint32
address *net.UDPAddr // UDP source address
conn *net.UDPConn // UDP "connection"
mutex sync.RWMutex
privateKey NoisePrivateKey
publicKey NoisePublicKey
routingTable RoutingTable
indices IndexTable
log *Logger
queue struct {
encryption chan *QueueOutboundElement // parallel work queue
}
peers map[NoisePublicKey]*Peer
mac MacStateDevice
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
@ -41,7 +44,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
}
}
func (device *Device) Init() {
func NewDevice(tun TUNDevice) *Device {
device := new(Device)
device.mutex.Lock()
defer device.mutex.Unlock()
@ -49,6 +54,14 @@ func (device *Device) Init() {
device.peers = make(map[NoisePublicKey]*Peer)
device.indices.Init()
device.routingTable.Reset()
// start workers
for i := 0; i < runtime.NumCPU(); i += 1 {
go device.RoutineEncryption()
}
go device.RoutineReadFromTUN(tun)
return device
}
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {

172
src/handshake.go Normal file
View file

@ -0,0 +1,172 @@
package main
import (
"bytes"
"encoding/binary"
"net"
"sync/atomic"
"time"
)
/* Sends a keep-alive if no packets queued for peer
*
* Used by initiator of handshake and with active keep-alive
*/
func (peer *Peer) SendKeepAlive() bool {
if len(peer.queue.nonce) == 0 {
select {
case peer.queue.nonce <- []byte{}:
return true
default:
return false
}
}
return true
}
func (peer *Peer) RoutineHandshakeInitiator() {
var ongoing bool
var begun time.Time
var attempts uint
var timeout time.Timer
device := peer.device
work := new(QueueOutboundElement)
buffer := make([]byte, 0, 1024)
queueHandshakeInitiation := func() error {
work.mutex.Lock()
defer work.mutex.Unlock()
// create initiation
msg, err := device.CreateMessageInitiation(peer)
if err != nil {
return err
}
// create "work" element
writer := bytes.NewBuffer(buffer[:0])
binary.Write(writer, binary.LittleEndian, &msg)
work.packet = writer.Bytes()
peer.mac.AddMacs(work.packet)
peer.InsertOutbound(work)
return nil
}
for {
select {
case <-peer.signal.stopInitiator:
return
case <-peer.signal.newHandshake:
if ongoing {
continue
}
// create handshake
err := queueHandshakeInitiation()
if err != nil {
device.log.Error.Println("Failed to create initiation message:", err)
}
// log when we began
begun = time.Now()
ongoing = true
attempts = 0
timeout.Reset(RekeyTimeout)
case <-peer.timer.sendKeepalive.C:
// active keep-alives
peer.SendKeepAlive()
case <-peer.timer.handshakeTimeout.C:
// check if we can stop trying
if time.Now().Sub(begun) > MaxHandshakeAttempTime {
peer.signal.flushNonceQueue <- true
peer.timer.sendKeepalive.Stop()
ongoing = false
continue
}
// otherwise, try again (exponental backoff)
attempts += 1
err := queueHandshakeInitiation()
if err != nil {
device.log.Error.Println("Failed to create initiation message:", err)
}
peer.timer.handshakeTimeout.Reset((1 << attempts) * RekeyTimeout)
}
}
}
/* Handles packets related to handshake
*
*
*/
func (device *Device) HandshakeWorker(queue chan struct {
msg []byte
msgType uint32
addr *net.UDPAddr
}) {
for {
elem := <-queue
switch elem.msgType {
case MessageInitiationType:
if len(elem.msg) != MessageInitiationSize {
continue
}
// check for cookie
var msg MessageInitiation
binary.Read(nil, binary.LittleEndian, &msg)
case MessageResponseType:
if len(elem.msg) != MessageResponseSize {
continue
}
// check for cookie
case MessageCookieReplyType:
case MessageTransportType:
}
}
}
func (device *Device) KeepKeyFresh(peer *Peer) {
send := func() bool {
peer.keyPairs.mutex.RLock()
defer peer.keyPairs.mutex.RUnlock()
kp := peer.keyPairs.current
if kp == nil {
return false
}
nonce := atomic.LoadUint64(&kp.sendNonce)
if nonce > RekeyAfterMessage {
return true
}
return kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime
}()
if send {
}
}

64
src/helper_test.go Normal file
View file

@ -0,0 +1,64 @@
package main
import (
"bytes"
"testing"
)
/* Helpers for writing unit tests
*/
type DummyTUN struct {
name string
mtu uint
packets chan []byte
}
func (tun *DummyTUN) Name() string {
return tun.name
}
func (tun *DummyTUN) MTU() uint {
return tun.mtu
}
func (tun *DummyTUN) Write(d []byte) (int, error) {
tun.packets <- d
return len(d), nil
}
func (tun *DummyTUN) Read(d []byte) (int, error) {
t := <-tun.packets
copy(d, t)
return len(t), nil
}
func CreateDummyTUN(name string) (TUNDevice, error) {
var dummy DummyTUN
dummy.mtu = 1024
dummy.packets = make(chan []byte, 100)
return &dummy, nil
}
func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
}
}
func assertEqual(t *testing.T, a []byte, b []byte) {
if bytes.Compare(a, b) != 0 {
t.Fatal(a, "!=", b)
}
}
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
tun, _ := CreateDummyTUN("dummy")
device := NewDevice(tun)
device.SetPrivateKey(sk)
return device
}

View file

@ -5,9 +5,10 @@ import (
)
const (
IPv4version = 4
IPv4offsetSrc = 12
IPv4offsetDst = IPv4offsetSrc + net.IPv4len
IPv4version = 4
IPv4offsetSrc = 12
IPv4offsetDst = IPv4offsetSrc + net.IPv4len
IPv4headerSize = 20
)
const (

View file

@ -8,8 +8,8 @@ import (
)
func TestMAC1(t *testing.T) {
dev1 := newDevice(t)
dev2 := newDevice(t)
dev1 := randDevice(t)
dev2 := randDevice(t)
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
@ -34,12 +34,10 @@ func TestMACs(t *testing.T) {
msg []byte,
receiver uint32,
) bool {
var device1 Device
device1.Init()
device1 := randDevice(t)
device1.SetPrivateKey(sk1)
var device2 Device
device2.Init()
device2 := randDevice(t)
device2.SetPrivateKey(sk2)
peer1 := device2.NewPeer(device1.privateKey.publicKey())

View file

@ -1,36 +1,30 @@
package main
import (
"fmt"
)
func main() {
fd, err := CreateTUN("test0")
fmt.Println(fd, err)
queue := make(chan []byte, 1000)
// var device Device
// go OutgoingRoutingWorker(&device, queue)
for {
tmp := make([]byte, 1<<16)
n, err := fd.Read(tmp)
if err != nil {
break
}
queue <- tmp[:n]
}
}
/*
import (
"fmt"
"log"
"net"
)
/*
*
* TODO: Fix logging
*/
func main() {
// Open TUN device
// TODO: Fix capabilities
tun, err := CreateTUN("test0")
log.Println(tun, err)
if err != nil {
return
}
device := NewDevice(tun)
// Start configuration lister
l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
if err != nil {
log.Fatal("listen error:", err)
@ -41,12 +35,9 @@ func main() {
if err != nil {
log.Fatal("accept error:", err)
}
var dev Device
go func(conn net.Conn) {
err := ipcListen(&dev, conn)
fmt.Println(err)
err := ipcListen(device, conn)
log.Println(err)
}(fd)
}
}
*/

View file

@ -77,7 +77,7 @@ type MessageCookieReply struct {
type Handshake struct {
state int
mutex sync.Mutex
mutex sync.RWMutex
hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key
presharedKey NoiseSymmetricKey // psk
@ -205,49 +205,64 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
}
hash = mixHash(hash, msg.Static[:])
// find peer
// lookup peer
peer := device.LookupPeer(peerPK)
if peer == nil {
return nil
}
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
// decrypt timestamp
// verify identity
var timestamp TAI64N
func() {
var key [chacha20poly1305.KeySize]byte
chainKey, key = KDF2(
chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
ok := func() bool {
// read lock handshake
handshake.mutex.RLock()
defer handshake.mutex.RUnlock()
// decrypt timestamp
func() {
var key [chacha20poly1305.KeySize]byte
chainKey, key = KDF2(
chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
}()
if err != nil {
return false
}
hash = mixHash(hash, msg.Timestamp[:])
// TODO: check for flood attack
// check for replay attack
return timestamp.After(handshake.lastTimestamp)
}()
if err != nil {
if !ok {
return nil
}
hash = mixHash(hash, msg.Timestamp[:])
// check for replay attack
if !timestamp.After(handshake.lastTimestamp) {
return nil
}
// TODO: check for flood attack
// update handshake state
handshake.mutex.Lock()
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral
handshake.lastTimestamp = timestamp
handshake.state = HandshakeInitiationConsumed
handshake.mutex.Unlock()
return peer
}
@ -320,47 +335,67 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
return nil
}
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
if handshake.state != HandshakeInitiationCreated {
return nil
}
var (
hash [blake2s.Size]byte
chainKey [blake2s.Size]byte
)
// finish 3-way DH
ok := func() bool {
hash := mixHash(handshake.hash, msg.Ephemeral[:])
chainKey := handshake.chainKey
// read lock handshake
func() {
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
chainKey = mixKey(chainKey, ss[:])
ss = device.privateKey.sharedSecret(msg.Ephemeral)
chainKey = mixKey(chainKey, ss[:])
handshake.mutex.RLock()
defer handshake.mutex.RUnlock()
if handshake.state != HandshakeInitiationCreated {
return false
}
// finish 3-way DH
hash = mixHash(handshake.hash, msg.Ephemeral[:])
chainKey = handshake.chainKey
func() {
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
chainKey = mixKey(chainKey, ss[:])
ss = device.privateKey.sharedSecret(msg.Ephemeral)
chainKey = mixKey(chainKey, ss[:])
}()
// add preshared key (psk)
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
hash = mixHash(hash, tau[:])
// authenticate
aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
return false
}
hash = mixHash(hash, msg.Empty[:])
return true
}()
// add preshared key (psk)
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
hash = mixHash(hash, tau[:])
// authenticate
aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
if !ok {
return nil
}
hash = mixHash(hash, msg.Empty[:])
// update handshake state
handshake.mutex.Lock()
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.state = HandshakeResponseConsumed
handshake.mutex.Unlock()
return lookup.peer
}

View file

@ -6,29 +6,6 @@ import (
"testing"
)
func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
}
}
func assertEqual(t *testing.T, a []byte, b []byte) {
if bytes.Compare(a, b) != 0 {
t.Fatal(a, "!=", b)
}
}
func newDevice(t *testing.T) *Device {
var device Device
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
device.Init()
device.SetPrivateKey(sk)
return &device
}
func TestCurveWrappers(t *testing.T) {
sk1, err := newPrivateKey()
assertNil(t, err)
@ -49,8 +26,8 @@ func TestCurveWrappers(t *testing.T) {
func TestNoiseHandshake(t *testing.T) {
dev1 := newDevice(t)
dev2 := newDevice(t)
dev1 := randDevice(t)
dev2 := randDevice(t)
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())

View file

@ -3,18 +3,18 @@ package main
import (
"encoding/hex"
"errors"
"golang.org/x/crypto/chacha20poly1305"
)
const (
NoisePublicKeySize = 32
NoisePrivateKeySize = 32
NoiseSymmetricKeySize = 32
NoisePublicKeySize = 32
NoisePrivateKeySize = 32
)
type (
NoisePublicKey [NoisePublicKeySize]byte
NoisePrivateKey [NoisePrivateKeySize]byte
NoiseSymmetricKey [NoiseSymmetricKeySize]byte
NoiseSymmetricKey [chacha20poly1305.KeySize]byte
NoiseNonce uint64 // padded to 12-bytes
)
@ -30,6 +30,15 @@ func loadExactHex(dst []byte, src string) error {
return nil
}
func (key NoisePrivateKey) IsZero() bool {
for _, b := range key[:] {
if b != 0 {
return false
}
}
return true
}
func (key *NoisePrivateKey) FromHex(src string) error {
return loadExactHex(key[:], src)
}

View file

@ -7,9 +7,7 @@ import (
"time"
)
const (
OutboundQueueSize = 64
)
const ()
type Peer struct {
mutex sync.RWMutex
@ -18,10 +16,26 @@ type Peer struct {
keyPairs KeyPairs
handshake Handshake
device *Device
queueInbound chan []byte
queueOutbound chan *OutboundWorkQueueElement
queueOutboundRouting chan []byte
mac MacStatePeer
tx_bytes uint64
rx_bytes uint64
time struct {
lastSend time.Time // last send message
}
signal struct {
newHandshake chan bool
flushNonceQueue chan bool // empty queued packets
stopSending chan bool // stop sending pipeline
stopInitiator chan bool // stop initiator timer
}
timer struct {
sendKeepalive time.Timer
handshakeTimeout time.Timer
}
queue struct {
nonce chan []byte // nonce / pre-handshake queue
outbound chan *QueueOutboundElement // sequential ordering of work
}
mac MacStatePeer
}
func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
@ -33,7 +47,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
peer.device = device
peer.keyPairs.Init()
peer.mac.Init(pk)
peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.nonce = make(chan []byte, QueueOutboundSize)
// map public key
@ -54,5 +69,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
handshake.mutex.Unlock()
peer.mutex.Unlock()
// start workers
peer.signal.stopSending = make(chan bool, 1)
peer.signal.stopInitiator = make(chan bool, 1)
peer.signal.newHandshake = make(chan bool, 1)
peer.signal.flushNonceQueue = make(chan bool, 1)
go peer.RoutineNonce()
go peer.RoutineHandshakeInitiator()
return &peer
}
func (peer *Peer) Close() {
peer.signal.stopSending <- true
peer.signal.stopInitiator <- true
}

View file

@ -12,9 +12,20 @@ type RoutingTable struct {
mutex sync.RWMutex
}
func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet {
table.mutex.RLock()
defer table.mutex.RUnlock()
allowed := make([]net.IPNet, 10)
table.IPv4.AllowedIPs(peer, allowed)
table.IPv6.AllowedIPs(peer, allowed)
return allowed
}
func (table *RoutingTable) Reset() {
table.mutex.Lock()
defer table.mutex.Unlock()
table.IPv4 = nil
table.IPv6 = nil
}
@ -22,6 +33,7 @@ func (table *RoutingTable) Reset() {
func (table *RoutingTable) RemovePeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
table.IPv4 = table.IPv4.RemovePeer(peer)
table.IPv6 = table.IPv6.RemovePeer(peer)
}

View file

@ -5,107 +5,159 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"net"
"sync"
"time"
)
/* Handles outbound flow
*
* 1. TUN queue
* 2. Routing
* 3. Per peer queuing
* 4. (work queuing)
* 2. Routing (sequential)
* 3. Nonce assignment (sequential)
* 4. Encryption (parallel)
* 5. Transmission (sequential)
*
* The order of packets (per peer) is maintained.
* The functions in this file occure (roughly) in the order packets are processed.
*/
type OutboundWorkQueueElement struct {
wg sync.WaitGroup
/* A work unit
*
* The sequential consumers will attempt to take the lock,
* workers release lock when they have completed work on the packet.
*/
type QueueOutboundElement struct {
mutex sync.Mutex
packet []byte
nonce uint64
keyPair *KeyPair
}
func (peer *Peer) HandshakeWorker(handshakeQueue []byte) {
func (peer *Peer) FlushNonceQueue() {
elems := len(peer.queue.nonce)
for i := 0; i < elems; i += 1 {
select {
case <-peer.queue.nonce:
default:
return
}
}
}
func (device *Device) SendPacket(packet []byte) {
// lookup peer
var peer *Peer
switch packet[0] >> 4 {
case IPv4version:
dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routingTable.LookupIPv4(dst)
case IPv6version:
dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routingTable.LookupIPv6(dst)
default:
device.log.Debug.Println("receieved packet with unknown IP version")
return
}
if peer == nil {
return
}
// insert into peer queue
func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
for {
select {
case peer.queueOutboundRouting <- packet:
case peer.queue.outbound <- elem:
default:
select {
case <-peer.queueOutboundRouting:
case <-peer.queue.outbound:
default:
}
continue
}
break
}
}
/* Go routine
/* Reads packets from the TUN and inserts
* into nonce queue for peer
*
*
* 1. waits for handshake.
* 2. assigns key pair & nonce
* 3. inserts to working queue
*
* TODO: avoid dynamic allocation of work queue elements
* Obs. Single instance per TUN device
*/
func (peer *Peer) RoutineOutboundNonceWorker() {
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
for {
// read packet
packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
size, err := tun.Read(packet)
if err != nil {
device.log.Error.Println("Failed to read packet from TUN device:", err)
continue
}
packet = packet[:size]
if len(packet) < IPv4headerSize {
device.log.Error.Println("Packet too short, length:", len(packet))
continue
}
device.log.Debug.Println("New packet on TUN:", packet) // TODO: Slow debugging, remove.
// lookup peer
var peer *Peer
switch packet[0] >> 4 {
case IPv4version:
dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routingTable.LookupIPv4(dst)
case IPv6version:
dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routingTable.LookupIPv6(dst)
default:
device.log.Debug.Println("Receieved packet with unknown IP version")
return
}
if peer == nil {
device.log.Debug.Println("No peer configured for IP")
return
}
// insert into nonce/pre-handshake queue
for {
select {
case peer.queue.nonce <- packet:
default:
select {
case <-peer.queue.nonce:
default:
}
continue
}
break
}
}
}
/* Queues packets when there is no handshake.
* Then assigns nonces to packets sequentially
* and creates "work" structs for workers
*
* TODO: Avoid dynamic allocation of work queue elements
*
* Obs. A single instance per peer
*/
func (peer *Peer) RoutineNonce() {
var packet []byte
var keyPair *KeyPair
var flushTimer time.Timer
for {
// wait for packet
if packet == nil {
packet = <-peer.queueOutboundRouting
select {
case packet = <-peer.queue.nonce:
case <-peer.signal.stopSending:
close(peer.queue.outbound)
return
}
}
// wait for key pair
for keyPair == nil {
flushTimer.Reset(time.Second * 10)
// TODO: Handshake or NOP
peer.signal.newHandshake <- true
select {
case <-peer.keyPairs.newKeyPair:
keyPair = peer.keyPairs.Current()
continue
case <-flushTimer.C:
size := len(peer.queueOutboundRouting)
for i := 0; i < size; i += 1 {
<-peer.queueOutboundRouting
}
case <-peer.signal.flushNonceQueue:
peer.FlushNonceQueue()
packet = nil
continue
case <-peer.signal.stopSending:
close(peer.queue.outbound)
return
}
break
}
// process current packet
@ -114,14 +166,13 @@ func (peer *Peer) RoutineOutboundNonceWorker() {
// create work element
work := new(OutboundWorkQueueElement)
work.wg.Add(1)
work := new(QueueOutboundElement) // TODO: profile, maybe use pool
work.keyPair = keyPair
work.packet = packet
work.nonce = keyPair.sendNonce
work.mutex.Lock()
packet = nil
peer.queueOutbound <- work
keyPair.sendNonce += 1
// drop packets until there is space
@ -129,46 +180,36 @@ func (peer *Peer) RoutineOutboundNonceWorker() {
func() {
for {
select {
case peer.device.queueWorkOutbound <- work:
case peer.device.queue.encryption <- work:
return
default:
drop := <-peer.device.queueWorkOutbound
drop := <-peer.device.queue.encryption
drop.packet = nil
drop.wg.Done()
drop.mutex.Unlock()
}
}
}()
peer.queue.outbound <- work
}
}
}
/* Go routine
*
* sequentially reads packets from queue and sends to endpoint
/* Encrypts the elements in the queue
* and marks them for sequential consumption (by releasing the mutex)
*
* Obs. One instance per core
*/
func (peer *Peer) RoutineSequential() {
for work := range peer.queueOutbound {
work.wg.Wait()
if work.packet == nil {
continue
}
if peer.endpoint == nil {
continue
}
peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
}
}
func (device *Device) RoutineEncryptionWorker() {
func (device *Device) RoutineEncryption() {
var nonce [chacha20poly1305.NonceSize]byte
for work := range device.queueWorkOutbound {
for work := range device.queue.encryption {
// pad packet
padding := device.mtu - len(work.packet)
if padding < 0 {
// drop
work.packet = nil
work.wg.Done()
work.mutex.Unlock()
}
for n := 0; n < padding; n += 1 {
work.packet = append(work.packet, 0)
@ -183,6 +224,30 @@ func (device *Device) RoutineEncryptionWorker() {
work.packet,
nil,
)
work.wg.Done()
work.mutex.Unlock()
}
}
/* Sequentially reads packets from queue and sends to endpoint
*
* Obs. Single instance per peer.
* The routine terminates then the outbound queue is closed.
*/
func (peer *Peer) RoutineSequential() {
for work := range peer.queue.outbound {
work.mutex.Lock()
func() {
peer.mutex.RLock()
defer peer.mutex.RUnlock()
if work.packet == nil {
return
}
if peer.endpoint == nil {
return
}
peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
peer.timer.sendKeepalive.Reset(peer.persistentKeepaliveInterval)
}()
work.mutex.Unlock()
}
}

View file

@ -1,15 +1,20 @@
package main
import (
"errors"
"net"
)
/* Binary trie
*
* The net.IPs used here are not formatted the
* same way as those created by the "net" functions.
* Here the IPs are slices of either 4 or 16 byte (not always 16)
*
* Syncronization done seperatly
* See: routing.go
*
* Todo: Better commenting
* TODO: Better commenting
*/
type Trie struct {
@ -24,7 +29,7 @@ type Trie struct {
}
/* Finds length of matching prefix
* Maybe there is a faster way
* TODO: Make faster
*
* Assumption: len(ip1) == len(ip2)
*/
@ -189,3 +194,25 @@ func (node *Trie) Count() uint {
r := node.child[1].Count()
return l + r
}
func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) {
if node.peer == p {
var mask net.IPNet
mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
if len(node.bits) == net.IPv4len {
mask.IP = net.IPv4(
node.bits[0],
node.bits[1],
node.bits[2],
node.bits[3],
)
} else if len(node.bits) == net.IPv6len {
mask.IP = node.bits
} else {
panic(errors.New("bug: unexpected address length"))
}
results = append(results, mask)
}
node.child[0].AllowedIPs(p, results)
node.child[1].AllowedIPs(p, results)
}

View file

@ -1,6 +1,6 @@
package main
type TUN interface {
type TUNDevice interface {
Read([]byte) (int, error)
Write([]byte) (int, error)
Name() string

View file

@ -9,9 +9,7 @@ import (
"unsafe"
)
/* Platform dependent functions for interacting with
* TUN devices on linux systems
*
/* Implementation of the TUN device interface for linux
*/
const CloneDevicePath = "/dev/net/tun"
@ -45,7 +43,7 @@ func (tun *NativeTun) Read(d []byte) (int, error) {
return tun.fd.Read(d)
}
func CreateTUN(name string) (TUN, error) {
func CreateTUN(name string) (TUNDevice, error) {
// Open clone device
fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0)
if err != nil {
@ -53,7 +51,7 @@ func CreateTUN(name string) (TUN, error) {
}
// Prepare ifreq struct
var ifr [18]byte
var ifr [128]byte
var flags uint16 = IFF_TUN | IFF_NO_PI
nameBytes := []byte(name)
if len(nameBytes) >= IFNAMSIZ {