Completed initial version of outbound flow

This commit is contained in:
Mathias Hall-Andersen 2017-06-30 14:41:08 +02:00
parent 7e185db141
commit ba3e486667
17 changed files with 491 additions and 287 deletions

View file

@ -8,7 +8,6 @@ import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"time"
) )
// #include <errno.h> // #include <errno.h>
@ -51,9 +50,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
send("private_key=" + device.privateKey.ToHex()) send("private_key=" + device.privateKey.ToHex())
} }
if device.address != nil { send(fmt.Sprintf("listen_port=%d", device.net.addr.Port))
send(fmt.Sprintf("listen_port=%d", device.address.Port))
}
for _, peer := range device.peers { for _, peer := range device.peers {
func() { func() {
@ -106,7 +103,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
} }
key := parts[0] key := parts[0]
value := parts[1] value := parts[1]
logger.Println("Key-value pair: (", key, ",", value, ")") // TODO: Remove, leaks private key to log
switch key { switch key {
@ -118,13 +114,13 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.privateKey = NoisePrivateKey{} device.privateKey = NoisePrivateKey{}
device.mutex.Unlock() device.mutex.Unlock()
} else { } else {
device.mutex.Lock() var sk NoisePrivateKey
err := device.privateKey.FromHex(value) err := sk.FromHex(value)
device.mutex.Unlock()
if err != nil { if err != nil {
logger.Println("Failed to set private_key:", err) logger.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
device.SetPrivateKey(sk)
} }
case "listen_port": case "listen_port":
@ -134,12 +130,10 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logger.Println("Failed to set listen_port:", err) logger.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
device.mutex.Lock() device.net.mutex.Lock()
if device.address == nil { device.net.addr.Port = port
device.address = &net.UDPAddr{} device.net.conn, err = net.ListenUDP("udp", device.net.addr)
} device.net.mutex.Unlock()
device.address.Port = port
device.mutex.Unlock()
case "fwmark": case "fwmark":
logger.Println("FWMark not handled yet") logger.Println("FWMark not handled yet")
@ -200,13 +194,13 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
} }
case "endpoint": case "endpoint":
ip := net.ParseIP(value) addr, err := net.ResolveUDPAddr("udp", value)
if ip == nil { if err != nil {
logger.Println("Failed to set endpoint:", value) logger.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
peer.mutex.Lock() peer.mutex.Lock()
// peer.endpoint = ip FIX peer.endpoint = addr
peer.mutex.Unlock() peer.mutex.Unlock()
case "persistent_keepalive_interval": case "persistent_keepalive_interval":
@ -216,7 +210,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalidValue} return &IPCError{Code: ipcErrorInvalidValue}
} }
peer.mutex.Lock() peer.mutex.Lock()
peer.persistentKeepaliveInterval = time.Duration(secs) * time.Second peer.persistentKeepaliveInterval = uint64(secs)
peer.mutex.Unlock() peer.mutex.Unlock()
case "replace_allowed_ips": case "replace_allowed_ips":

View file

@ -5,15 +5,15 @@ import (
) )
const ( const (
RekeyAfterMessage = (1 << 64) - (1 << 16) - 1 RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
RekeyAfterTime = time.Second * 120 RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90 RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5 // TODO: Exponential backoff RekeyTimeout = time.Second * 5 // TODO: Exponential backoff
RejectAfterTime = time.Second * 180 RejectAfterTime = time.Second * 180
RejectAfterMessage = (1 << 64) - (1 << 4) - 1 RejectAfterMessages = (1 << 64) - (1 << 4) - 1
KeepaliveTimeout = time.Second * 10 KeepaliveTimeout = time.Second * 10
CookieRefreshTime = time.Second * 2 CookieRefreshTime = time.Second * 2
MaxHandshakeAttempTime = time.Second * 90 MaxHandshakeAttemptTime = time.Second * 90
) )
const ( const (

View file

@ -8,15 +8,20 @@ import (
type Device struct { type Device struct {
mtu int mtu int
log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers
fwMark uint32 fwMark uint32
address *net.UDPAddr // UDP source address net struct {
// seperate for performance reasons
mutex sync.RWMutex
addr *net.UDPAddr // UDP source address
conn *net.UDPConn // UDP "connection" conn *net.UDPConn // UDP "connection"
}
mutex sync.RWMutex mutex sync.RWMutex
privateKey NoisePrivateKey privateKey NoisePrivateKey
publicKey NoisePublicKey publicKey NoisePublicKey
routingTable RoutingTable routingTable RoutingTable
indices IndexTable indices IndexTable
log *Logger
queue struct { queue struct {
encryption chan *QueueOutboundElement // parallel work queue encryption chan *QueueOutboundElement // parallel work queue
} }
@ -44,17 +49,29 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
} }
} }
func NewDevice(tun TUNDevice) *Device { func NewDevice(tun TUNDevice, logLevel int) *Device {
device := new(Device) device := new(Device)
device.mutex.Lock() device.mutex.Lock()
defer device.mutex.Unlock() defer device.mutex.Unlock()
device.log = NewLogger() device.log = NewLogger(logLevel)
device.peers = make(map[NoisePublicKey]*Peer) device.peers = make(map[NoisePublicKey]*Peer)
device.indices.Init() device.indices.Init()
device.routingTable.Reset() device.routingTable.Reset()
// listen
device.net.mutex.Lock()
device.net.conn, _ = net.ListenUDP("udp", device.net.addr)
addr := device.net.conn.LocalAddr()
device.net.addr, _ = net.ResolveUDPAddr(addr.Network(), addr.String())
device.net.mutex.Unlock()
// create queues
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
// start workers // start workers
for i := 0; i < runtime.NumCPU(); i += 1 { for i := 0; i < runtime.NumCPU(); i += 1 {
@ -92,5 +109,11 @@ func (device *Device) RemoveAllPeers() {
peer.mutex.Lock() peer.mutex.Lock()
delete(device.peers, key) delete(device.peers, key)
peer.Close() peer.Close()
peer.mutex.Unlock()
} }
} }
func (device *Device) Close() {
device.RemoveAllPeers()
close(device.queue.encryption)
}

View file

@ -24,17 +24,85 @@ func (peer *Peer) SendKeepAlive() bool {
return true return true
} }
func StoppedTimer() *time.Timer {
timer := time.NewTimer(time.Hour)
if !timer.Stop() {
<-timer.C
}
return timer
}
/* Called when a new authenticated message has been send
*
* TODO: This might be done in a faster way
*/
func (peer *Peer) KeepKeyFreshSending() {
send := func() bool {
peer.keyPairs.mutex.RLock()
defer peer.keyPairs.mutex.RUnlock()
kp := peer.keyPairs.current
if kp == nil {
return false
}
if !kp.isInitiator {
return false
}
nonce := atomic.LoadUint64(&kp.sendNonce)
if nonce > RekeyAfterMessages {
return true
}
return time.Now().Sub(kp.created) > RekeyAfterTime
}()
if send {
sendSignal(peer.signal.handshakeBegin)
}
}
/* This is the state machine for handshake initiation
*
* Associated with this routine is the signal "handshakeBegin"
* The routine will read from the "handshakeBegin" channel
* at most every RekeyTimeout or with exponential backoff
*
* Implements exponential backoff for retries
*/
func (peer *Peer) RoutineHandshakeInitiator() { func (peer *Peer) RoutineHandshakeInitiator() {
var ongoing bool
var begun time.Time
var attempts uint
var timeout time.Timer
device := peer.device
work := new(QueueOutboundElement) work := new(QueueOutboundElement)
buffer := make([]byte, 0, 1024) device := peer.device
buffer := make([]byte, 1024)
logger := device.log.Debug
timeout := time.NewTimer(time.Hour)
queueHandshakeInitiation := func() error { logger.Println("Routine, handshake initator, started for peer", peer.id)
func() {
for {
var attempts uint
var deadline time.Time
select {
case <-peer.signal.handshakeBegin:
case <-peer.signal.stop:
return
}
HandshakeLoop:
for run := true; run; {
// clear completed signal
select {
case <-peer.signal.handshakeCompleted:
case <-peer.signal.stop:
return
default:
}
// queue handshake
err := func() error {
work.mutex.Lock() work.mutex.Lock()
defer work.mutex.Unlock() defer work.mutex.Unlock()
@ -45,70 +113,74 @@ func (peer *Peer) RoutineHandshakeInitiator() {
return err return err
} }
// create "work" element // marshal
writer := bytes.NewBuffer(buffer[:0]) writer := bytes.NewBuffer(buffer[:0])
binary.Write(writer, binary.LittleEndian, &msg) binary.Write(writer, binary.LittleEndian, msg)
work.packet = writer.Bytes() work.packet = writer.Bytes()
peer.mac.AddMacs(work.packet) peer.mac.AddMacs(work.packet)
peer.InsertOutbound(work) peer.InsertOutbound(work)
return nil return nil
}()
if err != nil {
device.log.Error.Println("Failed to create initiation message:", err)
break
}
if attempts == 0 {
deadline = time.Now().Add(MaxHandshakeAttemptTime)
} }
for { // set timeout
if !timeout.Stop() {
select { select {
case <-peer.signal.stopInitiator: case <-timeout.C:
default:
}
}
timeout.Reset((1 << attempts) * RekeyTimeout)
attempts += 1
device.log.Debug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id)
time.Sleep(RekeyTimeout)
// wait for handshake or timeout
select {
case <-peer.signal.stop:
return return
case <-peer.signal.newHandshake: case <-peer.signal.handshakeCompleted:
if ongoing { break HandshakeLoop
continue
default:
select {
case <-peer.signal.stop:
return
case <-peer.signal.handshakeCompleted:
break HandshakeLoop
case <-timeout.C:
nextTimeout := (1 << attempts) * RekeyTimeout
if deadline.Before(time.Now().Add(nextTimeout)) {
// we do not have time for another attempt
peer.signal.flushNonceQueue <- struct{}{}
if !peer.timer.sendKeepalive.Stop() {
<-peer.timer.sendKeepalive.C
} }
break HandshakeLoop
// create handshake
err := queueHandshakeInitiation()
if err != nil {
device.log.Error.Println("Failed to create initiation message:", err)
}
// log when we began
begun = time.Now()
ongoing = true
attempts = 0
timeout.Reset(RekeyTimeout)
case <-peer.timer.sendKeepalive.C:
// active keep-alives
peer.SendKeepAlive()
case <-peer.timer.handshakeTimeout.C:
// check if we can stop trying
if time.Now().Sub(begun) > MaxHandshakeAttempTime {
peer.signal.flushNonceQueue <- true
peer.timer.sendKeepalive.Stop()
ongoing = false
continue
}
// otherwise, try again (exponental backoff)
attempts += 1
err := queueHandshakeInitiation()
if err != nil {
device.log.Error.Println("Failed to create initiation message:", err)
}
peer.timer.handshakeTimeout.Reset((1 << attempts) * RekeyTimeout)
} }
} }
} }
}
}
}()
/* Handles packets related to handshake logger.Println("Routine, handshake initator, stopped for peer", peer.id)
}
/* Handles incomming packets related to handshake
* *
* *
*/ */
@ -140,33 +212,12 @@ func (device *Device) HandshakeWorker(queue chan struct {
// check for cookie // check for cookie
case MessageCookieReplyType: case MessageCookieReplyType:
if len(elem.msg) != MessageCookieReplySize {
case MessageTransportType: continue
} }
default:
device.log.Error.Println("Invalid message type in handshake queue")
} }
} }
func (device *Device) KeepKeyFresh(peer *Peer) {
send := func() bool {
peer.keyPairs.mutex.RLock()
defer peer.keyPairs.mutex.RUnlock()
kp := peer.keyPairs.current
if kp == nil {
return false
}
nonce := atomic.LoadUint64(&kp.sendNonce)
if nonce > RekeyAfterMessage {
return true
}
return kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime
}()
if send {
}
} }

