diff --git a/event.go b/event.go new file mode 100644 index 0000000..ccf57c2 --- /dev/null +++ b/event.go @@ -0,0 +1,43 @@ +package main + +import ( + "sync/atomic" + "time" +) + +type Event struct { + guard int32 + next time.Time + interval time.Duration + C chan struct{} +} + +func newEvent(interval time.Duration) *Event { + return &Event{ + guard: 0, + next: time.Now(), + interval: interval, + C: make(chan struct{}, 1), + } +} + +func (e *Event) Clear() { + select { + case <-e.C: + default: + } +} + +func (e *Event) Fire() { + if atomic.SwapInt32(&e.guard, 1) != 0 { + return + } + if now := time.Now(); now.After(e.next) { + select { + case e.C <- struct{}{}: + default: + } + e.next = now.Add(e.interval) + } + atomic.StoreInt32(&e.guard, 0) +} diff --git a/noise-protocol.go b/noise-protocol.go index b880ede..35e95ef 100644 --- a/noise-protocol.go +++ b/noise-protocol.go @@ -571,7 +571,7 @@ func (peer *Peer) NewKeyPair() *KeyPair { } else { kp.previous = kp.current kp.current = keyPair - peer.signal.newKeyPair.Send() + peer.event.newKeyPair.Fire() } } else { diff --git a/peer.go b/peer.go index 9703b58..0b947fd 100644 --- a/peer.go +++ b/peer.go @@ -15,6 +15,7 @@ import ( const ( PeerRoutineNumber = 4 + EventInterval = 10 * time.Millisecond ) type Peer struct { @@ -40,26 +41,23 @@ type Peer struct { nextKeepalive time.Time } + event struct { + dataSent *Event + dataReceived *Event + anyAuthenticatedPacketReceived *Event + anyAuthenticatedPacketTraversal *Event + handshakeCompleted *Event + handshakePushDeadline *Event + handshakeBegin *Event + ephemeralKeyCreated *Event + newKeyPair *Event + } + signal struct { - newKeyPair Signal // size 1, new key pair was generated - handshakeCompleted Signal // size 1, handshake completed - handshakeBegin Signal // size 1, begin new handshake begin - flushNonceQueue Signal // size 1, empty queued packets - messageSend Signal // size 1, message was send to peer - messageReceived Signal // size 1, authenticated message recv + flushNonceQueue chan struct{} // size 0, empty queued packets } timer struct { - - // state related to WireGuard timers - - keepalivePersistent Timer // set for persistent keep-alive - keepalivePassive Timer // set upon receiving messages - zeroAllKeys Timer // zero all key material - handshakeNew Timer // begin a new handshake (stale) - handshakeDeadline Timer // complete handshake timeout - handshakeTimeout Timer // current handshake message timeout - sendLastMinuteHandshake AtomicBool needAnotherKeepalive AtomicBool } @@ -113,12 +111,17 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.device = device peer.isRunning.Set(false) - peer.timer.zeroAllKeys = NewTimer() - peer.timer.keepalivePersistent = NewTimer() - peer.timer.keepalivePassive = NewTimer() - peer.timer.handshakeNew = NewTimer() - peer.timer.handshakeDeadline = NewTimer() - peer.timer.handshakeTimeout = NewTimer() + // events + + peer.event.dataSent = newEvent(EventInterval) + peer.event.dataReceived = newEvent(EventInterval) + peer.event.anyAuthenticatedPacketReceived = newEvent(EventInterval) + peer.event.anyAuthenticatedPacketTraversal = newEvent(EventInterval) + peer.event.handshakeCompleted = newEvent(EventInterval) + peer.event.handshakePushDeadline = newEvent(EventInterval) + peer.event.handshakeBegin = newEvent(EventInterval) + peer.event.ephemeralKeyCreated = newEvent(EventInterval) + peer.event.newKeyPair = newEvent(EventInterval) // map public key @@ -200,7 +203,7 @@ func (peer *Peer) Start() { } device := peer.device - device.log.Debug.Println(peer.String() + ": Starting...") + device.log.Debug.Println(peer, ": Starting...") // sanity check : these should be 0 @@ -209,10 +212,7 @@ func (peer *Peer) Start() { // prepare queues and signals - peer.signal.newKeyPair = NewSignal() - peer.signal.handshakeBegin = NewSignal() - peer.signal.handshakeCompleted = NewSignal() - peer.signal.flushNonceQueue = NewSignal() + peer.signal.flushNonceQueue = make(chan struct{}) peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) @@ -247,7 +247,7 @@ func (peer *Peer) Stop() { } device := peer.device - device.log.Debug.Println(peer.String() + ": Stopping...") + device.log.Debug.Println(peer, ": Stopping...") // stop & wait for ongoing peer routines @@ -255,15 +255,6 @@ func (peer *Peer) Stop() { peer.routines.stop.Broadcast() peer.routines.stopping.Wait() - // stop timers - - peer.timer.keepalivePersistent.Stop() - peer.timer.keepalivePassive.Stop() - peer.timer.zeroAllKeys.Stop() - peer.timer.handshakeNew.Stop() - peer.timer.handshakeDeadline.Stop() - peer.timer.handshakeTimeout.Stop() - // close queues close(peer.queue.nonce) @@ -272,10 +263,8 @@ func (peer *Peer) Stop() { // close signals - peer.signal.newKeyPair.Close() - peer.signal.handshakeBegin.Close() - peer.signal.handshakeCompleted.Close() - peer.signal.flushNonceQueue.Close() + close(peer.signal.flushNonceQueue) + peer.signal.flushNonceQueue = nil // clear key pairs diff --git a/receive.go b/receive.go index 156ade5..1d8b718 100644 --- a/receive.go +++ b/receive.go @@ -212,6 +212,9 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { case MessageCookieReplyType: okay = len(packet) == MessageCookieReplySize + + default: + logDebug.Println("Received message with unknown type") } if okay { @@ -453,8 +456,8 @@ func (device *Device) RoutineHandshake() { // update timers - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() + peer.event.anyAuthenticatedPacketTraversal.Fire() + peer.event.anyAuthenticatedPacketReceived.Fire() // update endpoint @@ -462,7 +465,7 @@ func (device *Device) RoutineHandshake() { peer.endpoint = elem.endpoint peer.mutex.Unlock() - logDebug.Println(peer.String() + ": Received handshake initiation") + logDebug.Println(peer, ": Received handshake initiation") // create response @@ -475,7 +478,7 @@ func (device *Device) RoutineHandshake() { peer.TimerEphemeralKeyCreated() peer.NewKeyPair() - logDebug.Println(peer.String(), "Creating handshake response") + logDebug.Println(peer, ": Creating handshake response") writer := bytes.NewBuffer(temp[:0]) binary.Write(writer, binary.LittleEndian, response) @@ -486,9 +489,9 @@ func (device *Device) RoutineHandshake() { err = peer.SendBuffer(packet) if err == nil { - peer.TimerAnyAuthenticatedPacketTraversal() + peer.event.anyAuthenticatedPacketTraversal.Fire() } else { - logError.Println(peer.String(), "Failed to send handshake response", err) + logError.Println(peer, ": Failed to send handshake response", err) } case MessageResponseType: @@ -520,15 +523,15 @@ func (device *Device) RoutineHandshake() { peer.endpoint = elem.endpoint peer.mutex.Unlock() - logDebug.Println(peer.String() + ": Received handshake response") + logDebug.Println(peer, ": Received handshake response") peer.TimerEphemeralKeyCreated() // update timers - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() - peer.TimerHandshakeComplete() + peer.event.anyAuthenticatedPacketTraversal.Fire() + peer.event.anyAuthenticatedPacketReceived.Fire() + peer.event.handshakeCompleted.Fire() // derive key-pair @@ -547,10 +550,10 @@ func (peer *Peer) RoutineSequentialReceiver() { defer func() { peer.routines.stopping.Done() - logDebug.Println(peer.String() + ": Routine: sequential receiver - stopped") + logDebug.Println(peer, ": Routine: sequential receiver - stopped") }() - logDebug.Println(peer.String() + ": Routine: sequential receiver - started") + logDebug.Println(peer, ": Routine: sequential receiver - started") peer.routines.starting.Done() @@ -581,8 +584,8 @@ func (peer *Peer) RoutineSequentialReceiver() { continue } - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() + peer.event.anyAuthenticatedPacketTraversal.Fire() + peer.event.anyAuthenticatedPacketReceived.Fire() peer.KeepKeyFreshReceiving() // check if using new key-pair @@ -590,7 +593,7 @@ func (peer *Peer) RoutineSequentialReceiver() { kp := &peer.keyPairs kp.mutex.Lock() if kp.next == elem.keyPair { - peer.TimerHandshakeComplete() + peer.event.handshakeCompleted.Fire() if kp.previous != nil { device.DeleteKeyPair(kp.previous) } @@ -609,10 +612,10 @@ func (peer *Peer) RoutineSequentialReceiver() { // check for keep-alive if len(elem.packet) == 0 { - logDebug.Println("Received keep-alive from", peer.String()) + logDebug.Println(peer, ": Received keep-alive") continue } - peer.TimerDataReceived() + peer.event.dataReceived.Fire() // verify source and strip padding @@ -639,7 +642,7 @@ func (peer *Peer) RoutineSequentialReceiver() { if device.routing.table.LookupIPv4(src) != peer { logInfo.Println( "IPv4 packet with disallowed source address from", - peer.String(), + peer, ) continue } @@ -666,14 +669,14 @@ func (peer *Peer) RoutineSequentialReceiver() { src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] if device.routing.table.LookupIPv6(src) != peer { logInfo.Println( - "IPv6 packet with disallowed source address from", - peer.String(), + peer, + "sent packet with disallowed IPv6 source", ) continue } default: - logInfo.Println("Packet with invalid IP version from", peer.String()) + logInfo.Println("Packet with invalid IP version from", peer) continue } diff --git a/send.go b/send.go index 24c7f32..7423e3b 100644 --- a/send.go +++ b/send.go @@ -50,7 +50,7 @@ type QueueOutboundElement struct { peer *Peer // related peer } -func (peer *Peer) FlushNonceQueue() { +func (peer *Peer) flushNonceQueue() { elems := len(peer.queue.nonce) for i := 0; i < elems; i++ { select { @@ -180,7 +180,7 @@ func (device *Device) RoutineReadFromTUN() { // insert into nonce/pre-handshake queue if peer.isRunning.Get() { - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) + peer.event.handshakePushDeadline.Fire() addToOutboundQueue(peer.queue.nonce, elem) elem = device.NewOutboundElement() } @@ -201,11 +201,11 @@ func (peer *Peer) RoutineNonce() { defer func() { peer.routines.stopping.Done() - logDebug.Println(peer.String() + ": Routine: nonce worker - stopped") + logDebug.Println(peer, ": Routine: nonce worker - stopped") }() peer.routines.starting.Done() - logDebug.Println(peer.String() + ": Routine: nonce worker - started") + logDebug.Println(peer, ": Routine: nonce worker - started") for { NextPacket: @@ -222,6 +222,9 @@ func (peer *Peer) RoutineNonce() { // wait for key pair for { + + peer.event.newKeyPair.Clear() + keyPair = peer.keyPairs.Current() if keyPair != nil && keyPair.sendNonce < RejectAfterMessages { if time.Now().Sub(keyPair.created) < RejectAfterTime { @@ -229,16 +232,14 @@ func (peer *Peer) RoutineNonce() { } } - peer.signal.handshakeBegin.Send() + peer.event.handshakeBegin.Fire() - logDebug.Println(peer.String() + ": Awaiting key-pair") + logDebug.Println(peer, ": Awaiting key-pair") select { - case <-peer.signal.newKeyPair.Wait(): - logDebug.Println(peer.String() + ": Obtained awaited key-pair") - case <-peer.signal.flushNonceQueue.Wait(): - logDebug.Println(peer.String() + ": Flushing nonce queue") - peer.FlushNonceQueue() + case <-peer.event.newKeyPair.C: + logDebug.Println(peer, ": Obtained awaited key-pair") + case <-peer.signal.flushNonceQueue: goto NextPacket case <-peer.routines.stop.Wait(): return @@ -357,10 +358,10 @@ func (peer *Peer) RoutineSequentialSender() { defer func() { peer.routines.stopping.Done() - logDebug.Println(peer.String() + ": Routine: sequential sender - stopped") + logDebug.Println(peer, ": Routine: sequential sender - stopped") }() - logDebug.Println(peer.String() + ": Routine: sequential sender - started") + logDebug.Println(peer, ": Routine: sequential sender - started") peer.routines.starting.Done() @@ -387,16 +388,16 @@ func (peer *Peer) RoutineSequentialSender() { err := peer.SendBuffer(elem.packet) device.PutMessageBuffer(elem.buffer) if err != nil { - logDebug.Println("Failed to send authenticated packet to peer", peer.String()) + logDebug.Println("Failed to send authenticated packet to peer", peer) continue } atomic.AddUint64(&peer.stats.txBytes, length) // update timers - peer.TimerAnyAuthenticatedPacketTraversal() + peer.event.anyAuthenticatedPacketTraversal.Fire() if len(elem.packet) != MessageKeepaliveSize { - peer.TimerDataSent() + peer.event.dataSent.Fire() } peer.KeepKeyFreshSending() } diff --git a/signal.go b/signal.go index 4d51bfa..606da52 100644 --- a/signal.go +++ b/signal.go @@ -5,6 +5,13 @@ package main +func signalSend(s chan<- struct{}) { + select { + case s <- struct{}{}: + default: + } +} + type Signal struct { enabled AtomicBool C chan struct{} diff --git a/timer.go b/timer.go deleted file mode 100644 index aeab5d9..0000000 --- a/timer.go +++ /dev/null @@ -1,70 +0,0 @@ -/* SPDX-License-Identifier: GPL-2.0 - * - * Copyright (C) 2017-2018 Jason A. Donenfeld . All Rights Reserved. - */ - -package main - -import ( - "sync" - "time" -) - -type Timer struct { - mutex sync.Mutex - pending bool - timer *time.Timer -} - -/* Starts the timer if not already pending - */ -func (t *Timer) Start(dur time.Duration) bool { - t.mutex.Lock() - defer t.mutex.Unlock() - - started := !t.pending - if started { - t.timer.Reset(dur) - } - return started -} - -func (t *Timer) Stop() { - t.mutex.Lock() - defer t.mutex.Unlock() - - t.timer.Stop() - select { - case <-t.timer.C: - default: - } - t.pending = false -} - -func (t *Timer) Pending() bool { - t.mutex.Lock() - defer t.mutex.Unlock() - - return t.pending -} - -func (t *Timer) Reset(dur time.Duration) { - t.mutex.Lock() - defer t.mutex.Unlock() - t.timer.Reset(dur) -} - -func (t *Timer) Wait() <-chan time.Time { - return t.timer.C -} - -func NewTimer() (t Timer) { - t.pending = false - t.timer = time.NewTimer(time.Hour) - t.timer.Stop() - select { - case <-t.timer.C: - default: - } - return -} diff --git a/timers.go b/timers.go index 835191f..08d0561 100644 --- a/timers.go +++ b/timers.go @@ -15,8 +15,6 @@ import ( /* NOTE: * Notion of validity - * - * */ /* Called when a new authenticated message has been send @@ -29,10 +27,10 @@ func (peer *Peer) KeepKeyFreshSending() { } nonce := atomic.LoadUint64(&kp.sendNonce) if nonce > RekeyAfterMessages { - peer.signal.handshakeBegin.Send() + peer.event.handshakeBegin.Fire() } if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime { - peer.signal.handshakeBegin.Send() + peer.event.handshakeBegin.Fire() } } @@ -56,7 +54,7 @@ func (peer *Peer) KeepKeyFreshReceiving() { if send { // do a last minute attempt at initiating a new handshake peer.timer.sendLastMinuteHandshake.Set(true) - peer.signal.handshakeBegin.Send() + peer.event.handshakeBegin.Fire() } } @@ -76,57 +74,13 @@ func (peer *Peer) SendKeepAlive() bool { } } -/* Event: - * Sent non-empty (authenticated) transport message - */ -func (peer *Peer) TimerDataSent() { - peer.timer.keepalivePassive.Stop() - peer.timer.handshakeNew.Start(NewHandshakeTime) -} - -/* Event: - * Received non-empty (authenticated) transport message - * - * Action: - * Set a timer to confirm the message using a keep-alive (if not already set) - */ -func (peer *Peer) TimerDataReceived() { - if !peer.timer.keepalivePassive.Start(KeepaliveTimeout) { - peer.timer.needAnotherKeepalive.Set(true) - } -} - -/* Event: - * Any (authenticated) packet received - */ -func (peer *Peer) TimerAnyAuthenticatedPacketReceived() { - peer.timer.handshakeNew.Stop() -} - -/* Event: - * Any authenticated packet send / received. - * - * Action: - * Push persistent keep-alive into the future - */ -func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { - interval := peer.persistentKeepaliveInterval - if interval > 0 { - duration := time.Duration(interval) * time.Second - peer.timer.keepalivePersistent.Reset(duration) - } -} - /* Called after successfully completing a handshake. * i.e. after: * * - Valid handshake response * - First transport message under the "next" key */ -func (peer *Peer) TimerHandshakeComplete() { - peer.signal.handshakeCompleted.Send() - peer.device.log.Info.Println(peer.String() + ": New handshake completed") -} +// peer.device.log.Info.Println(peer, ": New handshake completed") /* Event: * An ephemeral key is generated @@ -141,17 +95,14 @@ func (peer *Peer) TimerHandshakeComplete() { * upon failure to complete a handshake */ func (peer *Peer) TimerEphemeralKeyCreated() { - peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) + peer.event.ephemeralKeyCreated.Fire() + // peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) } /* Sends a new handshake initiation message to the peer (endpoint) */ func (peer *Peer) sendNewHandshake() error { - // temporarily disable the handshake complete signal - - peer.signal.handshakeCompleted.Disable() - // create initiation message msg, err := peer.device.CreateMessageInitiation(peer) @@ -169,21 +120,15 @@ func (peer *Peer) sendNewHandshake() error { // send to endpoint - peer.TimerAnyAuthenticatedPacketTraversal() + peer.event.anyAuthenticatedPacketTraversal.Fire() - err = peer.SendBuffer(packet) - if err == nil { - peer.signal.handshakeCompleted.Enable() - } + return peer.SendBuffer(packet) +} - // set timeout - - jitter := time.Millisecond * time.Duration(rand.Uint32()%334) - - peer.timer.keepalivePassive.Stop() - peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter) - - return err +func newTimer() *time.Timer { + timer := time.NewTimer(time.Hour) + timer.Stop() + return timer } func (peer *Peer) RoutineTimerHandler() { @@ -194,24 +139,30 @@ func (peer *Peer) RoutineTimerHandler() { logDebug := device.log.Debug defer func() { - logDebug.Println(peer.String() + ": Routine: timer handler - stopped") + logDebug.Println(peer, ": Routine: timer handler - stopped") peer.routines.stopping.Done() }() - logDebug.Println(peer.String() + ": Routine: timer handler - started") + logDebug.Println(peer, ": Routine: timer handler - started") // reset all timers - peer.timer.keepalivePassive.Stop() - peer.timer.handshakeDeadline.Stop() - peer.timer.handshakeTimeout.Stop() - peer.timer.handshakeNew.Stop() - peer.timer.zeroAllKeys.Stop() + enableHandshake := true + + pendingHandshakeNew := false + pendingKeepalivePassive := false + + timerKeepalivePassive := newTimer() + timerHandshakeDeadline := newTimer() + timerHandshakeTimeout := newTimer() + timerHandshakeNew := newTimer() + timerZeroAllKeys := newTimer() + timerKeepalivePersistent := newTimer() interval := peer.persistentKeepaliveInterval if interval > 0 { duration := time.Duration(interval) * time.Second - peer.timer.keepalivePersistent.Reset(duration) + timerKeepalivePersistent.Reset(duration) } // signal synchronised setup complete @@ -228,34 +179,56 @@ func (peer *Peer) RoutineTimerHandler() { case <-peer.routines.stop.Wait(): return + /* events */ + + case <-peer.event.dataSent.C: + timerKeepalivePassive.Stop() + if !pendingHandshakeNew { + timerHandshakeNew.Reset(NewHandshakeTime) + } + + case <-peer.event.dataReceived.C: + if pendingKeepalivePassive { + peer.timer.needAnotherKeepalive.Set(true) // TODO: make local + } else { + timerKeepalivePassive.Reset(KeepaliveTimeout) + } + + case <-peer.event.anyAuthenticatedPacketTraversal.C: + interval := peer.persistentKeepaliveInterval + if interval > 0 { + duration := time.Duration(interval) * time.Second + timerKeepalivePersistent.Reset(duration) + } + /* timers */ // keep-alive - case <-peer.timer.keepalivePersistent.Wait(): + case <-timerKeepalivePersistent.C: interval := peer.persistentKeepaliveInterval if interval > 0 { - logDebug.Println(peer.String() + ": Send keep-alive (persistent)") - peer.timer.keepalivePassive.Stop() + logDebug.Println(peer, ": Send keep-alive (persistent)") + timerKeepalivePassive.Stop() peer.SendKeepAlive() } - case <-peer.timer.keepalivePassive.Wait(): + case <-timerKeepalivePassive.C: - logDebug.Println(peer.String() + ": Send keep-alive (passive)") + logDebug.Println(peer, ": Send keep-alive (passive)") peer.SendKeepAlive() if peer.timer.needAnotherKeepalive.Swap(false) { - peer.timer.keepalivePassive.Reset(KeepaliveTimeout) + timerKeepalivePassive.Reset(KeepaliveTimeout) } // clear key material timer - case <-peer.timer.zeroAllKeys.Wait(): + case <-timerZeroAllKeys.C: - logDebug.Println(peer.String() + ": Clear all key-material (timer event)") + logDebug.Println(peer, ": Clear all key-material (timer event)") hs := &peer.handshake hs.mutex.Lock() @@ -287,11 +260,11 @@ func (peer *Peer) RoutineTimerHandler() { // handshake timers - case <-peer.timer.handshakeNew.Wait(): - logInfo.Println(peer.String() + ": Retrying handshake (timer event)") - peer.signal.handshakeBegin.Send() + case <-timerHandshakeTimeout.C: - case <-peer.timer.handshakeTimeout.Wait(): + // allow new handshake to be send + + enableHandshake = true // clear source (in case this is causing problems) @@ -305,52 +278,84 @@ func (peer *Peer) RoutineTimerHandler() { err := peer.sendNewHandshake() + // set timeout + + jitter := time.Millisecond * time.Duration(rand.Uint32()%334) + timerKeepalivePassive.Stop() + timerHandshakeTimeout.Reset(RekeyTimeout + jitter) + if err != nil { - logInfo.Println(peer.String()+": Failed to send handshake initiation", err) + logInfo.Println(peer, ": Failed to send handshake initiation", err) } else { - logDebug.Println(peer.String() + ": Send handshake initiation (subsequent)") + logDebug.Println(peer, ": Send handshake initiation (subsequent)") } - case <-peer.timer.handshakeDeadline.Wait(): + // disable further handshakes + + peer.event.handshakeBegin.Clear() + enableHandshake = false + + case <-timerHandshakeDeadline.C: // clear all queued packets and stop keep-alive - logInfo.Println(peer.String() + ": Handshake negotiation timed-out") + logInfo.Println(peer, ": Handshake negotiation timed-out") - peer.signal.flushNonceQueue.Send() - peer.timer.keepalivePersistent.Stop() - peer.signal.handshakeBegin.Enable() + peer.flushNonceQueue() + signalSend(peer.signal.flushNonceQueue) + timerKeepalivePersistent.Stop() - /* signals */ + // disable further handshakes - case <-peer.signal.handshakeBegin.Wait(): + peer.event.handshakeBegin.Clear() + enableHandshake = true - peer.signal.handshakeBegin.Disable() + case <-peer.event.handshakeBegin.C: + + if !enableHandshake { + continue + } + + logDebug.Println(peer, ": Event, Handshake Begin") err := peer.sendNewHandshake() + // set timeout + + jitter := time.Millisecond * time.Duration(rand.Uint32()%334) + timerKeepalivePassive.Stop() + timerHandshakeTimeout.Reset(RekeyTimeout + jitter) + if err != nil { - logInfo.Println(peer.String()+": Failed to send handshake initiation", err) + logInfo.Println(peer, ": Failed to send handshake initiation", err) } else { - logDebug.Println(peer.String() + ": Send handshake initiation (initial)") + logDebug.Println(peer, ": Send handshake initiation (initial)") } - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) + timerHandshakeDeadline.Reset(RekeyAttemptTime) - case <-peer.signal.handshakeCompleted.Wait(): + // disable further handshakes - logInfo.Println(peer.String() + ": Handshake completed") + peer.event.handshakeBegin.Clear() + enableHandshake = false + + case <-peer.event.handshakeCompleted.C: + + logInfo.Println(peer, ": Handshake completed") atomic.StoreInt64( &peer.stats.lastHandshakeNano, time.Now().UnixNano(), ) - peer.timer.handshakeTimeout.Stop() - peer.timer.handshakeDeadline.Stop() - peer.signal.handshakeBegin.Enable() - + timerHandshakeTimeout.Stop() + timerHandshakeDeadline.Stop() peer.timer.sendLastMinuteHandshake.Set(false) + + // allow further handshakes + + peer.event.handshakeBegin.Clear() + enableHandshake = true } } } diff --git a/uapi.go b/uapi.go index a7ef662..c87a536 100644 --- a/uapi.go +++ b/uapi.go @@ -253,12 +253,10 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logError.Println("Failed to create new peer:", err) return &IPCError{Code: ipcErrorInvalid} } - logDebug.Println("UAPI: Created new peer:", peer.String()) + logDebug.Println("UAPI: Created new peer:", peer) } - peer.mutex.Lock() - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) - peer.mutex.Unlock() + peer.event.handshakePushDeadline.Fire() case "remove": @@ -269,7 +267,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return &IPCError{Code: ipcErrorInvalid} } if !dummy { - logDebug.Println("UAPI: Removing peer:", peer.String()) + logDebug.Println("UAPI: Removing peer:", peer) device.RemovePeer(peer.handshake.remoteStatic) } peer = &Peer{} @@ -279,7 +277,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // update PSK - logDebug.Println("UAPI: Updating pre-shared key for peer:", peer.String()) + logDebug.Println("UAPI: Updating pre-shared key for peer:", peer) peer.handshake.mutex.Lock() err := peer.handshake.presharedKey.FromHex(value) @@ -294,7 +292,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // set endpoint destination - logDebug.Println("UAPI: Updating endpoint for peer:", peer.String()) + logDebug.Println("UAPI: Updating endpoint for peer:", peer) err := func() error { peer.mutex.Lock() @@ -304,7 +302,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return err } peer.endpoint = endpoint - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) + peer.event.handshakePushDeadline.Fire() return nil }() @@ -317,7 +315,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // update keep-alive interval - logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer.String()) + logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer) secs, err := strconv.ParseUint(value, 10, 16) if err != nil { @@ -342,7 +340,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "replace_allowed_ips": - logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer.String()) + logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer) if value != "true" { logError.Println("Failed to set replace_allowed_ips, invalid value:", value) @@ -359,7 +357,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "allowed_ip": - logDebug.Println("UAPI: Adding allowed_ip to peer:", peer.String()) + logDebug.Println("UAPI: Adding allowed_ip to peer:", peer) _, network, err := net.ParseCIDR(value) if err != nil {