diff --git a/src/noise_protocol.go b/src/noise_protocol.go index 9e5fdd8..2f9e1d5 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -532,7 +532,6 @@ func (peer *Peer) NewKeyPair() *KeyPair { kp := &peer.keyPairs kp.mutex.Lock() - // TODO: Adapt kernel behavior noise.c:161 if isInitiator { if kp.previous != nil { device.DeleteKeyPair(kp.previous) @@ -545,7 +544,7 @@ func (peer *Peer) NewKeyPair() *KeyPair { } else { kp.previous = kp.current kp.current = keyPair - signalSend(peer.signal.newKeyPair) // TODO: This more places (after confirming the key) + peer.signal.newKeyPair.Send() } } else { diff --git a/src/peer.go b/src/peer.go index f3eb6c2..f582556 100644 --- a/src/peer.go +++ b/src/peer.go @@ -28,30 +28,26 @@ type Peer struct { nextKeepalive time.Time } signal struct { - newKeyPair chan struct{} // (size 1) : a new key pair was generated - handshakeBegin chan struct{} // (size 1) : request that a new handshake be started ("queue handshake") - handshakeCompleted chan struct{} // (size 1) : handshake completed - handshakeReset chan struct{} // (size 1) : reset handshake negotiation state - flushNonceQueue chan struct{} // (size 1) : empty queued packets - messageSend chan struct{} // (size 1) : a message was send to the peer - messageReceived chan struct{} // (size 1) : an authenticated message was received - stop chan struct{} // (size 0) : close to stop all goroutines for peer + newKeyPair Signal // size 1, new key pair was generated + handshakeCompleted Signal // size 1, handshake completed + handshakeBegin Signal // size 1, begin new handshake begin + flushNonceQueue Signal // size 1, empty queued packets + messageSend Signal // size 1, message was send to peer + messageReceived Signal // size 1, authenticated message recv + stop Signal // size 0, stop all goroutines } timer struct { // state related to WireGuard timers - keepalivePersistent *time.Timer // set for persistent keepalives - keepalivePassive *time.Timer // set upon recieving messages - newHandshake *time.Timer // begin a new handshake (after Keepalive + RekeyTimeout) - zeroAllKeys *time.Timer // zero all key material (after RejectAfterTime*3) - handshakeDeadline *time.Timer // Current handshake must be completed + keepalivePersistent Timer // set for persistent keepalives + keepalivePassive Timer // set upon recieving messages + newHandshake Timer // begin a new handshake (stale) + zeroAllKeys Timer // zero all key material + handshakeDeadline Timer // complete handshake timeout + handshakeTimeout Timer // current handshake message timeout - pendingKeepalivePassive bool - pendingNewHandshake bool - pendingZeroAllKeys bool - - needAnotherKeepalive bool sendLastMinuteHandshake bool + needAnotherKeepalive bool } queue struct { nonce chan *QueueOutboundElement // nonce / pre-handshake queue @@ -71,10 +67,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.mac.Init(pk) peer.device = device - peer.timer.keepalivePersistent = NewStoppedTimer() - peer.timer.keepalivePassive = NewStoppedTimer() - peer.timer.newHandshake = NewStoppedTimer() - peer.timer.zeroAllKeys = NewStoppedTimer() + peer.timer.keepalivePersistent = NewTimer() + peer.timer.keepalivePassive = NewTimer() + peer.timer.newHandshake = NewTimer() + peer.timer.zeroAllKeys = NewTimer() + peer.timer.handshakeDeadline = NewTimer() + peer.timer.handshakeTimeout = NewTimer() // assign id for debugging @@ -102,7 +100,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake := &peer.handshake handshake.mutex.Lock() handshake.remoteStatic = pk - handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic) + handshake.precomputedStaticStatic = + device.privateKey.sharedSecret(handshake.remoteStatic) handshake.mutex.Unlock() // reset endpoint @@ -117,16 +116,14 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // prepare signaling & routines - peer.signal.stop = make(chan struct{}) - peer.signal.newKeyPair = make(chan struct{}, 1) - peer.signal.handshakeBegin = make(chan struct{}, 1) - peer.signal.handshakeReset = make(chan struct{}, 1) - peer.signal.handshakeCompleted = make(chan struct{}, 1) - peer.signal.flushNonceQueue = make(chan struct{}, 1) + peer.signal.stop = NewSignal() + peer.signal.newKeyPair = NewSignal() + peer.signal.handshakeBegin = NewSignal() + peer.signal.handshakeCompleted = NewSignal() + peer.signal.flushNonceQueue = NewSignal() go peer.RoutineNonce() go peer.RoutineTimerHandler() - go peer.RoutineHandshakeInitiator() go peer.RoutineSequentialSender() go peer.RoutineSequentialReceiver() @@ -163,5 +160,5 @@ func (peer *Peer) String() string { } func (peer *Peer) Close() { - close(peer.signal.stop) + peer.signal.stop.Broadcast() } diff --git a/src/receive.go b/src/receive.go index 0b0efbf..7d493b0 100644 --- a/src/receive.go +++ b/src/receive.go @@ -482,7 +482,8 @@ func (peer *Peer) RoutineSequentialReceiver() { for { select { - case <-peer.signal.stop: + + case <-peer.signal.stop.Wait(): logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id) return diff --git a/src/send.go b/src/send.go index 52872f6..35a4a6e 100644 --- a/src/send.go +++ b/src/send.go @@ -164,7 +164,7 @@ func (device *Device) RoutineReadFromTUN() { // insert into nonce/pre-handshake queue - signalSend(peer.signal.handshakeReset) + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) addToOutboundQueue(peer.queue.nonce, elem) elem = device.NewOutboundElement() } @@ -186,7 +186,7 @@ func (peer *Peer) RoutineNonce() { for { NextPacket: select { - case <-peer.signal.stop: + case <-peer.signal.stop.Wait(): return case elem := <-peer.queue.nonce: @@ -201,16 +201,17 @@ func (peer *Peer) RoutineNonce() { } } - signalSend(peer.signal.handshakeBegin) + peer.signal.handshakeBegin.Send() + logDebug.Println("Awaiting key-pair for", peer.String()) select { - case <-peer.signal.newKeyPair: - case <-peer.signal.flushNonceQueue: + case <-peer.signal.newKeyPair.Wait(): + case <-peer.signal.flushNonceQueue.Wait(): logDebug.Println("Clearing queue for", peer.String()) peer.FlushNonceQueue() goto NextPacket - case <-peer.signal.stop: + case <-peer.signal.stop.Wait(): return } } @@ -309,8 +310,10 @@ func (peer *Peer) RoutineSequentialSender() { for { select { - case <-peer.signal.stop: - logDebug.Println("Routine, sequential sender, stopped for", peer.String()) + + case <-peer.signal.stop.Wait(): + logDebug.Println( + "Routine, sequential sender, stopped for", peer.String()) return case elem := <-peer.queue.outbound: diff --git a/src/signal.go b/src/signal.go new file mode 100644 index 0000000..96b21bb --- /dev/null +++ b/src/signal.go @@ -0,0 +1,45 @@ +package main + +type Signal struct { + enabled AtomicBool + C chan struct{} +} + +func NewSignal() (s Signal) { + s.C = make(chan struct{}, 1) + s.Enable() + return +} + +func (s *Signal) Disable() { + s.enabled.Set(false) + s.Clear() +} + +func (s *Signal) Enable() { + s.enabled.Set(true) +} + +func (s *Signal) Send() { + if s.enabled.Get() { + select { + case s.C <- struct{}{}: + default: + } + } +} + +func (s Signal) Clear() { + select { + case <-s.C: + default: + } +} + +func (s Signal) Broadcast() { + close(s.C) // unblocks all selectors +} + +func (s Signal) Wait() chan struct{} { + return s.C +} diff --git a/src/timer.go b/src/timer.go new file mode 100644 index 0000000..3def253 --- /dev/null +++ b/src/timer.go @@ -0,0 +1,65 @@ +package main + +import ( + "time" +) + +type Timer struct { + pending AtomicBool + timer *time.Timer +} + +/* Starts the timer if not already pending + */ +func (t *Timer) Start(dur time.Duration) bool { + set := t.pending.Swap(true) + if !set { + t.timer.Reset(dur) + return true + } + return false +} + +/* Stops the timer + */ +func (t *Timer) Stop() { + set := t.pending.Swap(true) + if set { + t.timer.Stop() + select { + case <-t.timer.C: + default: + } + } + t.pending.Set(false) +} + +func (t *Timer) Pending() bool { + return t.pending.Get() +} + +func (t *Timer) Reset(dur time.Duration) { + t.pending.Set(false) + t.Start(dur) +} + +func (t *Timer) Push(dur time.Duration) { + if t.pending.Get() { + t.Reset(dur) + } +} + +func (t *Timer) Wait() <-chan time.Time { + return t.timer.C +} + +func NewTimer() (t Timer) { + t.pending.Set(false) + t.timer = time.NewTimer(0) + t.timer.Stop() + select { + case <-t.timer.C: + default: + } + return +} diff --git a/src/timers.go b/src/timers.go index 5848b2a..64aeca8 100644 --- a/src/timers.go +++ b/src/timers.go @@ -18,10 +18,10 @@ func (peer *Peer) KeepKeyFreshSending() { } nonce := atomic.LoadUint64(&kp.sendNonce) if nonce > RekeyAfterMessages { - signalSend(peer.signal.handshakeBegin) + peer.signal.handshakeBegin.Send() } if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime { - signalSend(peer.signal.handshakeBegin) + peer.signal.handshakeBegin.Send() } } @@ -44,7 +44,7 @@ func (peer *Peer) KeepKeyFreshReceiving() { send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving if send { // do a last minute attempt at initiating a new handshake - signalSend(peer.signal.handshakeBegin) + peer.signal.handshakeBegin.Send() peer.timer.sendLastMinuteHandshake = true } } @@ -69,34 +69,36 @@ func (peer *Peer) SendKeepAlive() bool { * Sent non-empty (authenticated) transport message */ func (peer *Peer) TimerDataSent() { - timerStop(peer.timer.keepalivePassive) - if !peer.timer.pendingNewHandshake { - peer.timer.pendingNewHandshake = true + peer.timer.keepalivePassive.Stop() + if peer.timer.newHandshake.Pending() { peer.timer.newHandshake.Reset(NewHandshakeTime) } } /* Event: * Received non-empty (authenticated) transport message + * + * Action: + * Set a timer to confirm the message using a keep-alive (if not already set) */ func (peer *Peer) TimerDataReceived() { - if peer.timer.pendingKeepalivePassive { + if !peer.timer.keepalivePassive.Start(KeepaliveTimeout) { peer.timer.needAnotherKeepalive = true - return } - peer.timer.pendingKeepalivePassive = false - peer.timer.keepalivePassive.Reset(KeepaliveTimeout) } /* Event: * Any (authenticated) packet received */ func (peer *Peer) TimerAnyAuthenticatedPacketReceived() { - timerStop(peer.timer.newHandshake) + peer.timer.newHandshake.Stop() } /* Event: * Any authenticated packet send / received. + * + * Action: + * Push persistent keep-alive into the future */ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) @@ -117,7 +119,7 @@ func (peer *Peer) TimerHandshakeComplete() { &peer.stats.lastHandshakeNano, time.Now().UnixNano(), ) - signalSend(peer.signal.handshakeCompleted) + peer.signal.handshakeCompleted.Send() peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) } @@ -129,7 +131,8 @@ func (peer *Peer) TimerHandshakeComplete() { * CreateMessageInitiation * CreateMessageResponse * - * Schedules the deletion of all key material + * Action: + * Schedule the deletion of all key material * upon failure to complete a handshake */ func (peer *Peer) TimerEphemeralKeyCreated() { @@ -139,18 +142,18 @@ func (peer *Peer) TimerEphemeralKeyCreated() { func (peer *Peer) RoutineTimerHandler() { device := peer.device + logInfo := device.log.Info logDebug := device.log.Debug logDebug.Println("Routine, timer handler, started for peer", peer.String()) for { select { - case <-peer.signal.stop: - return + /* timers */ - // keep-alives + // keep-alive - case <-peer.timer.keepalivePersistent.C: + case <-peer.timer.keepalivePersistent.Wait(): interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) if interval > 0 { @@ -158,7 +161,7 @@ func (peer *Peer) RoutineTimerHandler() { peer.SendKeepAlive() } - case <-peer.timer.keepalivePassive.C: + case <-peer.timer.keepalivePassive.Wait(): logDebug.Println("Sending keep-alive to", peer.String()) @@ -169,17 +172,9 @@ func (peer *Peer) RoutineTimerHandler() { peer.timer.needAnotherKeepalive = false } - // unresponsive session + // clear key material timer - case <-peer.timer.newHandshake.C: - - logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply") - - signalSend(peer.signal.handshakeBegin) - - // clear key material - - case <-peer.timer.zeroAllKeys.C: + case <-peer.timer.zeroAllKeys.Wait(): logDebug.Println("Clearing all key material for", peer.String()) @@ -215,125 +210,106 @@ func (peer *Peer) RoutineTimerHandler() { setZero(hs.chainKey[:]) setZero(hs.hash[:]) hs.mutex.Unlock() - } - } -} -/* This is the state machine for handshake initiation - * - * Associated with this routine is the signal "handshakeBegin" - * The routine will read from the "handshakeBegin" channel - * at most every RekeyTimeout seconds - */ -func (peer *Peer) RoutineHandshakeInitiator() { - device := peer.device + // handshake timers - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - logDebug.Println("Routine, handshake initiator, started for", peer.String()) + case <-peer.timer.newHandshake.Wait(): + logInfo.Println("Retrying handshake with", peer.String()) + peer.signal.handshakeBegin.Send() - var temp [256]byte + case <-peer.timer.handshakeTimeout.Wait(): - for { + // clear source (in case this is causing problems) - // wait for signal - - select { - case <-peer.signal.handshakeBegin: - case <-peer.signal.stop: - return - } - - // set deadline - - BeginHandshakes: - - signalClear(peer.signal.handshakeReset) - deadline := time.NewTimer(RekeyAttemptTime) - - AttemptHandshakes: - - for attempts := uint(1); ; attempts++ { - - // check if deadline reached - - select { - case <-deadline.C: - logInfo.Println("Handshake negotiation timed out for:", peer.String()) - signalSend(peer.signal.flushNonceQueue) - timerStop(peer.timer.keepalivePersistent) - break - case <-peer.signal.stop: - return - default: + peer.mutex.Lock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() } + peer.mutex.Unlock() - signalClear(peer.signal.handshakeCompleted) + // send new handshake - // create initiation message - - msg, err := peer.device.CreateMessageInitiation(peer) + err := peer.sendNewHandshake() if err != nil { - logError.Println("Failed to create handshake initiation message:", err) - break AttemptHandshakes + logInfo.Println( + "Failed to send handshake to peer:", peer.String()) } - // marshal handshake message + case <-peer.timer.handshakeDeadline.Wait(): - writer := bytes.NewBuffer(temp[:0]) - binary.Write(writer, binary.LittleEndian, msg) - packet := writer.Bytes() - peer.mac.AddMacs(packet) + // clear all queued packets and stop keep-alive - // send to endpoint + logInfo.Println( + "Handshake negotiation timed out for:", peer.String()) - err = peer.SendBuffer(packet) - jitter := time.Millisecond * time.Duration(rand.Uint32()%334) - timeout := time.NewTimer(RekeyTimeout + jitter) - if err == nil { - peer.TimerAnyAuthenticatedPacketTraversal() - logDebug.Println( - "Handshake initiation attempt", - attempts, "sent to", peer.String(), - ) - } else { - logError.Println( - "Failed to send handshake initiation message to", - peer.String(), ":", err, - ) + peer.signal.flushNonceQueue.Send() + peer.timer.keepalivePersistent.Stop() + peer.signal.handshakeBegin.Enable() + + /* signals */ + + case <-peer.signal.stop.Wait(): + return + + case <-peer.signal.handshakeBegin.Wait(): + + peer.signal.handshakeBegin.Disable() + + err := peer.sendNewHandshake() + if err != nil { + logInfo.Println( + "Failed to send handshake to peer:", peer.String()) } - // wait for handshake or timeout + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) - select { + case <-peer.signal.handshakeCompleted.Wait(): - case <-peer.signal.stop: - return + logInfo.Println( + "Handshake completed for:", peer.String()) - case <-peer.signal.handshakeCompleted: - <-timeout.C - peer.timer.sendLastMinuteHandshake = false - break AttemptHandshakes - - case <-peer.signal.handshakeReset: - <-timeout.C - goto BeginHandshakes - - case <-timeout.C: - - // clear source address of peer - - peer.mutex.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.mutex.Unlock() - } + peer.timer.handshakeTimeout.Stop() + peer.timer.handshakeDeadline.Stop() + peer.signal.handshakeBegin.Enable() } - - // clear signal set in the meantime - - signalClear(peer.signal.handshakeBegin) } } + +/* Sends a new handshake initiation message to the peer (endpoint) + */ +func (peer *Peer) sendNewHandshake() error { + + // temporarily disable the handshake complete signal + + peer.signal.handshakeCompleted.Disable() + + // create initiation message + + msg, err := peer.device.CreateMessageInitiation(peer) + if err != nil { + return err + } + + // marshal handshake message + + var buff [MessageInitiationSize]byte + writer := bytes.NewBuffer(buff[:0]) + binary.Write(writer, binary.LittleEndian, msg) + packet := writer.Bytes() + peer.mac.AddMacs(packet) + + // send to endpoint + + err = peer.SendBuffer(packet) + if err == nil { + peer.TimerAnyAuthenticatedPacketTraversal() + peer.signal.handshakeCompleted.Enable() + } + + // set timeout + + jitter := time.Millisecond * time.Duration(rand.Uint32()%334) + peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter) + + return err +} diff --git a/src/uapi.go b/src/uapi.go index 7ab3c4a..155f483 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -221,7 +221,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return &IPCError{Code: ipcErrorInvalid} } } - signalSend(peer.signal.handshakeReset) + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) dummy = false } @@ -265,7 +265,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return err } peer.endpoint = endpoint - signalSend(peer.signal.handshakeReset) + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) return nil }()