View file

@ -35,7 +35,7 @@ func (tun *DummyTUN) Read(d []byte) (int, error) {
func CreateDummyTUN(name string) (TUNDevice, error) { func CreateDummyTUN(name string) (TUNDevice, error) {
var dummy DummyTUN var dummy DummyTUN
dummy.mtu = 1024 dummy.mtu = 0
dummy.packets = make(chan []byte, 100) dummy.packets = make(chan []byte, 100)
return &dummy, nil return &dummy, nil
} }
@ -58,7 +58,7 @@ func randDevice(t *testing.T) *Device {
t.Fatal(err) t.Fatal(err)
} }
tun, _ := CreateDummyTUN("dummy") tun, _ := CreateDummyTUN("dummy")
device := NewDevice(tun) device := NewDevice(tun, LogLevelError)
device.SetPrivateKey(sk) device.SetPrivateKey(sk)
return device return device
} }

View file

@ -41,7 +41,7 @@ func (table *IndexTable) Init() {
table.mutex.Unlock() table.mutex.Unlock()
} }
func (table *IndexTable) ClearIndex(index uint32) { func (table *IndexTable) Delete(index uint32) {
if index == 0 { if index == 0 {
return return
} }

View file

@ -13,6 +13,7 @@ type KeyPair struct {
sendNonce uint64 sendNonce uint64
isInitiator bool isInitiator bool
created time.Time created time.Time
id uint32
} }
type KeyPairs struct { type KeyPairs struct {
@ -20,14 +21,20 @@ type KeyPairs struct {
current *KeyPair current *KeyPair
previous *KeyPair previous *KeyPair
next *KeyPair // not yet "confirmed by transport" next *KeyPair // not yet "confirmed by transport"
newKeyPair chan bool // signals when "current" has been updated
} }
func (kp *KeyPairs) Init() { /* Called during recieving to confirm the handshake
* was completed correctly
*/
func (kp *KeyPairs) Used(key *KeyPair) {
if key == kp.next {
kp.mutex.Lock() kp.mutex.Lock()
kp.newKeyPair = make(chan bool, 5) kp.previous = kp.current
kp.current = key
kp.next = nil
kp.mutex.Unlock() kp.mutex.Unlock()
} }
}
func (kp *KeyPairs) Current() *KeyPair { func (kp *KeyPairs) Current() *KeyPair {
kp.mutex.RLock() kp.mutex.RLock()

View file

@ -1,6 +1,8 @@
package main package main
import ( import (
"io"
"io/ioutil"
"log" "log"
"os" "os"
) )
@ -17,17 +19,30 @@ type Logger struct {
Error *log.Logger Error *log.Logger
} }
func NewLogger() *Logger { func NewLogger(level int) *Logger {
output := os.Stdout
logger := new(Logger) logger := new(Logger)
logger.Debug = log.New(os.Stdout,
logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) {
if level >= LogLevelDebug {
return output, output, output
}
if level >= LogLevelInfo {
return output, output, ioutil.Discard
}
return output, ioutil.Discard, ioutil.Discard
}()
logger.Debug = log.New(logDebug,
"DEBUG: ", "DEBUG: ",
log.Ldate|log.Ltime|log.Lshortfile, log.Ldate|log.Ltime|log.Lshortfile,
) )
logger.Info = log.New(os.Stdout,
logger.Info = log.New(logInfo,
"INFO: ", "INFO: ",
log.Ldate|log.Ltime|log.Lshortfile, log.Ldate|log.Ltime|log.Lshortfile,
) )
logger.Error = log.New(os.Stdout, logger.Error = log.New(logErr,
"ERROR: ", "ERROR: ",
log.Ldate|log.Ltime|log.Lshortfile, log.Ldate|log.Ltime|log.Lshortfile,
) )

View file

@ -11,6 +11,9 @@ func TestMAC1(t *testing.T) {
dev1 := randDevice(t) dev1 := randDevice(t)
dev2 := randDevice(t) dev2 := randDevice(t)
defer dev1.Close()
defer dev2.Close()
peer1 := dev2.NewPeer(dev1.privateKey.publicKey()) peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
peer2 := dev1.NewPeer(dev2.privateKey.publicKey()) peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
@ -40,6 +43,9 @@ func TestMACs(t *testing.T) {
device2 := randDevice(t) device2 := randDevice(t)
device2.SetPrivateKey(sk2) device2.SetPrivateKey(sk2)
defer device1.Close()
defer device2.Close()
peer1 := device2.NewPeer(device1.privateKey.publicKey()) peer1 := device2.NewPeer(device1.privateKey.publicKey())
peer2 := device1.NewPeer(device2.privateKey.publicKey()) peer2 := device1.NewPeer(device2.privateKey.publicKey())

View file

@ -28,7 +28,7 @@ func main() {
return return
} }
device := NewDevice(tun) device := NewDevice(tun, LogLevelDebug)
// Start configuration lister // Start configuration lister

View file

@ -6,3 +6,10 @@ func min(a uint, b uint) uint {
} }
return a return a
} }
func sendSignal(c chan struct{}) {
select {
case c <- struct{}{}:
default:
}
}

View file

@ -33,6 +33,7 @@ func KDF2(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
HMAC(&prk, key, input) HMAC(&prk, key, input)
HMAC(&t0, prk[:], []byte{0x1}) HMAC(&t0, prk[:], []byte{0x1})
HMAC(&t1, prk[:], append(t0[:], 0x2)) HMAC(&t1, prk[:], append(t0[:], 0x2))
prk = [blake2s.Size]byte{}
return return
} }
@ -42,6 +43,7 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
HMAC(&t0, prk[:], []byte{0x1}) HMAC(&t0, prk[:], []byte{0x1})
HMAC(&t1, prk[:], append(t0[:], 0x2)) HMAC(&t1, prk[:], append(t0[:], 0x2))
HMAC(&t2, prk[:], append(t1[:], 0x3)) HMAC(&t2, prk[:], append(t1[:], 0x3))
prk = [blake2s.Size]byte{}
return return
} }

