diff --git a/src/constants.go b/src/constants.go index 6fbb7a0..4e8d521 100644 --- a/src/constants.go +++ b/src/constants.go @@ -16,6 +16,10 @@ const ( MaxHandshakeAttemptTime = time.Second * 90 ) +const ( + RekeyAfterTimeReceiving = RekeyAfterTime - KeepaliveTimeout - RekeyTimeout +) + const ( QueueOutboundSize = 1024 QueueInboundSize = 1024 diff --git a/src/device.go b/src/device.go index 0564068..12d1ed9 100644 --- a/src/device.go +++ b/src/device.go @@ -31,16 +31,11 @@ type Device struct { signal struct { stop chan struct{} } - congestionState int32 // used as an atomic bool - peers map[NoisePublicKey]*Peer - mac MACStateDevice + underLoad int32 // used as an atomic bool + peers map[NoisePublicKey]*Peer + mac MACStateDevice } -const ( - CongestionStateUnderLoad = iota - CongestionStateOkay -) - func (device *Device) SetPrivateKey(sk NoisePrivateKey) { device.mutex.Lock() defer device.mutex.Unlock() @@ -99,10 +94,12 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { go device.RoutineDecryption() go device.RoutineHandshake() } + go device.RoutineBusyMonitor() go device.RoutineReadFromTUN(tun) go device.RoutineReceiveIncomming() go device.RoutineWriteToTUN(tun) + return device } diff --git a/src/handshake.go b/src/handshake.go deleted file mode 100644 index de607df..0000000 --- a/src/handshake.go +++ /dev/null @@ -1,153 +0,0 @@ -package main - -import ( - "bytes" - "encoding/binary" - "sync/atomic" - "time" -) - -/* Sends a keep-alive if no packets queued for peer - * - * Used by initiator of handshake and with active keep-alive - */ -func (peer *Peer) SendKeepAlive() bool { - elem := peer.device.NewOutboundElement() - elem.packet = nil - if len(peer.queue.nonce) == 0 { - select { - case peer.queue.nonce <- elem: - return true - default: - return false - } - } - return true -} - -/* Called when a new authenticated message has been send - * - * TODO: This might be done in a faster way - */ -func (peer *Peer) KeepKeyFreshSending() { - send := func() bool { - peer.keyPairs.mutex.RLock() - defer peer.keyPairs.mutex.RUnlock() - - kp := peer.keyPairs.current - if kp == nil { - return false - } - - if !kp.isInitiator { - return false - } - - nonce := atomic.LoadUint64(&kp.sendNonce) - if nonce > RekeyAfterMessages { - return true - } - return time.Now().Sub(kp.created) > RekeyAfterTime - }() - if send { - sendSignal(peer.signal.handshakeBegin) - } -} - -/* This is the state machine for handshake initiation - * - * Associated with this routine is the signal "handshakeBegin" - * The routine will read from the "handshakeBegin" channel - * at most every RekeyTimeout seconds - */ -func (peer *Peer) RoutineHandshakeInitiator() { - device := peer.device - logger := device.log.Debug - timeout := stoppedTimer() - - var elem *QueueOutboundElement - - logger.Println("Routine, handshake initator, started for peer", peer.id) - - func() { - for { - var attempts uint - var deadline time.Time - - // wait for signal - - select { - case <-peer.signal.handshakeBegin: - case <-peer.signal.stop: - return - } - - HandshakeLoop: - for { - // clear completed signal - - select { - case <-peer.signal.handshakeCompleted: - case <-peer.signal.stop: - return - default: - } - - // create initiation - - if elem != nil { - elem.Drop() - } - elem = device.NewOutboundElement() - - msg, err := device.CreateMessageInitiation(peer) - if err != nil { - device.log.Error.Println("Failed to create initiation message:", err) - break - } - - // marshal & schedule for sending - - writer := bytes.NewBuffer(elem.data[:0]) - binary.Write(writer, binary.LittleEndian, msg) - elem.packet = writer.Bytes() - peer.mac.AddMacs(elem.packet) - addToOutboundQueue(peer.queue.outbound, elem) - - if attempts == 0 { - deadline = time.Now().Add(MaxHandshakeAttemptTime) - } - - // set timeout - - attempts += 1 - stopTimer(timeout) - timeout.Reset(RekeyTimeout) - device.log.Debug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id) - - // wait for handshake or timeout - - select { - case <-peer.signal.stop: - return - - case <-peer.signal.handshakeCompleted: - device.log.Debug.Println("Handshake complete") - break HandshakeLoop - - case <-timeout.C: - device.log.Debug.Println("Timeout") - if deadline.Before(time.Now().Add(RekeyTimeout)) { - peer.signal.flushNonceQueue <- struct{}{} - if !peer.timer.sendKeepalive.Stop() { - <-peer.timer.sendKeepalive.C - } - break HandshakeLoop - } - } - } - } - }() - - logger.Println("Routine, handshake initator, stopped for peer", peer.id) -} diff --git a/src/keypair.go b/src/keypair.go index 3caa0c8..b24dbe4 100644 --- a/src/keypair.go +++ b/src/keypair.go @@ -23,19 +23,6 @@ type KeyPairs struct { next *KeyPair // not yet "confirmed by transport" } -/* Called during recieving to confirm the handshake - * was completed correctly - */ -func (kp *KeyPairs) Used(key *KeyPair) { - if key == kp.next { - kp.mutex.Lock() - kp.previous = kp.current - kp.current = key - kp.next = nil - kp.mutex.Unlock() - } -} - func (kp *KeyPairs) Current() *KeyPair { kp.mutex.RLock() defer kp.mutex.RUnlock() diff --git a/src/misc.go b/src/misc.go index dd4fa63..75561b2 100644 --- a/src/misc.go +++ b/src/misc.go @@ -4,6 +4,14 @@ import ( "time" ) +/* We use int32 as atomic bools + * (since booleans are not natively supported by sync/atomic) + */ +const ( + AtomicFalse = iota + AtomicTrue +) + func min(a uint, b uint) uint { if a > b { return b @@ -11,14 +19,21 @@ func min(a uint, b uint) uint { return a } -func sendSignal(c chan struct{}) { +func signalSend(c chan struct{}) { select { case c <- struct{}{}: default: } } -func stopTimer(timer *time.Timer) { +func signalClear(c chan struct{}) { + select { + case <-c: + default: + } +} + +func timerStop(timer *time.Timer) { if !timer.Stop() { select { case <-timer.C: @@ -27,8 +42,8 @@ func stopTimer(timer *time.Timer) { } } -func stoppedTimer() *time.Timer { +func NewStoppedTimer() *time.Timer { timer := time.NewTimer(time.Hour) - stopTimer(timer) + timerStop(timer) return timer } diff --git a/src/noise_protocol.go b/src/noise_protocol.go index 9a9d918..a90fe4c 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -478,7 +478,7 @@ func (peer *Peer) NewKeyPair() *KeyPair { } kp.previous = kp.current kp.current = keyPair - sendSignal(peer.signal.newKeyPair) + signalSend(peer.signal.newKeyPair) } else { kp.next = keyPair } diff --git a/src/peer.go b/src/peer.go index fadc43f..c8dc5c0 100644 --- a/src/peer.go +++ b/src/peer.go @@ -20,25 +20,36 @@ type Peer struct { txBytes uint64 rxBytes uint64 time struct { + mutex sync.RWMutex lastSend time.Time // last send message lastHandshake time.Time // last completed handshake + nextKeepalive time.Time } signal struct { newKeyPair chan struct{} // (size 1) : a new key pair was generated handshakeBegin chan struct{} // (size 1) : request that a new handshake be started ("queue handshake") handshakeCompleted chan struct{} // (size 1) : handshake completed flushNonceQueue chan struct{} // (size 1) : empty queued packets + messageSend chan struct{} // (size 1) : a message was send to the peer + messageReceived chan struct{} // (size 1) : an authenticated message was received stop chan struct{} // (size 0) : close to stop all goroutines for peer } timer struct { - sendKeepalive *time.Timer - handshakeTimeout *time.Timer + /* Both keep-alive timers acts as one (see timers.go) + * They are kept seperate to simplify the implementation. + */ + keepalivePersistent *time.Timer // set for persistent keepalives + keepaliveAcknowledgement *time.Timer // set upon recieving messages + zeroAllKeys *time.Timer // zero all key material after RejectAfterTime*3 } queue struct { nonce chan *QueueOutboundElement // nonce / pre-handshake queue outbound chan *QueueOutboundElement // sequential ordering of work inbound chan *QueueInboundElement // sequential ordering of work } + flags struct { + keepaliveWaiting int32 + } mac MACStatePeer } @@ -51,7 +62,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { peer.mac.Init(pk) peer.device = device - peer.timer.sendKeepalive = stoppedTimer() + + peer.timer.keepalivePersistent = NewStoppedTimer() + peer.timer.keepaliveAcknowledgement = NewStoppedTimer() + peer.timer.zeroAllKeys = NewStoppedTimer() + + peer.flags.keepaliveWaiting = AtomicFalse // assign id for debugging @@ -82,7 +98,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) - // prepare signaling + // prepare signaling & routines peer.signal.stop = make(chan struct{}) peer.signal.newKeyPair = make(chan struct{}, 1) @@ -90,9 +106,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { peer.signal.handshakeCompleted = make(chan struct{}, 1) peer.signal.flushNonceQueue = make(chan struct{}, 1) - // outbound pipeline - go peer.RoutineNonce() + go peer.RoutineTimerHandler() go peer.RoutineHandshakeInitiator() go peer.RoutineSequentialSender() go peer.RoutineSequentialReceiver() diff --git a/src/receive.go b/src/receive.go index c788dcf..e780c66 100644 --- a/src/receive.go +++ b/src/receive.go @@ -10,11 +10,6 @@ import ( "time" ) -const ( - ElementStateOkay = iota - ElementStateDropped -) - type QueueHandshakeElement struct { msgType uint32 packet []byte @@ -22,7 +17,7 @@ type QueueHandshakeElement struct { } type QueueInboundElement struct { - state uint32 + dropped int32 mutex sync.Mutex packet []byte counter uint64 @@ -30,11 +25,11 @@ type QueueInboundElement struct { } func (elem *QueueInboundElement) Drop() { - atomic.StoreUint32(&elem.state, ElementStateDropped) + atomic.StoreInt32(&elem.dropped, AtomicTrue) } func (elem *QueueInboundElement) IsDropped() bool { - return atomic.LoadUint32(&elem.state) == ElementStateDropped + return atomic.LoadInt32(&elem.dropped) == AtomicTrue } func addToInboundQueue( @@ -101,9 +96,9 @@ func (device *Device) RoutineBusyMonitor() { // update busy state if busy { - atomic.StoreInt32(&device.congestionState, CongestionStateUnderLoad) + atomic.StoreInt32(&device.underLoad, AtomicTrue) } else { - atomic.StoreInt32(&device.congestionState, CongestionStateOkay) + atomic.StoreInt32(&device.underLoad, AtomicFalse) } timer.Reset(interval) @@ -216,7 +211,7 @@ func (device *Device) RoutineReceiveIncomming() { work := new(QueueInboundElement) work.packet = packet work.keyPair = keyPair - work.state = ElementStateOkay + work.dropped = AtomicFalse work.mutex.Lock() // add to decryption queues @@ -303,7 +298,7 @@ func (device *Device) RoutineHandshake() { // verify mac2 - busy := atomic.LoadInt32(&device.congestionState) == CongestionStateUnderLoad + busy := atomic.LoadInt32(&device.underLoad) == AtomicTrue if busy && !device.mac.CheckMAC2(elem.packet, elem.source) { sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" @@ -397,13 +392,12 @@ func (device *Device) RoutineHandshake() { ) return } - sendSignal(peer.signal.handshakeCompleted) - logDebug.Println("Recieved valid response message for peer", peer.id) kp := peer.NewKeyPair() if kp == nil { logDebug.Println("Failed to derieve key-pair") } peer.SendKeepAlive() + peer.EventHandshakeComplete() default: device.log.Error.Println("Invalid message type in handshake queue") @@ -438,9 +432,25 @@ func (peer *Peer) RoutineSequentialReceiver() { // check for replay - // update timers + // time (passive) keep-alive - // refresh key material + peer.TimerStartKeepalive() + + // refresh key material (rekey) + + peer.KeepKeyFreshReceiving() + + // check if confirming handshake + + kp := &peer.keyPairs + kp.mutex.Lock() + if kp.next == elem.keyPair { + peer.EventHandshakeComplete() + kp.previous = kp.current + kp.current = kp.next + kp.next = nil + } + kp.mutex.Unlock() // check for keep-alive @@ -491,7 +501,7 @@ func (peer *Peer) RoutineSequentialReceiver() { } default: - device.log.Debug.Println("Receieved packet with unknown IP version") + logDebug.Println("Receieved packet with unknown IP version") return } diff --git a/src/send.go b/src/send.go index a02f5cb..5ea9a8f 100644 --- a/src/send.go +++ b/src/send.go @@ -31,7 +31,7 @@ import ( * (to allow the construction of transport messages in-place) */ type QueueOutboundElement struct { - state uint32 + dropped int32 mutex sync.Mutex data [MaxMessageSize]byte packet []byte // slice of "data" (always!) @@ -61,11 +61,11 @@ func (device *Device) NewOutboundElement() *QueueOutboundElement { } func (elem *QueueOutboundElement) Drop() { - atomic.StoreUint32(&elem.state, ElementStateDropped) + atomic.StoreInt32(&elem.dropped, AtomicTrue) } func (elem *QueueOutboundElement) IsDropped() bool { - return atomic.LoadUint32(&elem.state) == ElementStateDropped + return atomic.LoadInt32(&elem.dropped) == AtomicTrue } func addToOutboundQueue( @@ -86,6 +86,25 @@ func addToOutboundQueue( } } +func addToEncryptionQueue( + queue chan *QueueOutboundElement, + element *QueueOutboundElement, +) { + for { + select { + case queue <- element: + return + default: + select { + case old := <-queue: + old.Drop() + old.mutex.Unlock() + default: + } + } + } +} + /* Reads packets from the TUN and inserts * into nonce queue for peer * @@ -196,9 +215,7 @@ func (peer *Peer) RoutineNonce() { break } } - logDebug.Println("Key pair:", keyPair) - - sendSignal(peer.signal.handshakeBegin) + signalSend(peer.signal.handshakeBegin) logDebug.Println("Waiting for key-pair, peer", peer.id) select { @@ -225,12 +242,13 @@ func (peer *Peer) RoutineNonce() { elem.keyPair = keyPair elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 + elem.dropped = AtomicFalse elem.peer = peer elem.mutex.Lock() - // add to parallel processing and sequential consuming queue + // add to parallel and sequential queue - addToOutboundQueue(device.queue.encryption, elem) + addToEncryptionQueue(device.queue.encryption, elem) addToOutboundQueue(peer.queue.outbound, elem) elem = nil } @@ -246,6 +264,9 @@ func (peer *Peer) RoutineNonce() { func (device *Device) RoutineEncryption() { var nonce [chacha20poly1305.NonceSize]byte for work := range device.queue.encryption { + + // check if dropped + if work.IsDropped() { continue } @@ -289,25 +310,25 @@ func (device *Device) RoutineEncryption() { * The routine terminates then the outbound queue is closed. */ func (peer *Peer) RoutineSequentialSender() { - logDebug := peer.device.log.Debug - logDebug.Println("Routine, sequential sender, started for peer", peer.id) - device := peer.device + logDebug := device.log.Debug + logDebug.Println("Routine, sequential sender, started for peer", peer.id) + for { select { case <-peer.signal.stop: logDebug.Println("Routine, sequential sender, stopped for peer", peer.id) return case work := <-peer.queue.outbound: + work.mutex.Lock() if work.IsDropped() { continue } - work.mutex.Lock() + func() { - if work.packet == nil { - return - } + + // send to endpoint peer.mutex.RLock() defer peer.mutex.RUnlock() @@ -331,12 +352,9 @@ func (peer *Peer) RoutineSequentialSender() { } atomic.AddUint64(&peer.txBytes, uint64(len(work.packet))) - // shift keep-alive timer + // reset keep-alive (passive keep-alives / acknowledgements) - if peer.persistentKeepaliveInterval != 0 { - interval := time.Duration(peer.persistentKeepaliveInterval) * time.Second - peer.timer.sendKeepalive.Reset(interval) - } + peer.TimerResetKeepalive() }() } } diff --git a/src/timers.go b/src/timers.go new file mode 100644 index 0000000..26926c2 --- /dev/null +++ b/src/timers.go @@ -0,0 +1,303 @@ +package main + +import ( + "bytes" + "encoding/binary" + "golang.org/x/crypto/blake2s" + "sync/atomic" + "time" +) + +/* Called when a new authenticated message has been send + * + */ +func (peer *Peer) KeepKeyFreshSending() { + send := func() bool { + peer.keyPairs.mutex.RLock() + defer peer.keyPairs.mutex.RUnlock() + + kp := peer.keyPairs.current + if kp == nil { + return false + } + + if !kp.isInitiator { + return false + } + + nonce := atomic.LoadUint64(&kp.sendNonce) + return nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime + }() + if send { + signalSend(peer.signal.handshakeBegin) + } +} + +/* Called when a new authenticated message has been recevied + * + */ +func (peer *Peer) KeepKeyFreshReceiving() { + send := func() bool { + peer.keyPairs.mutex.RLock() + defer peer.keyPairs.mutex.RUnlock() + + kp := peer.keyPairs.current + if kp == nil { + return false + } + + if !kp.isInitiator { + return false + } + + nonce := atomic.LoadUint64(&kp.sendNonce) + return nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving + }() + if send { + signalSend(peer.signal.handshakeBegin) + } +} + +/* Called after succesfully completing a handshake. + * i.e. after: + * - Valid handshake response + * - First transport message under the "next" key + */ +func (peer *Peer) EventHandshakeComplete() { + peer.device.log.Debug.Println("Handshake completed") + peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) + signalSend(peer.signal.handshakeCompleted) +} + +/* Queues a keep-alive if no packets are queued for peer + */ +func (peer *Peer) SendKeepAlive() bool { + elem := peer.device.NewOutboundElement() + elem.packet = nil + if len(peer.queue.nonce) == 0 { + select { + case peer.queue.nonce <- elem: + return true + default: + return false + } + } + return true +} + +/* Starts the "keep-alive" timer + * (if not already running), + * in response to incomming messages + */ +func (peer *Peer) TimerStartKeepalive() { + + // check if acknowledgement timer set yet + + var waiting int32 = AtomicTrue + waiting = atomic.SwapInt32(&peer.flags.keepaliveWaiting, waiting) + if waiting == AtomicTrue { + return + } + + // timer not yet set, start it + + wait := KeepaliveTimeout + interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) + if interval > 0 { + duration := time.Duration(interval) * time.Second + if duration < wait { + wait = duration + } + } +} + +/* Resets both keep-alive timers + */ +func (peer *Peer) TimerResetKeepalive() { + + // reset persistent timer + + interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) + if interval > 0 { + peer.timer.keepalivePersistent.Reset( + time.Duration(interval) * time.Second, + ) + } + + // stop acknowledgement timer + + timerStop(peer.timer.keepaliveAcknowledgement) + atomic.StoreInt32(&peer.flags.keepaliveWaiting, AtomicFalse) +} + +func (peer *Peer) BeginHandshakeInitiation() (*QueueOutboundElement, error) { + + // create initiation + + elem := peer.device.NewOutboundElement() + msg, err := peer.device.CreateMessageInitiation(peer) + if err != nil { + return nil, err + } + + // marshal & schedule for sending + + writer := bytes.NewBuffer(elem.data[:0]) + binary.Write(writer, binary.LittleEndian, msg) + elem.packet = writer.Bytes() + peer.mac.AddMacs(elem.packet) + addToOutboundQueue(peer.queue.outbound, elem) + return elem, err +} + +func (peer *Peer) RoutineTimerHandler() { + device := peer.device + + logDebug := device.log.Debug + logDebug.Println("Routine, timer handler, started for peer", peer.id) + + for { + select { + + case <-peer.signal.stop: + return + + // keep-alives + + case <-peer.timer.keepalivePersistent.C: + + logDebug.Println("Sending persistent keep-alive to peer", peer.id) + + peer.SendKeepAlive() + peer.TimerResetKeepalive() + + case <-peer.timer.keepaliveAcknowledgement.C: + + logDebug.Println("Sending passive persistent keep-alive to peer", peer.id) + + peer.SendKeepAlive() + peer.TimerResetKeepalive() + + // clear key material + + case <-peer.timer.zeroAllKeys.C: + + logDebug.Println("Clearing all key material for peer", peer.id) + + // zero out key pairs + + func() { + kp := &peer.keyPairs + kp.mutex.Lock() + // best we can do is wait for GC :( ? + kp.current = nil + kp.previous = nil + kp.next = nil + kp.mutex.Unlock() + }() + + // zero out handshake + + func() { + hs := &peer.handshake + hs.mutex.Lock() + hs.localEphemeral = NoisePrivateKey{} + hs.remoteEphemeral = NoisePublicKey{} + hs.chainKey = [blake2s.Size]byte{} + hs.hash = [blake2s.Size]byte{} + hs.mutex.Unlock() + }() + } + } +} + +/* This is the state machine for handshake initiation + * + * Associated with this routine is the signal "handshakeBegin" + * The routine will read from the "handshakeBegin" channel + * at most every RekeyTimeout seconds + */ +func (peer *Peer) RoutineHandshakeInitiator() { + device := peer.device + + var elem *QueueOutboundElement + + logError := device.log.Error + logDebug := device.log.Debug + logDebug.Println("Routine, handshake initator, started for peer", peer.id) + + for run := true; run; { + var err error + var attempts uint + var deadline time.Time + + // wait for signal + + select { + case <-peer.signal.handshakeBegin: + case <-peer.signal.stop: + return + } + + // wait for handshake + + run = func() bool { + for { + // clear completed signal + + select { + case <-peer.signal.handshakeCompleted: + case <-peer.signal.stop: + return false + default: + } + + // create initiation + + if elem != nil { + elem.Drop() + } + elem, err = peer.BeginHandshakeInitiation() + if err != nil { + logError.Println("Failed to create initiation message:", err) + break + } + + // set timeout + + attempts += 1 + if attempts == 1 { + deadline = time.Now().Add(MaxHandshakeAttemptTime) + } + timeout := time.NewTimer(RekeyTimeout) + logDebug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id) + + // wait for handshake or timeout + + select { + case <-peer.signal.stop: + return true + + case <-peer.signal.handshakeCompleted: + <-timeout.C + return true + + case <-timeout.C: + logDebug.Println("Timeout") + + // check if sufficient time for retry + + if deadline.Before(time.Now().Add(RekeyTimeout)) { + signalSend(peer.signal.flushNonceQueue) + timerStop(peer.timer.keepalivePersistent) + timerStop(peer.timer.keepaliveAcknowledgement) + return true + } + } + } + return true + }() + + signalClear(peer.signal.handshakeBegin) + } +}