Work on UAPI
Cross-platform API (get operation) Handshake initiation creation process Outbound packet flow Fixes from code-review
This commit is contained in:
parent
8236f3afa2
commit
1f0976a26c
9
src/Makefile
Normal file
9
src/Makefile
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
BINARY=wireguard-go
|
||||||
|
|
||||||
|
build:
|
||||||
|
go build -o ${BINARY}
|
||||||
|
|
||||||
|
clean:
|
||||||
|
if [ -f ${BINARY} ]; then rm ${BINARY}; fi
|
||||||
|
|
||||||
|
.PHONY: clean
|
|
@ -11,7 +11,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* todo : use real error code
|
/* TODO : use real error code
|
||||||
* Many of which will be the same
|
* Many of which will be the same
|
||||||
*/
|
*/
|
||||||
const (
|
const (
|
||||||
|
@ -37,8 +37,55 @@ func (s *IPCError) ErrorCode() int {
|
||||||
return s.Code
|
return s.Code
|
||||||
}
|
}
|
||||||
|
|
||||||
func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) {
|
func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
|
||||||
|
|
||||||
|
device.mutex.RLock()
|
||||||
|
defer device.mutex.RUnlock()
|
||||||
|
|
||||||
|
// create lines
|
||||||
|
|
||||||
|
lines := make([]string, 0, 100)
|
||||||
|
send := func(line string) {
|
||||||
|
lines = append(lines, line)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !device.privateKey.IsZero() {
|
||||||
|
send("private_key=" + device.privateKey.ToHex())
|
||||||
|
}
|
||||||
|
|
||||||
|
if device.address != nil {
|
||||||
|
send(fmt.Sprintf("listen_port=%d", device.address.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range device.peers {
|
||||||
|
func() {
|
||||||
|
peer.mutex.RLock()
|
||||||
|
defer peer.mutex.RUnlock()
|
||||||
|
send("public_key=" + peer.handshake.remoteStatic.ToHex())
|
||||||
|
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
|
||||||
|
if peer.endpoint != nil {
|
||||||
|
send("endpoint=" + peer.endpoint.String())
|
||||||
|
}
|
||||||
|
send(fmt.Sprintf("tx_bytes=%d", peer.tx_bytes))
|
||||||
|
send(fmt.Sprintf("rx_bytes=%d", peer.rx_bytes))
|
||||||
|
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
|
||||||
|
for _, ip := range device.routingTable.AllowedIPs(peer) {
|
||||||
|
send("allowed_ip=" + ip.String())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// send lines
|
||||||
|
|
||||||
|
for _, line := range lines {
|
||||||
|
device.log.Debug.Println("config:", line)
|
||||||
|
_, err := socket.WriteString(line + "\n")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||||
|
@ -179,7 +226,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ipcListen(dev *Device, socket io.ReadWriter) error {
|
func ipcListen(device *Device, socket io.ReadWriter) error {
|
||||||
|
|
||||||
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
|
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
|
||||||
reader := bufio.NewReader(s)
|
reader := bufio.NewReader(s)
|
||||||
|
@ -187,6 +234,8 @@ func ipcListen(dev *Device, socket io.ReadWriter) error {
|
||||||
return bufio.NewReadWriter(reader, writer)
|
return bufio.NewReadWriter(reader, writer)
|
||||||
}(socket)
|
}(socket)
|
||||||
|
|
||||||
|
defer buffered.Flush()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
op, err := buffered.ReadString('\n')
|
op, err := buffered.ReadString('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -197,17 +246,26 @@ func ipcListen(dev *Device, socket io.ReadWriter) error {
|
||||||
switch op {
|
switch op {
|
||||||
|
|
||||||
case "set=1\n":
|
case "set=1\n":
|
||||||
err := ipcSetOperation(dev, buffered)
|
err := ipcSetOperation(device, buffered)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(buffered, "errno=%d\n", err.ErrorCode())
|
fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(buffered, "errno=0\n")
|
fmt.Fprintf(buffered, "errno=0\n\n")
|
||||||
}
|
}
|
||||||
buffered.Flush()
|
buffered.Flush()
|
||||||
|
|
||||||
case "get=1\n":
|
case "get=1\n":
|
||||||
|
err := ipcGetOperation(device, buffered)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(buffered, "errno=1\n\n") // fix
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(buffered, "errno=0\n\n")
|
||||||
|
}
|
||||||
|
buffered.Flush()
|
||||||
|
|
||||||
|
case "\n":
|
||||||
default:
|
default:
|
||||||
return errors.New("handle this please")
|
return errors.New("handle this please")
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,9 +8,14 @@ const (
|
||||||
RekeyAfterMessage = (1 << 64) - (1 << 16) - 1
|
RekeyAfterMessage = (1 << 64) - (1 << 16) - 1
|
||||||
RekeyAfterTime = time.Second * 120
|
RekeyAfterTime = time.Second * 120
|
||||||
RekeyAttemptTime = time.Second * 90
|
RekeyAttemptTime = time.Second * 90
|
||||||
RekeyTimeout = time.Second * 5
|
RekeyTimeout = time.Second * 5 // TODO: Exponential backoff
|
||||||
RejectAfterTime = time.Second * 180
|
RejectAfterTime = time.Second * 180
|
||||||
RejectAfterMessage = (1 << 64) - (1 << 4) - 1
|
RejectAfterMessage = (1 << 64) - (1 << 4) - 1
|
||||||
KeepaliveTimeout = time.Second * 10
|
KeepaliveTimeout = time.Second * 10
|
||||||
CookieRefreshTime = time.Second * 2
|
CookieRefreshTime = time.Second * 2
|
||||||
|
MaxHandshakeAttempTime = time.Second * 90
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
QueueOutboundSize = 1024
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,6 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,7 +17,9 @@ type Device struct {
|
||||||
routingTable RoutingTable
|
routingTable RoutingTable
|
||||||
indices IndexTable
|
indices IndexTable
|
||||||
log *Logger
|
log *Logger
|
||||||
queueWorkOutbound chan *OutboundWorkQueueElement
|
queue struct {
|
||||||
|
encryption chan *QueueOutboundElement // parallel work queue
|
||||||
|
}
|
||||||
peers map[NoisePublicKey]*Peer
|
peers map[NoisePublicKey]*Peer
|
||||||
mac MacStateDevice
|
mac MacStateDevice
|
||||||
}
|
}
|
||||||
|
@ -41,7 +44,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) Init() {
|
func NewDevice(tun TUNDevice) *Device {
|
||||||
|
device := new(Device)
|
||||||
|
|
||||||
device.mutex.Lock()
|
device.mutex.Lock()
|
||||||
defer device.mutex.Unlock()
|
defer device.mutex.Unlock()
|
||||||
|
|
||||||
|
@ -49,6 +54,14 @@ func (device *Device) Init() {
|
||||||
device.peers = make(map[NoisePublicKey]*Peer)
|
device.peers = make(map[NoisePublicKey]*Peer)
|
||||||
device.indices.Init()
|
device.indices.Init()
|
||||||
device.routingTable.Reset()
|
device.routingTable.Reset()
|
||||||
|
|
||||||
|
// start workers
|
||||||
|
|
||||||
|
for i := 0; i < runtime.NumCPU(); i += 1 {
|
||||||
|
go device.RoutineEncryption()
|
||||||
|
}
|
||||||
|
go device.RoutineReadFromTUN(tun)
|
||||||
|
return device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
||||||
|
|
172
src/handshake.go
Normal file
172
src/handshake.go
Normal file
|
@ -0,0 +1,172 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Sends a keep-alive if no packets queued for peer
|
||||||
|
*
|
||||||
|
* Used by initiator of handshake and with active keep-alive
|
||||||
|
*/
|
||||||
|
func (peer *Peer) SendKeepAlive() bool {
|
||||||
|
if len(peer.queue.nonce) == 0 {
|
||||||
|
select {
|
||||||
|
case peer.queue.nonce <- []byte{}:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) RoutineHandshakeInitiator() {
|
||||||
|
var ongoing bool
|
||||||
|
var begun time.Time
|
||||||
|
var attempts uint
|
||||||
|
var timeout time.Timer
|
||||||
|
|
||||||
|
device := peer.device
|
||||||
|
work := new(QueueOutboundElement)
|
||||||
|
buffer := make([]byte, 0, 1024)
|
||||||
|
|
||||||
|
queueHandshakeInitiation := func() error {
|
||||||
|
work.mutex.Lock()
|
||||||
|
defer work.mutex.Unlock()
|
||||||
|
|
||||||
|
// create initiation
|
||||||
|
|
||||||
|
msg, err := device.CreateMessageInitiation(peer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// create "work" element
|
||||||
|
|
||||||
|
writer := bytes.NewBuffer(buffer[:0])
|
||||||
|
binary.Write(writer, binary.LittleEndian, &msg)
|
||||||
|
work.packet = writer.Bytes()
|
||||||
|
peer.mac.AddMacs(work.packet)
|
||||||
|
peer.InsertOutbound(work)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-peer.signal.stopInitiator:
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-peer.signal.newHandshake:
|
||||||
|
if ongoing {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// create handshake
|
||||||
|
|
||||||
|
err := queueHandshakeInitiation()
|
||||||
|
if err != nil {
|
||||||
|
device.log.Error.Println("Failed to create initiation message:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// log when we began
|
||||||
|
|
||||||
|
begun = time.Now()
|
||||||
|
ongoing = true
|
||||||
|
attempts = 0
|
||||||
|
timeout.Reset(RekeyTimeout)
|
||||||
|
|
||||||
|
case <-peer.timer.sendKeepalive.C:
|
||||||
|
|
||||||
|
// active keep-alives
|
||||||
|
|
||||||
|
peer.SendKeepAlive()
|
||||||
|
|
||||||
|
case <-peer.timer.handshakeTimeout.C:
|
||||||
|
|
||||||
|
// check if we can stop trying
|
||||||
|
|
||||||
|
if time.Now().Sub(begun) > MaxHandshakeAttempTime {
|
||||||
|
peer.signal.flushNonceQueue <- true
|
||||||
|
peer.timer.sendKeepalive.Stop()
|
||||||
|
ongoing = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// otherwise, try again (exponental backoff)
|
||||||
|
|
||||||
|
attempts += 1
|
||||||
|
err := queueHandshakeInitiation()
|
||||||
|
if err != nil {
|
||||||
|
device.log.Error.Println("Failed to create initiation message:", err)
|
||||||
|
}
|
||||||
|
peer.timer.handshakeTimeout.Reset((1 << attempts) * RekeyTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Handles packets related to handshake
|
||||||
|
*
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
func (device *Device) HandshakeWorker(queue chan struct {
|
||||||
|
msg []byte
|
||||||
|
msgType uint32
|
||||||
|
addr *net.UDPAddr
|
||||||
|
}) {
|
||||||
|
for {
|
||||||
|
elem := <-queue
|
||||||
|
|
||||||
|
switch elem.msgType {
|
||||||
|
case MessageInitiationType:
|
||||||
|
if len(elem.msg) != MessageInitiationSize {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// check for cookie
|
||||||
|
|
||||||
|
var msg MessageInitiation
|
||||||
|
|
||||||
|
binary.Read(nil, binary.LittleEndian, &msg)
|
||||||
|
|
||||||
|
case MessageResponseType:
|
||||||
|
if len(elem.msg) != MessageResponseSize {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// check for cookie
|
||||||
|
|
||||||
|
case MessageCookieReplyType:
|
||||||
|
|
||||||
|
case MessageTransportType:
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) KeepKeyFresh(peer *Peer) {
|
||||||
|
|
||||||
|
send := func() bool {
|
||||||
|
peer.keyPairs.mutex.RLock()
|
||||||
|
defer peer.keyPairs.mutex.RUnlock()
|
||||||
|
|
||||||
|
kp := peer.keyPairs.current
|
||||||
|
if kp == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce := atomic.LoadUint64(&kp.sendNonce)
|
||||||
|
if nonce > RekeyAfterMessage {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime
|
||||||
|
}()
|
||||||
|
|
||||||
|
if send {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
64
src/helper_test.go
Normal file
64
src/helper_test.go
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Helpers for writing unit tests
|
||||||
|
*/
|
||||||
|
|
||||||
|
type DummyTUN struct {
|
||||||
|
name string
|
||||||
|
mtu uint
|
||||||
|
packets chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *DummyTUN) Name() string {
|
||||||
|
return tun.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *DummyTUN) MTU() uint {
|
||||||
|
return tun.mtu
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *DummyTUN) Write(d []byte) (int, error) {
|
||||||
|
tun.packets <- d
|
||||||
|
return len(d), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *DummyTUN) Read(d []byte) (int, error) {
|
||||||
|
t := <-tun.packets
|
||||||
|
copy(d, t)
|
||||||
|
return len(t), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateDummyTUN(name string) (TUNDevice, error) {
|
||||||
|
var dummy DummyTUN
|
||||||
|
dummy.mtu = 1024
|
||||||
|
dummy.packets = make(chan []byte, 100)
|
||||||
|
return &dummy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertNil(t *testing.T, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertEqual(t *testing.T, a []byte, b []byte) {
|
||||||
|
if bytes.Compare(a, b) != 0 {
|
||||||
|
t.Fatal(a, "!=", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func randDevice(t *testing.T) *Device {
|
||||||
|
sk, err := newPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tun, _ := CreateDummyTUN("dummy")
|
||||||
|
device := NewDevice(tun)
|
||||||
|
device.SetPrivateKey(sk)
|
||||||
|
return device
|
||||||
|
}
|
|
@ -8,6 +8,7 @@ const (
|
||||||
IPv4version = 4
|
IPv4version = 4
|
||||||
IPv4offsetSrc = 12
|
IPv4offsetSrc = 12
|
||||||
IPv4offsetDst = IPv4offsetSrc + net.IPv4len
|
IPv4offsetDst = IPv4offsetSrc + net.IPv4len
|
||||||
|
IPv4headerSize = 20
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -8,8 +8,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMAC1(t *testing.T) {
|
func TestMAC1(t *testing.T) {
|
||||||
dev1 := newDevice(t)
|
dev1 := randDevice(t)
|
||||||
dev2 := newDevice(t)
|
dev2 := randDevice(t)
|
||||||
|
|
||||||
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
|
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
|
||||||
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
|
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
|
||||||
|
@ -34,12 +34,10 @@ func TestMACs(t *testing.T) {
|
||||||
msg []byte,
|
msg []byte,
|
||||||
receiver uint32,
|
receiver uint32,
|
||||||
) bool {
|
) bool {
|
||||||
var device1 Device
|
device1 := randDevice(t)
|
||||||
device1.Init()
|
|
||||||
device1.SetPrivateKey(sk1)
|
device1.SetPrivateKey(sk1)
|
||||||
|
|
||||||
var device2 Device
|
device2 := randDevice(t)
|
||||||
device2.Init()
|
|
||||||
device2.SetPrivateKey(sk2)
|
device2.SetPrivateKey(sk2)
|
||||||
|
|
||||||
peer1 := device2.NewPeer(device1.privateKey.publicKey())
|
peer1 := device2.NewPeer(device1.privateKey.publicKey())
|
||||||
|
|
53
src/main.go
53
src/main.go
|
@ -1,36 +1,30 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
fd, err := CreateTUN("test0")
|
|
||||||
fmt.Println(fd, err)
|
|
||||||
|
|
||||||
queue := make(chan []byte, 1000)
|
|
||||||
|
|
||||||
// var device Device
|
|
||||||
|
|
||||||
// go OutgoingRoutingWorker(&device, queue)
|
|
||||||
|
|
||||||
for {
|
|
||||||
tmp := make([]byte, 1<<16)
|
|
||||||
n, err := fd.Read(tmp)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
queue <- tmp[:n]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* TODO: Fix logging
|
||||||
|
*/
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
// Open TUN device
|
||||||
|
|
||||||
|
// TODO: Fix capabilities
|
||||||
|
|
||||||
|
tun, err := CreateTUN("test0")
|
||||||
|
log.Println(tun, err)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
device := NewDevice(tun)
|
||||||
|
|
||||||
|
// Start configuration lister
|
||||||
|
|
||||||
l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
|
l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("listen error:", err)
|
log.Fatal("listen error:", err)
|
||||||
|
@ -41,12 +35,9 @@ func main() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("accept error:", err)
|
log.Fatal("accept error:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var dev Device
|
|
||||||
go func(conn net.Conn) {
|
go func(conn net.Conn) {
|
||||||
err := ipcListen(&dev, conn)
|
err := ipcListen(device, conn)
|
||||||
fmt.Println(err)
|
log.Println(err)
|
||||||
}(fd)
|
}(fd)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
|
@ -77,7 +77,7 @@ type MessageCookieReply struct {
|
||||||
|
|
||||||
type Handshake struct {
|
type Handshake struct {
|
||||||
state int
|
state int
|
||||||
mutex sync.Mutex
|
mutex sync.RWMutex
|
||||||
hash [blake2s.Size]byte // hash value
|
hash [blake2s.Size]byte // hash value
|
||||||
chainKey [blake2s.Size]byte // chain key
|
chainKey [blake2s.Size]byte // chain key
|
||||||
presharedKey NoiseSymmetricKey // psk
|
presharedKey NoiseSymmetricKey // psk
|
||||||
|
@ -205,19 +205,26 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||||
}
|
}
|
||||||
hash = mixHash(hash, msg.Static[:])
|
hash = mixHash(hash, msg.Static[:])
|
||||||
|
|
||||||
// find peer
|
// lookup peer
|
||||||
|
|
||||||
peer := device.LookupPeer(peerPK)
|
peer := device.LookupPeer(peerPK)
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.mutex.Lock()
|
|
||||||
defer handshake.mutex.Unlock()
|
// verify identity
|
||||||
|
|
||||||
|
var timestamp TAI64N
|
||||||
|
ok := func() bool {
|
||||||
|
|
||||||
|
// read lock handshake
|
||||||
|
|
||||||
|
handshake.mutex.RLock()
|
||||||
|
defer handshake.mutex.RUnlock()
|
||||||
|
|
||||||
// decrypt timestamp
|
// decrypt timestamp
|
||||||
|
|
||||||
var timestamp TAI64N
|
|
||||||
func() {
|
func() {
|
||||||
var key [chacha20poly1305.KeySize]byte
|
var key [chacha20poly1305.KeySize]byte
|
||||||
chainKey, key = KDF2(
|
chainKey, key = KDF2(
|
||||||
|
@ -228,26 +235,34 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||||
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
||||||
}()
|
}()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return false
|
||||||
}
|
}
|
||||||
hash = mixHash(hash, msg.Timestamp[:])
|
hash = mixHash(hash, msg.Timestamp[:])
|
||||||
|
|
||||||
|
// TODO: check for flood attack
|
||||||
|
|
||||||
// check for replay attack
|
// check for replay attack
|
||||||
|
|
||||||
if !timestamp.After(handshake.lastTimestamp) {
|
return timestamp.After(handshake.lastTimestamp)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: check for flood attack
|
|
||||||
|
|
||||||
// update handshake state
|
// update handshake state
|
||||||
|
|
||||||
|
handshake.mutex.Lock()
|
||||||
|
|
||||||
handshake.hash = hash
|
handshake.hash = hash
|
||||||
handshake.chainKey = chainKey
|
handshake.chainKey = chainKey
|
||||||
handshake.remoteIndex = msg.Sender
|
handshake.remoteIndex = msg.Sender
|
||||||
handshake.remoteEphemeral = msg.Ephemeral
|
handshake.remoteEphemeral = msg.Ephemeral
|
||||||
handshake.lastTimestamp = timestamp
|
handshake.lastTimestamp = timestamp
|
||||||
handshake.state = HandshakeInitiationConsumed
|
handshake.state = HandshakeInitiationConsumed
|
||||||
|
|
||||||
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -320,16 +335,26 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
handshake.mutex.Lock()
|
var (
|
||||||
defer handshake.mutex.Unlock()
|
hash [blake2s.Size]byte
|
||||||
|
chainKey [blake2s.Size]byte
|
||||||
|
)
|
||||||
|
|
||||||
|
ok := func() bool {
|
||||||
|
|
||||||
|
// read lock handshake
|
||||||
|
|
||||||
|
handshake.mutex.RLock()
|
||||||
|
defer handshake.mutex.RUnlock()
|
||||||
|
|
||||||
if handshake.state != HandshakeInitiationCreated {
|
if handshake.state != HandshakeInitiationCreated {
|
||||||
return nil
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// finish 3-way DH
|
// finish 3-way DH
|
||||||
|
|
||||||
hash := mixHash(handshake.hash, msg.Ephemeral[:])
|
hash = mixHash(handshake.hash, msg.Ephemeral[:])
|
||||||
chainKey := handshake.chainKey
|
chainKey = handshake.chainKey
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
||||||
|
@ -350,17 +375,27 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return false
|
||||||
}
|
}
|
||||||
hash = mixHash(hash, msg.Empty[:])
|
hash = mixHash(hash, msg.Empty[:])
|
||||||
|
return true
|
||||||
|
}()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// update handshake state
|
// update handshake state
|
||||||
|
|
||||||
|
handshake.mutex.Lock()
|
||||||
|
|
||||||
handshake.hash = hash
|
handshake.hash = hash
|
||||||
handshake.chainKey = chainKey
|
handshake.chainKey = chainKey
|
||||||
handshake.remoteIndex = msg.Sender
|
handshake.remoteIndex = msg.Sender
|
||||||
handshake.state = HandshakeResponseConsumed
|
handshake.state = HandshakeResponseConsumed
|
||||||
|
|
||||||
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
return lookup.peer
|
return lookup.peer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,29 +6,6 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func assertNil(t *testing.T, err error) {
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertEqual(t *testing.T, a []byte, b []byte) {
|
|
||||||
if bytes.Compare(a, b) != 0 {
|
|
||||||
t.Fatal(a, "!=", b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDevice(t *testing.T) *Device {
|
|
||||||
var device Device
|
|
||||||
sk, err := newPrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
device.Init()
|
|
||||||
device.SetPrivateKey(sk)
|
|
||||||
return &device
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCurveWrappers(t *testing.T) {
|
func TestCurveWrappers(t *testing.T) {
|
||||||
sk1, err := newPrivateKey()
|
sk1, err := newPrivateKey()
|
||||||
assertNil(t, err)
|
assertNil(t, err)
|
||||||
|
@ -49,8 +26,8 @@ func TestCurveWrappers(t *testing.T) {
|
||||||
|
|
||||||
func TestNoiseHandshake(t *testing.T) {
|
func TestNoiseHandshake(t *testing.T) {
|
||||||
|
|
||||||
dev1 := newDevice(t)
|
dev1 := randDevice(t)
|
||||||
dev2 := newDevice(t)
|
dev2 := randDevice(t)
|
||||||
|
|
||||||
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
|
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
|
||||||
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
|
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
|
||||||
|
|
|
@ -3,18 +3,18 @@ package main
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NoisePublicKeySize = 32
|
NoisePublicKeySize = 32
|
||||||
NoisePrivateKeySize = 32
|
NoisePrivateKeySize = 32
|
||||||
NoiseSymmetricKeySize = 32
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
NoisePublicKey [NoisePublicKeySize]byte
|
NoisePublicKey [NoisePublicKeySize]byte
|
||||||
NoisePrivateKey [NoisePrivateKeySize]byte
|
NoisePrivateKey [NoisePrivateKeySize]byte
|
||||||
NoiseSymmetricKey [NoiseSymmetricKeySize]byte
|
NoiseSymmetricKey [chacha20poly1305.KeySize]byte
|
||||||
NoiseNonce uint64 // padded to 12-bytes
|
NoiseNonce uint64 // padded to 12-bytes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -30,6 +30,15 @@ func loadExactHex(dst []byte, src string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (key NoisePrivateKey) IsZero() bool {
|
||||||
|
for _, b := range key[:] {
|
||||||
|
if b != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (key *NoisePrivateKey) FromHex(src string) error {
|
func (key *NoisePrivateKey) FromHex(src string) error {
|
||||||
return loadExactHex(key[:], src)
|
return loadExactHex(key[:], src)
|
||||||
}
|
}
|
||||||
|
|
44
src/peer.go
44
src/peer.go
|
@ -7,9 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const ()
|
||||||
OutboundQueueSize = 64
|
|
||||||
)
|
|
||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
@ -18,9 +16,25 @@ type Peer struct {
|
||||||
keyPairs KeyPairs
|
keyPairs KeyPairs
|
||||||
handshake Handshake
|
handshake Handshake
|
||||||
device *Device
|
device *Device
|
||||||
queueInbound chan []byte
|
tx_bytes uint64
|
||||||
queueOutbound chan *OutboundWorkQueueElement
|
rx_bytes uint64
|
||||||
queueOutboundRouting chan []byte
|
time struct {
|
||||||
|
lastSend time.Time // last send message
|
||||||
|
}
|
||||||
|
signal struct {
|
||||||
|
newHandshake chan bool
|
||||||
|
flushNonceQueue chan bool // empty queued packets
|
||||||
|
stopSending chan bool // stop sending pipeline
|
||||||
|
stopInitiator chan bool // stop initiator timer
|
||||||
|
}
|
||||||
|
timer struct {
|
||||||
|
sendKeepalive time.Timer
|
||||||
|
handshakeTimeout time.Timer
|
||||||
|
}
|
||||||
|
queue struct {
|
||||||
|
nonce chan []byte // nonce / pre-handshake queue
|
||||||
|
outbound chan *QueueOutboundElement // sequential ordering of work
|
||||||
|
}
|
||||||
mac MacStatePeer
|
mac MacStatePeer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,7 +47,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
|
||||||
peer.device = device
|
peer.device = device
|
||||||
peer.keyPairs.Init()
|
peer.keyPairs.Init()
|
||||||
peer.mac.Init(pk)
|
peer.mac.Init(pk)
|
||||||
peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
|
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
|
||||||
|
peer.queue.nonce = make(chan []byte, QueueOutboundSize)
|
||||||
|
|
||||||
// map public key
|
// map public key
|
||||||
|
|
||||||
|
@ -54,5 +69,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
peer.mutex.Unlock()
|
peer.mutex.Unlock()
|
||||||
|
|
||||||
|
// start workers
|
||||||
|
|
||||||
|
peer.signal.stopSending = make(chan bool, 1)
|
||||||
|
peer.signal.stopInitiator = make(chan bool, 1)
|
||||||
|
peer.signal.newHandshake = make(chan bool, 1)
|
||||||
|
peer.signal.flushNonceQueue = make(chan bool, 1)
|
||||||
|
|
||||||
|
go peer.RoutineNonce()
|
||||||
|
go peer.RoutineHandshakeInitiator()
|
||||||
|
|
||||||
return &peer
|
return &peer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) Close() {
|
||||||
|
peer.signal.stopSending <- true
|
||||||
|
peer.signal.stopInitiator <- true
|
||||||
|
}
|
||||||
|
|
|
@ -12,9 +12,20 @@ type RoutingTable struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet {
|
||||||
|
table.mutex.RLock()
|
||||||
|
defer table.mutex.RUnlock()
|
||||||
|
|
||||||
|
allowed := make([]net.IPNet, 10)
|
||||||
|
table.IPv4.AllowedIPs(peer, allowed)
|
||||||
|
table.IPv6.AllowedIPs(peer, allowed)
|
||||||
|
return allowed
|
||||||
|
}
|
||||||
|
|
||||||
func (table *RoutingTable) Reset() {
|
func (table *RoutingTable) Reset() {
|
||||||
table.mutex.Lock()
|
table.mutex.Lock()
|
||||||
defer table.mutex.Unlock()
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
table.IPv4 = nil
|
table.IPv4 = nil
|
||||||
table.IPv6 = nil
|
table.IPv6 = nil
|
||||||
}
|
}
|
||||||
|
@ -22,6 +33,7 @@ func (table *RoutingTable) Reset() {
|
||||||
func (table *RoutingTable) RemovePeer(peer *Peer) {
|
func (table *RoutingTable) RemovePeer(peer *Peer) {
|
||||||
table.mutex.Lock()
|
table.mutex.Lock()
|
||||||
defer table.mutex.Unlock()
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
table.IPv4 = table.IPv4.RemovePeer(peer)
|
table.IPv4 = table.IPv4.RemovePeer(peer)
|
||||||
table.IPv6 = table.IPv6.RemovePeer(peer)
|
table.IPv6 = table.IPv6.RemovePeer(peer)
|
||||||
}
|
}
|
||||||
|
|
177
src/send.go
177
src/send.go
|
@ -5,30 +5,78 @@ import (
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Handles outbound flow
|
/* Handles outbound flow
|
||||||
*
|
*
|
||||||
* 1. TUN queue
|
* 1. TUN queue
|
||||||
* 2. Routing
|
* 2. Routing (sequential)
|
||||||
* 3. Per peer queuing
|
* 3. Nonce assignment (sequential)
|
||||||
* 4. (work queuing)
|
* 4. Encryption (parallel)
|
||||||
|
* 5. Transmission (sequential)
|
||||||
*
|
*
|
||||||
|
* The order of packets (per peer) is maintained.
|
||||||
|
* The functions in this file occure (roughly) in the order packets are processed.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
type OutboundWorkQueueElement struct {
|
/* A work unit
|
||||||
wg sync.WaitGroup
|
*
|
||||||
|
* The sequential consumers will attempt to take the lock,
|
||||||
|
* workers release lock when they have completed work on the packet.
|
||||||
|
*/
|
||||||
|
type QueueOutboundElement struct {
|
||||||
|
mutex sync.Mutex
|
||||||
packet []byte
|
packet []byte
|
||||||
nonce uint64
|
nonce uint64
|
||||||
keyPair *KeyPair
|
keyPair *KeyPair
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) HandshakeWorker(handshakeQueue []byte) {
|
func (peer *Peer) FlushNonceQueue() {
|
||||||
|
elems := len(peer.queue.nonce)
|
||||||
|
for i := 0; i < elems; i += 1 {
|
||||||
|
select {
|
||||||
|
case <-peer.queue.nonce:
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) SendPacket(packet []byte) {
|
func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case peer.queue.outbound <- elem:
|
||||||
|
default:
|
||||||
|
select {
|
||||||
|
case <-peer.queue.outbound:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Reads packets from the TUN and inserts
|
||||||
|
* into nonce queue for peer
|
||||||
|
*
|
||||||
|
* Obs. Single instance per TUN device
|
||||||
|
*/
|
||||||
|
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
|
||||||
|
for {
|
||||||
|
// read packet
|
||||||
|
|
||||||
|
packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
|
||||||
|
size, err := tun.Read(packet)
|
||||||
|
if err != nil {
|
||||||
|
device.log.Error.Println("Failed to read packet from TUN device:", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
packet = packet[:size]
|
||||||
|
if len(packet) < IPv4headerSize {
|
||||||
|
device.log.Error.Println("Packet too short, length:", len(packet))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
device.log.Debug.Println("New packet on TUN:", packet) // TODO: Slow debugging, remove.
|
||||||
|
|
||||||
// lookup peer
|
// lookup peer
|
||||||
|
|
||||||
|
@ -43,69 +91,73 @@ func (device *Device) SendPacket(packet []byte) {
|
||||||
peer = device.routingTable.LookupIPv6(dst)
|
peer = device.routingTable.LookupIPv6(dst)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
device.log.Debug.Println("receieved packet with unknown IP version")
|
device.log.Debug.Println("Receieved packet with unknown IP version")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
|
device.log.Debug.Println("No peer configured for IP")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// insert into peer queue
|
// insert into nonce/pre-handshake queue
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case peer.queueOutboundRouting <- packet:
|
case peer.queue.nonce <- packet:
|
||||||
default:
|
default:
|
||||||
select {
|
select {
|
||||||
case <-peer.queueOutboundRouting:
|
case <-peer.queue.nonce:
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Go routine
|
/* Queues packets when there is no handshake.
|
||||||
|
* Then assigns nonces to packets sequentially
|
||||||
|
* and creates "work" structs for workers
|
||||||
*
|
*
|
||||||
|
* TODO: Avoid dynamic allocation of work queue elements
|
||||||
*
|
*
|
||||||
* 1. waits for handshake.
|
* Obs. A single instance per peer
|
||||||
* 2. assigns key pair & nonce
|
|
||||||
* 3. inserts to working queue
|
|
||||||
*
|
|
||||||
* TODO: avoid dynamic allocation of work queue elements
|
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) RoutineOutboundNonceWorker() {
|
func (peer *Peer) RoutineNonce() {
|
||||||
var packet []byte
|
var packet []byte
|
||||||
var keyPair *KeyPair
|
var keyPair *KeyPair
|
||||||
var flushTimer time.Timer
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|
||||||
// wait for packet
|
// wait for packet
|
||||||
|
|
||||||
if packet == nil {
|
if packet == nil {
|
||||||
packet = <-peer.queueOutboundRouting
|
select {
|
||||||
|
case packet = <-peer.queue.nonce:
|
||||||
|
case <-peer.signal.stopSending:
|
||||||
|
close(peer.queue.outbound)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for key pair
|
// wait for key pair
|
||||||
|
|
||||||
for keyPair == nil {
|
for keyPair == nil {
|
||||||
flushTimer.Reset(time.Second * 10)
|
peer.signal.newHandshake <- true
|
||||||
// TODO: Handshake or NOP
|
|
||||||
select {
|
select {
|
||||||
case <-peer.keyPairs.newKeyPair:
|
case <-peer.keyPairs.newKeyPair:
|
||||||
keyPair = peer.keyPairs.Current()
|
keyPair = peer.keyPairs.Current()
|
||||||
continue
|
continue
|
||||||
case <-flushTimer.C:
|
case <-peer.signal.flushNonceQueue:
|
||||||
size := len(peer.queueOutboundRouting)
|
peer.FlushNonceQueue()
|
||||||
for i := 0; i < size; i += 1 {
|
|
||||||
<-peer.queueOutboundRouting
|
|
||||||
}
|
|
||||||
packet = nil
|
packet = nil
|
||||||
|
continue
|
||||||
|
case <-peer.signal.stopSending:
|
||||||
|
close(peer.queue.outbound)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// process current packet
|
// process current packet
|
||||||
|
@ -114,14 +166,13 @@ func (peer *Peer) RoutineOutboundNonceWorker() {
|
||||||
|
|
||||||
// create work element
|
// create work element
|
||||||
|
|
||||||
work := new(OutboundWorkQueueElement)
|
work := new(QueueOutboundElement) // TODO: profile, maybe use pool
|
||||||
work.wg.Add(1)
|
|
||||||
work.keyPair = keyPair
|
work.keyPair = keyPair
|
||||||
work.packet = packet
|
work.packet = packet
|
||||||
work.nonce = keyPair.sendNonce
|
work.nonce = keyPair.sendNonce
|
||||||
|
work.mutex.Lock()
|
||||||
|
|
||||||
packet = nil
|
packet = nil
|
||||||
peer.queueOutbound <- work
|
|
||||||
keyPair.sendNonce += 1
|
keyPair.sendNonce += 1
|
||||||
|
|
||||||
// drop packets until there is space
|
// drop packets until there is space
|
||||||
|
@ -129,46 +180,36 @@ func (peer *Peer) RoutineOutboundNonceWorker() {
|
||||||
func() {
|
func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case peer.device.queueWorkOutbound <- work:
|
case peer.device.queue.encryption <- work:
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
drop := <-peer.device.queueWorkOutbound
|
drop := <-peer.device.queue.encryption
|
||||||
drop.packet = nil
|
drop.packet = nil
|
||||||
drop.wg.Done()
|
drop.mutex.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
peer.queue.outbound <- work
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Go routine
|
/* Encrypts the elements in the queue
|
||||||
*
|
* and marks them for sequential consumption (by releasing the mutex)
|
||||||
* sequentially reads packets from queue and sends to endpoint
|
|
||||||
*
|
*
|
||||||
|
* Obs. One instance per core
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) RoutineSequential() {
|
func (device *Device) RoutineEncryption() {
|
||||||
for work := range peer.queueOutbound {
|
|
||||||
work.wg.Wait()
|
|
||||||
if work.packet == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if peer.endpoint == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) RoutineEncryptionWorker() {
|
|
||||||
var nonce [chacha20poly1305.NonceSize]byte
|
var nonce [chacha20poly1305.NonceSize]byte
|
||||||
for work := range device.queueWorkOutbound {
|
for work := range device.queue.encryption {
|
||||||
|
|
||||||
// pad packet
|
// pad packet
|
||||||
|
|
||||||
padding := device.mtu - len(work.packet)
|
padding := device.mtu - len(work.packet)
|
||||||
if padding < 0 {
|
if padding < 0 {
|
||||||
|
// drop
|
||||||
work.packet = nil
|
work.packet = nil
|
||||||
work.wg.Done()
|
work.mutex.Unlock()
|
||||||
}
|
}
|
||||||
for n := 0; n < padding; n += 1 {
|
for n := 0; n < padding; n += 1 {
|
||||||
work.packet = append(work.packet, 0)
|
work.packet = append(work.packet, 0)
|
||||||
|
@ -183,6 +224,30 @@ func (device *Device) RoutineEncryptionWorker() {
|
||||||
work.packet,
|
work.packet,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
work.wg.Done()
|
work.mutex.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Sequentially reads packets from queue and sends to endpoint
|
||||||
|
*
|
||||||
|
* Obs. Single instance per peer.
|
||||||
|
* The routine terminates then the outbound queue is closed.
|
||||||
|
*/
|
||||||
|
func (peer *Peer) RoutineSequential() {
|
||||||
|
for work := range peer.queue.outbound {
|
||||||
|
work.mutex.Lock()
|
||||||
|
func() {
|
||||||
|
peer.mutex.RLock()
|
||||||
|
defer peer.mutex.RUnlock()
|
||||||
|
if work.packet == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if peer.endpoint == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
|
||||||
|
peer.timer.sendKeepalive.Reset(peer.persistentKeepaliveInterval)
|
||||||
|
}()
|
||||||
|
work.mutex.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
31
src/trie.go
31
src/trie.go
|
@ -1,15 +1,20 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Binary trie
|
/* Binary trie
|
||||||
|
*
|
||||||
|
* The net.IPs used here are not formatted the
|
||||||
|
* same way as those created by the "net" functions.
|
||||||
|
* Here the IPs are slices of either 4 or 16 byte (not always 16)
|
||||||
*
|
*
|
||||||
* Syncronization done seperatly
|
* Syncronization done seperatly
|
||||||
* See: routing.go
|
* See: routing.go
|
||||||
*
|
*
|
||||||
* Todo: Better commenting
|
* TODO: Better commenting
|
||||||
*/
|
*/
|
||||||
|
|
||||||
type Trie struct {
|
type Trie struct {
|
||||||
|
@ -24,7 +29,7 @@ type Trie struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Finds length of matching prefix
|
/* Finds length of matching prefix
|
||||||
* Maybe there is a faster way
|
* TODO: Make faster
|
||||||
*
|
*
|
||||||
* Assumption: len(ip1) == len(ip2)
|
* Assumption: len(ip1) == len(ip2)
|
||||||
*/
|
*/
|
||||||
|
@ -189,3 +194,25 @@ func (node *Trie) Count() uint {
|
||||||
r := node.child[1].Count()
|
r := node.child[1].Count()
|
||||||
return l + r
|
return l + r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) {
|
||||||
|
if node.peer == p {
|
||||||
|
var mask net.IPNet
|
||||||
|
mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
|
||||||
|
if len(node.bits) == net.IPv4len {
|
||||||
|
mask.IP = net.IPv4(
|
||||||
|
node.bits[0],
|
||||||
|
node.bits[1],
|
||||||
|
node.bits[2],
|
||||||
|
node.bits[3],
|
||||||
|
)
|
||||||
|
} else if len(node.bits) == net.IPv6len {
|
||||||
|
mask.IP = node.bits
|
||||||
|
} else {
|
||||||
|
panic(errors.New("bug: unexpected address length"))
|
||||||
|
}
|
||||||
|
results = append(results, mask)
|
||||||
|
}
|
||||||
|
node.child[0].AllowedIPs(p, results)
|
||||||
|
node.child[1].AllowedIPs(p, results)
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
type TUN interface {
|
type TUNDevice interface {
|
||||||
Read([]byte) (int, error)
|
Read([]byte) (int, error)
|
||||||
Write([]byte) (int, error)
|
Write([]byte) (int, error)
|
||||||
Name() string
|
Name() string
|
||||||
|
|
|
@ -9,9 +9,7 @@ import (
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Platform dependent functions for interacting with
|
/* Implementation of the TUN device interface for linux
|
||||||
* TUN devices on linux systems
|
|
||||||
*
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const CloneDevicePath = "/dev/net/tun"
|
const CloneDevicePath = "/dev/net/tun"
|
||||||
|
@ -45,7 +43,7 @@ func (tun *NativeTun) Read(d []byte) (int, error) {
|
||||||
return tun.fd.Read(d)
|
return tun.fd.Read(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTUN(name string) (TUN, error) {
|
func CreateTUN(name string) (TUNDevice, error) {
|
||||||
// Open clone device
|
// Open clone device
|
||||||
fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0)
|
fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -53,7 +51,7 @@ func CreateTUN(name string) (TUN, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare ifreq struct
|
// Prepare ifreq struct
|
||||||
var ifr [18]byte
|
var ifr [128]byte
|
||||||
var flags uint16 = IFF_TUN | IFF_NO_PI
|
var flags uint16 = IFF_TUN | IFF_NO_PI
|
||||||
nameBytes := []byte(name)
|
nameBytes := []byte(name)
|
||||||
if len(nameBytes) >= IFNAMSIZ {
|
if len(nameBytes) >= IFNAMSIZ {
|
||||||
|
|
Loading…
Reference in a new issue