View file

@ -33,6 +33,7 @@ const (
const ( const (
MessageInitiationSize = 148 MessageInitiationSize = 148
MessageResponseSize = 92 MessageResponseSize = 92
MessageCookieReplySize = 64
) )
/* Type is an 8-bit field, followed by 3 nul bytes, /* Type is an 8-bit field, followed by 3 nul bytes,
@ -91,16 +92,11 @@ type Handshake struct {
} }
var ( var (
InitalChainKey [blake2s.Size]byte InitialChainKey [blake2s.Size]byte
InitalHash [blake2s.Size]byte InitialHash [blake2s.Size]byte
ZeroNonce [chacha20poly1305.NonceSize]byte ZeroNonce [chacha20poly1305.NonceSize]byte
) )
func init() {
InitalChainKey = blake2s.Sum256([]byte(NoiseConstruction))
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
}
func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte { func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
return KDF1(c[:], data) return KDF1(c[:], data)
} }
@ -117,6 +113,13 @@ func (h *Handshake) mixKey(data []byte) {
h.chainKey = mixKey(h.chainKey, data) h.chainKey = mixKey(h.chainKey, data)
} }
/* Do basic precomputations
*/
func init() {
InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
InitialHash = mixHash(InitialChainKey, []byte(WGIdentifier))
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
@ -125,28 +128,30 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// create ephemeral key // create ephemeral key
var err error var err error
handshake.chainKey = InitalChainKey handshake.hash = InitialHash
handshake.hash = mixHash(InitalHash, handshake.remoteStatic[:]) handshake.chainKey = InitialChainKey
handshake.localEphemeral, err = newPrivateKey() handshake.localEphemeral, err = newPrivateKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
device.indices.ClearIndex(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer)
// assign index // assign index
var msg MessageInitiation device.indices.Delete(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer)
msg.Type = MessageInitiationType
msg.Ephemeral = handshake.localEphemeral.publicKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
msg.Sender = handshake.localIndex handshake.mixHash(handshake.remoteStatic[:])
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
Sender: handshake.localIndex,
}
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
@ -185,9 +190,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
return nil return nil
} }
hash := mixHash(InitalHash, device.publicKey[:]) hash := mixHash(InitialHash, device.publicKey[:])
hash = mixHash(hash, msg.Ephemeral[:]) hash = mixHash(hash, msg.Ephemeral[:])
chainKey := mixKey(InitalChainKey, msg.Ephemeral[:]) chainKey := mixKey(InitialChainKey, msg.Ephemeral[:])
// decrypt static key // decrypt static key
@ -278,7 +283,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
// assign index // assign index
var err error var err error
device.indices.ClearIndex(handshake.localIndex) device.indices.Delete(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer) handshake.localIndex, err = device.indices.NewIndex(peer)
if err != nil { if err != nil {
return nil, err return nil, err
@ -420,10 +425,15 @@ func (peer *Peer) NewKeyPair() *KeyPair {
return nil return nil
} }
// zero handshake
handshake.chainKey = [blake2s.Size]byte{}
handshake.localEphemeral = NoisePrivateKey{}
peer.handshake.state = HandshakeZeroed
// create AEAD instances // create AEAD instances
var keyPair KeyPair keyPair := new(KeyPair)
keyPair.send, _ = chacha20poly1305.New(sendKey[:]) keyPair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.recv, _ = chacha20poly1305.New(recvKey[:]) keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
keyPair.sendNonce = 0 keyPair.sendNonce = 0
@ -433,30 +443,32 @@ func (peer *Peer) NewKeyPair() *KeyPair {
peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{ peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{
peer: peer, peer: peer,
keyPair: &keyPair, keyPair: keyPair,
handshake: nil, handshake: nil,
}) })
handshake.localIndex = 0 handshake.localIndex = 0
// start timer for keypair
// rotate key pairs // rotate key pairs
func() {
kp := &peer.keyPairs kp := &peer.keyPairs
func() {
kp.mutex.Lock() kp.mutex.Lock()
defer kp.mutex.Unlock() defer kp.mutex.Unlock()
if isInitiator { if isInitiator {
kp.previous = peer.keyPairs.current if kp.previous != nil {
kp.current = &keyPair kp.previous.send = nil
kp.newKeyPair <- true kp.previous.recv = nil
peer.device.indices.Delete(kp.previous.id)
}
kp.previous = kp.current
kp.current = keyPair
sendSignal(peer.signal.newKeyPair)
} else { } else {
kp.next = &keyPair kp.next = keyPair
} }
}() }()
// zero handshake return keyPair
handshake.chainKey = [blake2s.Size]byte{}
handshake.localEphemeral = NoisePrivateKey{}
peer.handshake.state = HandshakeZeroed
return &keyPair
} }

View file

@ -25,10 +25,12 @@ func TestCurveWrappers(t *testing.T) {
} }
func TestNoiseHandshake(t *testing.T) { func TestNoiseHandshake(t *testing.T) {
dev1 := randDevice(t) dev1 := randDevice(t)
dev2 := randDevice(t) dev2 := randDevice(t)
defer dev1.Close()
defer dev2.Close()
peer1 := dev2.NewPeer(dev1.privateKey.publicKey()) peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
peer2 := dev1.NewPeer(dev2.privateKey.publicKey()) peer2 := dev1.NewPeer(dev2.privateKey.publicKey())

View file

@ -10,9 +10,10 @@ import (
const () const ()
type Peer struct { type Peer struct {
id uint
mutex sync.RWMutex mutex sync.RWMutex
endpoint *net.UDPAddr endpoint *net.UDPAddr
persistentKeepaliveInterval time.Duration // 0 = disabled persistentKeepaliveInterval uint64
keyPairs KeyPairs keyPairs KeyPairs
handshake Handshake handshake Handshake
device *Device device *Device
@ -20,16 +21,18 @@ type Peer struct {
rx_bytes uint64 rx_bytes uint64
time struct { time struct {
lastSend time.Time // last send message lastSend time.Time // last send message
lastHandshake time.Time // last completed handshake
} }
signal struct { signal struct {
newHandshake chan bool newKeyPair chan struct{} // (size 1) : a new key pair was generated
flushNonceQueue chan bool // empty queued packets handshakeBegin chan struct{} // (size 1) : request that a new handshake be started ("queue handshake")
stopSending chan bool // stop sending pipeline handshakeCompleted chan struct{} // (size 1) : handshake completed
stopInitiator chan bool // stop initiator timer flushNonceQueue chan struct{} // (size 1) : empty queued packets
stop chan struct{} // (size 0) : close to stop all goroutines for peer
} }
timer struct { timer struct {
sendKeepalive time.Timer sendKeepalive *time.Timer
handshakeTimeout time.Timer handshakeTimeout *time.Timer
} }
queue struct { queue struct {
nonce chan []byte // nonce / pre-handshake queue nonce chan []byte // nonce / pre-handshake queue
@ -39,25 +42,30 @@ type Peer struct {
} }
func (device *Device) NewPeer(pk NoisePublicKey) *Peer { func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
var peer Peer
// create peer // create peer
peer := new(Peer)
peer.mutex.Lock() peer.mutex.Lock()
defer peer.mutex.Unlock()
peer.device = device peer.device = device
peer.keyPairs.Init()
peer.mac.Init(pk) peer.mac.Init(pk)
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.nonce = make(chan []byte, QueueOutboundSize) peer.queue.nonce = make(chan []byte, QueueOutboundSize)
peer.timer.sendKeepalive = StoppedTimer()
// assign id for debugging
device.mutex.Lock()
peer.id = device.idCounter
device.idCounter += 1
// map public key // map public key
device.mutex.Lock()
_, ok := device.peers[pk] _, ok := device.peers[pk]
if ok { if ok {
panic(errors.New("bug: adding existing peer")) panic(errors.New("bug: adding existing peer"))
} }
device.peers[pk] = &peer device.peers[pk] = peer
device.mutex.Unlock() device.mutex.Unlock()
// precompute DH // precompute DH
@ -67,22 +75,24 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
handshake.remoteStatic = pk handshake.remoteStatic = pk
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic) handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
handshake.mutex.Unlock() handshake.mutex.Unlock()
peer.mutex.Unlock()
// start workers // prepare signaling
peer.signal.stopSending = make(chan bool, 1) peer.signal.stop = make(chan struct{})
peer.signal.stopInitiator = make(chan bool, 1) peer.signal.newKeyPair = make(chan struct{}, 1)
peer.signal.newHandshake = make(chan bool, 1) peer.signal.handshakeBegin = make(chan struct{}, 1)
peer.signal.flushNonceQueue = make(chan bool, 1) peer.signal.handshakeCompleted = make(chan struct{}, 1)
peer.signal.flushNonceQueue = make(chan struct{}, 1)
// outbound pipeline
go peer.RoutineNonce() go peer.RoutineNonce()
go peer.RoutineHandshakeInitiator() go peer.RoutineHandshakeInitiator()
go peer.RoutineSequentialSender()
return &peer return peer
} }
func (peer *Peer) Close() { func (peer *Peer) Close() {
peer.signal.stopSending <- true close(peer.signal.stop)
peer.signal.stopInitiator <- true
} }

View file

@ -5,6 +5,8 @@ import (
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"net" "net"
"sync" "sync"
"sync/atomic"
"time"
) )
/* Handles outbound flow /* Handles outbound flow
@ -29,6 +31,7 @@ type QueueOutboundElement struct {
packet []byte packet []byte
nonce uint64 nonce uint64
keyPair *KeyPair keyPair *KeyPair
peer *Peer
} }
func (peer *Peer) FlushNonceQueue() { func (peer *Peer) FlushNonceQueue() {
@ -46,6 +49,7 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
for { for {
select { select {
case peer.queue.outbound <- elem: case peer.queue.outbound <- elem:
return
default: default:
select { select {
case <-peer.queue.outbound: case <-peer.queue.outbound:
@ -61,11 +65,15 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
* Obs. Single instance per TUN device * Obs. Single instance per TUN device
*/ */
func (device *Device) RoutineReadFromTUN(tun TUNDevice) { func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
if tun.MTU() == 0 {
// Dummy
return
}
device.log.Debug.Println("Routine, TUN Reader: started") device.log.Debug.Println("Routine, TUN Reader: started")
for { for {
// read packet // read packet
device.log.Debug.Println("Read")
packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
size, err := tun.Read(packet) size, err := tun.Read(packet)
if err != nil { if err != nil {
@ -94,13 +102,16 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
default: default:
device.log.Debug.Println("Receieved packet with unknown IP version") device.log.Debug.Println("Receieved packet with unknown IP version")
return
} }
if peer == nil { if peer == nil {
device.log.Debug.Println("No peer configured for IP") device.log.Debug.Println("No peer configured for IP")
continue continue
} }
if peer.endpoint == nil {
device.log.Debug.Println("No known endpoint for peer", peer.id)
continue
}
// insert into nonce/pre-handshake queue // insert into nonce/pre-handshake queue
@ -131,33 +142,56 @@ func (peer *Peer) RoutineNonce() {
var packet []byte var packet []byte
var keyPair *KeyPair var keyPair *KeyPair
device := peer.device
logger := device.log.Debug
logger.Println("Routine, nonce worker, started for peer", peer.id)
func() {
for { for {
NextPacket:
// wait for packet // wait for packet
if packet == nil { if packet == nil {
select { select {
case packet = <-peer.queue.nonce: case packet = <-peer.queue.nonce:
case <-peer.signal.stopSending: case <-peer.signal.stop:
close(peer.queue.outbound)
return return
} }
} }
// wait for key pair // wait for key pair
for keyPair == nil { for {
peer.signal.newHandshake <- true
select { select {
case <-peer.keyPairs.newKeyPair: case <-peer.signal.newKeyPair:
default:
}
keyPair = peer.keyPairs.Current() keyPair = peer.keyPairs.Current()
continue if keyPair != nil && keyPair.sendNonce < RejectAfterMessages {
if time.Now().Sub(keyPair.created) < RejectAfterTime {
break
}
}
sendSignal(peer.signal.handshakeBegin)
logger.Println("Waiting for key-pair, peer", peer.id)
select {
case <-peer.signal.newKeyPair:
logger.Println("Key-pair negotiated for peer", peer.id)
goto NextPacket
case <-peer.signal.flushNonceQueue: case <-peer.signal.flushNonceQueue:
logger.Println("Clearing queue for peer", peer.id)
peer.FlushNonceQueue() peer.FlushNonceQueue()
packet = nil packet = nil
continue goto NextPacket
case <-peer.signal.stopSending:
close(peer.queue.outbound) case <-peer.signal.stop:
return return
} }
} }
@ -171,11 +205,11 @@ func (peer *Peer) RoutineNonce() {
work := new(QueueOutboundElement) // TODO: profile, maybe use pool work := new(QueueOutboundElement) // TODO: profile, maybe use pool
work.keyPair = keyPair work.keyPair = keyPair
work.packet = packet work.packet = packet
work.nonce = keyPair.sendNonce work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1)
work.peer = peer
work.mutex.Lock() work.mutex.Lock()
packet = nil packet = nil
keyPair.sendNonce += 1
// drop packets until there is space // drop packets until there is space
@ -194,6 +228,9 @@ func (peer *Peer) RoutineNonce() {
peer.queue.outbound <- work peer.queue.outbound <- work
} }
} }
}()
logger.Println("Routine, nonce worker, stopped for peer", peer.id)
} }
/* Encrypts the elements in the queue /* Encrypts the elements in the queue
@ -227,6 +264,10 @@ func (device *Device) RoutineEncryption() {
nil, nil,
) )
work.mutex.Unlock() work.mutex.Unlock()
// initiate new handshake
work.peer.KeepKeyFreshSending()
} }
} }
@ -235,21 +276,54 @@ func (device *Device) RoutineEncryption() {
* Obs. Single instance per peer. * Obs. Single instance per peer.
* The routine terminates then the outbound queue is closed. * The routine terminates then the outbound queue is closed.
*/ */
func (peer *Peer) RoutineSequential() { func (peer *Peer) RoutineSequentialSender() {
for work := range peer.queue.outbound { logger := peer.device.log.Debug
logger.Println("Routine, sequential sender, started for peer", peer.id)
device := peer.device
for {
select {
case <-peer.signal.stop:
logger.Println("Routine, sequential sender, stopped for peer", peer.id)
return
case work := <-peer.queue.outbound:
work.mutex.Lock() work.mutex.Lock()
func() { func() {
peer.mutex.RLock()
defer peer.mutex.RUnlock()
if work.packet == nil { if work.packet == nil {
return return
} }
peer.mutex.RLock()
defer peer.mutex.RUnlock()
if peer.endpoint == nil { if peer.endpoint == nil {
logger.Println("No endpoint for peer:", peer.id)
return return
} }
peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
peer.timer.sendKeepalive.Reset(peer.persistentKeepaliveInterval) device.net.mutex.RLock()
defer device.net.mutex.RUnlock()
if device.net.conn == nil {
logger.Println("No source for device")
return
}
logger.Println("Sending packet for peer", peer.id, work.packet)
_, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint)
logger.Println("SEND:", peer.endpoint, err)
atomic.AddUint64(&peer.tx_bytes, uint64(len(work.packet)))
// shift keep-alive timer
if peer.persistentKeepaliveInterval != 0 {
interval := time.Duration(peer.persistentKeepaliveInterval) * time.Second
peer.timer.sendKeepalive.Reset(interval)
}
}() }()
work.mutex.Unlock() work.mutex.Unlock()
} }
} }
}

View file

@ -74,5 +74,6 @@ func CreateTUN(name string) (TUNDevice, error) {
return &NativeTun{ return &NativeTun{
fd: fd, fd: fd,
name: newName, name: newName,
mtu: 1500, // TODO: FIX
}, nil }, nil
} }