Merge branch 'master' of ssh://git.zx2c4.com/wireguard-go

This commit is contained in:
Jason A. Donenfeld 2018-05-05 04:20:16 +02:00
commit beab52258a
9 changed files with 243 additions and 267 deletions

43
event.go Normal file
View file

@ -0,0 +1,43 @@
package main
import (
"sync/atomic"
"time"
)
type Event struct {
guard int32
next time.Time
interval time.Duration
C chan struct{}
}
func newEvent(interval time.Duration) *Event {
return &Event{
guard: 0,
next: time.Now(),
interval: interval,
C: make(chan struct{}, 1),
}
}
func (e *Event) Clear() {
select {
case <-e.C:
default:
}
}
func (e *Event) Fire() {
if atomic.SwapInt32(&e.guard, 1) != 0 {
return
}
if now := time.Now(); now.After(e.next) {
select {
case e.C <- struct{}{}:
default:
}
e.next = now.Add(e.interval)
}
atomic.StoreInt32(&e.guard, 0)
}

View file

@ -571,7 +571,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
} else { } else {
kp.previous = kp.current kp.previous = kp.current
kp.current = keyPair kp.current = keyPair
peer.signal.newKeyPair.Send() peer.event.newKeyPair.Fire()
} }
} else { } else {

71
peer.go
View file

@ -15,6 +15,7 @@ import (
const ( const (
PeerRoutineNumber = 4 PeerRoutineNumber = 4
EventInterval = 10 * time.Millisecond
) )
type Peer struct { type Peer struct {
@ -40,26 +41,23 @@ type Peer struct {
nextKeepalive time.Time nextKeepalive time.Time
} }
event struct {
dataSent *Event
dataReceived *Event
anyAuthenticatedPacketReceived *Event
anyAuthenticatedPacketTraversal *Event
handshakeCompleted *Event
handshakePushDeadline *Event
handshakeBegin *Event
ephemeralKeyCreated *Event
newKeyPair *Event
}
signal struct { signal struct {
newKeyPair Signal // size 1, new key pair was generated flushNonceQueue chan struct{} // size 0, empty queued packets
handshakeCompleted Signal // size 1, handshake completed
handshakeBegin Signal // size 1, begin new handshake begin
flushNonceQueue Signal // size 1, empty queued packets
messageSend Signal // size 1, message was send to peer
messageReceived Signal // size 1, authenticated message recv
} }
timer struct { timer struct {
// state related to WireGuard timers
keepalivePersistent Timer // set for persistent keep-alive
keepalivePassive Timer // set upon receiving messages
zeroAllKeys Timer // zero all key material
handshakeNew Timer // begin a new handshake (stale)
handshakeDeadline Timer // complete handshake timeout
handshakeTimeout Timer // current handshake message timeout
sendLastMinuteHandshake AtomicBool sendLastMinuteHandshake AtomicBool
needAnotherKeepalive AtomicBool needAnotherKeepalive AtomicBool
} }
@ -113,12 +111,17 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.device = device peer.device = device
peer.isRunning.Set(false) peer.isRunning.Set(false)
peer.timer.zeroAllKeys = NewTimer() // events
peer.timer.keepalivePersistent = NewTimer()
peer.timer.keepalivePassive = NewTimer() peer.event.dataSent = newEvent(EventInterval)
peer.timer.handshakeNew = NewTimer() peer.event.dataReceived = newEvent(EventInterval)
peer.timer.handshakeDeadline = NewTimer() peer.event.anyAuthenticatedPacketReceived = newEvent(EventInterval)
peer.timer.handshakeTimeout = NewTimer() peer.event.anyAuthenticatedPacketTraversal = newEvent(EventInterval)
peer.event.handshakeCompleted = newEvent(EventInterval)
peer.event.handshakePushDeadline = newEvent(EventInterval)
peer.event.handshakeBegin = newEvent(EventInterval)
peer.event.ephemeralKeyCreated = newEvent(EventInterval)
peer.event.newKeyPair = newEvent(EventInterval)
// map public key // map public key
@ -200,7 +203,7 @@ func (peer *Peer) Start() {
} }
device := peer.device device := peer.device
device.log.Debug.Println(peer.String() + ": Starting...") device.log.Debug.Println(peer, ": Starting...")
// sanity check : these should be 0 // sanity check : these should be 0
@ -209,10 +212,7 @@ func (peer *Peer) Start() {
// prepare queues and signals // prepare queues and signals
peer.signal.newKeyPair = NewSignal() peer.signal.flushNonceQueue = make(chan struct{})
peer.signal.handshakeBegin = NewSignal()
peer.signal.handshakeCompleted = NewSignal()
peer.signal.flushNonceQueue = NewSignal()
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
@ -247,7 +247,7 @@ func (peer *Peer) Stop() {
} }
device := peer.device device := peer.device
device.log.Debug.Println(peer.String() + ": Stopping...") device.log.Debug.Println(peer, ": Stopping...")
// stop & wait for ongoing peer routines // stop & wait for ongoing peer routines
@ -255,15 +255,6 @@ func (peer *Peer) Stop() {
peer.routines.stop.Broadcast() peer.routines.stop.Broadcast()
peer.routines.stopping.Wait() peer.routines.stopping.Wait()
// stop timers
peer.timer.keepalivePersistent.Stop()
peer.timer.keepalivePassive.Stop()
peer.timer.zeroAllKeys.Stop()
peer.timer.handshakeNew.Stop()
peer.timer.handshakeDeadline.Stop()
peer.timer.handshakeTimeout.Stop()
// close queues // close queues
close(peer.queue.nonce) close(peer.queue.nonce)
@ -272,10 +263,8 @@ func (peer *Peer) Stop() {
// close signals // close signals
peer.signal.newKeyPair.Close() close(peer.signal.flushNonceQueue)
peer.signal.handshakeBegin.Close() peer.signal.flushNonceQueue = nil
peer.signal.handshakeCompleted.Close()
peer.signal.flushNonceQueue.Close()
// clear key pairs // clear key pairs

View file

@ -212,6 +212,9 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
case MessageCookieReplyType: case MessageCookieReplyType:
okay = len(packet) == MessageCookieReplySize okay = len(packet) == MessageCookieReplySize
default:
logDebug.Println("Received message with unknown type")
} }
if okay { if okay {
@ -453,8 +456,8 @@ func (device *Device) RoutineHandshake() {
// update timers // update timers
peer.TimerAnyAuthenticatedPacketTraversal() peer.event.anyAuthenticatedPacketTraversal.Fire()
peer.TimerAnyAuthenticatedPacketReceived() peer.event.anyAuthenticatedPacketReceived.Fire()
// update endpoint // update endpoint
@ -462,7 +465,7 @@ func (device *Device) RoutineHandshake() {
peer.endpoint = elem.endpoint peer.endpoint = elem.endpoint
peer.mutex.Unlock() peer.mutex.Unlock()
logDebug.Println(peer.String() + ": Received handshake initiation") logDebug.Println(peer, ": Received handshake initiation")
// create response // create response
@ -475,7 +478,7 @@ func (device *Device) RoutineHandshake() {
peer.TimerEphemeralKeyCreated() peer.TimerEphemeralKeyCreated()
peer.NewKeyPair() peer.NewKeyPair()
logDebug.Println(peer.String(), "Creating handshake response") logDebug.Println(peer, ": Creating handshake response")
writer := bytes.NewBuffer(temp[:0]) writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, response) binary.Write(writer, binary.LittleEndian, response)
@ -486,9 +489,9 @@ func (device *Device) RoutineHandshake() {
err = peer.SendBuffer(packet) err = peer.SendBuffer(packet)
if err == nil { if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal() peer.event.anyAuthenticatedPacketTraversal.Fire()
} else { } else {
logError.Println(peer.String(), "Failed to send handshake response", err) logError.Println(peer, ": Failed to send handshake response", err)
} }
case MessageResponseType: case MessageResponseType:
@ -520,15 +523,15 @@ func (device *Device) RoutineHandshake() {
peer.endpoint = elem.endpoint peer.endpoint = elem.endpoint
peer.mutex.Unlock() peer.mutex.Unlock()
logDebug.Println(peer.String() + ": Received handshake response") logDebug.Println(peer, ": Received handshake response")
peer.TimerEphemeralKeyCreated() peer.TimerEphemeralKeyCreated()
// update timers // update timers
peer.TimerAnyAuthenticatedPacketTraversal() peer.event.anyAuthenticatedPacketTraversal.Fire()
peer.TimerAnyAuthenticatedPacketReceived() peer.event.anyAuthenticatedPacketReceived.Fire()
peer.TimerHandshakeComplete() peer.event.handshakeCompleted.Fire()
// derive key-pair // derive key-pair
@ -547,10 +550,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
defer func() { defer func() {
peer.routines.stopping.Done() peer.routines.stopping.Done()
logDebug.Println(peer.String() + ": Routine: sequential receiver - stopped") logDebug.Println(peer, ": Routine: sequential receiver - stopped")
}() }()
logDebug.Println(peer.String() + ": Routine: sequential receiver - started") logDebug.Println(peer, ": Routine: sequential receiver - started")
peer.routines.starting.Done() peer.routines.starting.Done()
@ -581,8 +584,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue continue
} }
peer.TimerAnyAuthenticatedPacketTraversal() peer.event.anyAuthenticatedPacketTraversal.Fire()
peer.TimerAnyAuthenticatedPacketReceived() peer.event.anyAuthenticatedPacketReceived.Fire()
peer.KeepKeyFreshReceiving() peer.KeepKeyFreshReceiving()
// check if using new key-pair // check if using new key-pair
@ -590,7 +593,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
kp := &peer.keyPairs kp := &peer.keyPairs
kp.mutex.Lock() kp.mutex.Lock()
if kp.next == elem.keyPair { if kp.next == elem.keyPair {
peer.TimerHandshakeComplete() peer.event.handshakeCompleted.Fire()
if kp.previous != nil { if kp.previous != nil {
device.DeleteKeyPair(kp.previous) device.DeleteKeyPair(kp.previous)
} }
@ -609,10 +612,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
// check for keep-alive // check for keep-alive
if len(elem.packet) == 0 { if len(elem.packet) == 0 {
logDebug.Println("Received keep-alive from", peer.String()) logDebug.Println(peer, ": Received keep-alive")
continue continue
} }
peer.TimerDataReceived() peer.event.dataReceived.Fire()
// verify source and strip padding // verify source and strip padding
@ -639,7 +642,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
if device.routing.table.LookupIPv4(src) != peer { if device.routing.table.LookupIPv4(src) != peer {
logInfo.Println( logInfo.Println(
"IPv4 packet with disallowed source address from", "IPv4 packet with disallowed source address from",
peer.String(), peer,
) )
continue continue
} }
@ -666,14 +669,14 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routing.table.LookupIPv6(src) != peer { if device.routing.table.LookupIPv6(src) != peer {
logInfo.Println( logInfo.Println(
"IPv6 packet with disallowed source address from", peer,
peer.String(), "sent packet with disallowed IPv6 source",
) )
continue continue
} }
default: default:
logInfo.Println("Packet with invalid IP version from", peer.String()) logInfo.Println("Packet with invalid IP version from", peer)
continue continue
} }

33
send.go
View file

@ -50,7 +50,7 @@ type QueueOutboundElement struct {
peer *Peer // related peer peer *Peer // related peer
} }
func (peer *Peer) FlushNonceQueue() { func (peer *Peer) flushNonceQueue() {
elems := len(peer.queue.nonce) elems := len(peer.queue.nonce)
for i := 0; i < elems; i++ { for i := 0; i < elems; i++ {
select { select {
@ -180,7 +180,7 @@ func (device *Device) RoutineReadFromTUN() {
// insert into nonce/pre-handshake queue // insert into nonce/pre-handshake queue
if peer.isRunning.Get() { if peer.isRunning.Get() {
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) peer.event.handshakePushDeadline.Fire()
addToOutboundQueue(peer.queue.nonce, elem) addToOutboundQueue(peer.queue.nonce, elem)
elem = device.NewOutboundElement() elem = device.NewOutboundElement()
} }
@ -201,11 +201,11 @@ func (peer *Peer) RoutineNonce() {
defer func() { defer func() {
peer.routines.stopping.Done() peer.routines.stopping.Done()
logDebug.Println(peer.String() + ": Routine: nonce worker - stopped") logDebug.Println(peer, ": Routine: nonce worker - stopped")
}() }()
peer.routines.starting.Done() peer.routines.starting.Done()
logDebug.Println(peer.String() + ": Routine: nonce worker - started") logDebug.Println(peer, ": Routine: nonce worker - started")
for { for {
NextPacket: NextPacket:
@ -222,6 +222,9 @@ func (peer *Peer) RoutineNonce() {
// wait for key pair // wait for key pair
for { for {
peer.event.newKeyPair.Clear()
keyPair = peer.keyPairs.Current() keyPair = peer.keyPairs.Current()
if keyPair != nil && keyPair.sendNonce < RejectAfterMessages { if keyPair != nil && keyPair.sendNonce < RejectAfterMessages {
if time.Now().Sub(keyPair.created) < RejectAfterTime { if time.Now().Sub(keyPair.created) < RejectAfterTime {
@ -229,16 +232,14 @@ func (peer *Peer) RoutineNonce() {
} }
} }
peer.signal.handshakeBegin.Send() peer.event.handshakeBegin.Fire()
logDebug.Println(peer.String() + ": Awaiting key-pair") logDebug.Println(peer, ": Awaiting key-pair")
select { select {
case <-peer.signal.newKeyPair.Wait(): case <-peer.event.newKeyPair.C:
logDebug.Println(peer.String() + ": Obtained awaited key-pair") logDebug.Println(peer, ": Obtained awaited key-pair")
case <-peer.signal.flushNonceQueue.Wait(): case <-peer.signal.flushNonceQueue:
logDebug.Println(peer.String() + ": Flushing nonce queue")
peer.FlushNonceQueue()
goto NextPacket goto NextPacket
case <-peer.routines.stop.Wait(): case <-peer.routines.stop.Wait():
return return
@ -357,10 +358,10 @@ func (peer *Peer) RoutineSequentialSender() {
defer func() { defer func() {
peer.routines.stopping.Done() peer.routines.stopping.Done()
logDebug.Println(peer.String() + ": Routine: sequential sender - stopped") logDebug.Println(peer, ": Routine: sequential sender - stopped")
}() }()
logDebug.Println(peer.String() + ": Routine: sequential sender - started") logDebug.Println(peer, ": Routine: sequential sender - started")
peer.routines.starting.Done() peer.routines.starting.Done()
@ -387,16 +388,16 @@ func (peer *Peer) RoutineSequentialSender() {
err := peer.SendBuffer(elem.packet) err := peer.SendBuffer(elem.packet)
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
if err != nil { if err != nil {
logDebug.Println("Failed to send authenticated packet to peer", peer.String()) logDebug.Println("Failed to send authenticated packet to peer", peer)
continue continue
} }
atomic.AddUint64(&peer.stats.txBytes, length) atomic.AddUint64(&peer.stats.txBytes, length)
// update timers // update timers
peer.TimerAnyAuthenticatedPacketTraversal() peer.event.anyAuthenticatedPacketTraversal.Fire()
if len(elem.packet) != MessageKeepaliveSize { if len(elem.packet) != MessageKeepaliveSize {
peer.TimerDataSent() peer.event.dataSent.Fire()
} }
peer.KeepKeyFreshSending() peer.KeepKeyFreshSending()
} }

View file

@ -5,6 +5,13 @@
package main package main
func signalSend(s chan<- struct{}) {
select {
case s <- struct{}{}:
default:
}
}
type Signal struct { type Signal struct {
enabled AtomicBool enabled AtomicBool
C chan struct{} C chan struct{}

View file

@ -1,70 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*/
package main
import (
"sync"
"time"
)
type Timer struct {
mutex sync.Mutex
pending bool
timer *time.Timer
}
/* Starts the timer if not already pending
*/
func (t *Timer) Start(dur time.Duration) bool {
t.mutex.Lock()
defer t.mutex.Unlock()
started := !t.pending
if started {
t.timer.Reset(dur)
}
return started
}
func (t *Timer) Stop() {
t.mutex.Lock()
defer t.mutex.Unlock()
t.timer.Stop()
select {
case <-t.timer.C:
default:
}
t.pending = false
}
func (t *Timer) Pending() bool {
t.mutex.Lock()
defer t.mutex.Unlock()
return t.pending
}
func (t *Timer) Reset(dur time.Duration) {
t.mutex.Lock()
defer t.mutex.Unlock()
t.timer.Reset(dur)
}
func (t *Timer) Wait() <-chan time.Time {
return t.timer.C
}
func NewTimer() (t Timer) {
t.pending = false
t.timer = time.NewTimer(time.Hour)
t.timer.Stop()
select {
case <-t.timer.C:
default:
}
return
}

219
timers.go
View file

@ -15,8 +15,6 @@ import (
/* NOTE: /* NOTE:
* Notion of validity * Notion of validity
*
*
*/ */
/* Called when a new authenticated message has been send /* Called when a new authenticated message has been send
@ -29,10 +27,10 @@ func (peer *Peer) KeepKeyFreshSending() {
} }
nonce := atomic.LoadUint64(&kp.sendNonce) nonce := atomic.LoadUint64(&kp.sendNonce)
if nonce > RekeyAfterMessages { if nonce > RekeyAfterMessages {
peer.signal.handshakeBegin.Send() peer.event.handshakeBegin.Fire()
} }
if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime { if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
peer.signal.handshakeBegin.Send() peer.event.handshakeBegin.Fire()
} }
} }
@ -56,7 +54,7 @@ func (peer *Peer) KeepKeyFreshReceiving() {
if send { if send {
// do a last minute attempt at initiating a new handshake // do a last minute attempt at initiating a new handshake
peer.timer.sendLastMinuteHandshake.Set(true) peer.timer.sendLastMinuteHandshake.Set(true)
peer.signal.handshakeBegin.Send() peer.event.handshakeBegin.Fire()
} }
} }
@ -76,57 +74,13 @@ func (peer *Peer) SendKeepAlive() bool {
} }
} }
/* Event:
* Sent non-empty (authenticated) transport message
*/
func (peer *Peer) TimerDataSent() {
peer.timer.keepalivePassive.Stop()
peer.timer.handshakeNew.Start(NewHandshakeTime)
}
/* Event:
* Received non-empty (authenticated) transport message
*
* Action:
* Set a timer to confirm the message using a keep-alive (if not already set)
*/
func (peer *Peer) TimerDataReceived() {
if !peer.timer.keepalivePassive.Start(KeepaliveTimeout) {
peer.timer.needAnotherKeepalive.Set(true)
}
}
/* Event:
* Any (authenticated) packet received
*/
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
peer.timer.handshakeNew.Stop()
}
/* Event:
* Any authenticated packet send / received.
*
* Action:
* Push persistent keep-alive into the future
*/
func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
interval := peer.persistentKeepaliveInterval
if interval > 0 {
duration := time.Duration(interval) * time.Second
peer.timer.keepalivePersistent.Reset(duration)
}
}
/* Called after successfully completing a handshake. /* Called after successfully completing a handshake.
* i.e. after: * i.e. after:
* *
* - Valid handshake response * - Valid handshake response
* - First transport message under the "next" key * - First transport message under the "next" key
*/ */
func (peer *Peer) TimerHandshakeComplete() { // peer.device.log.Info.Println(peer, ": New handshake completed")
peer.signal.handshakeCompleted.Send()
peer.device.log.Info.Println(peer.String() + ": New handshake completed")
}
/* Event: /* Event:
* An ephemeral key is generated * An ephemeral key is generated
@ -141,17 +95,14 @@ func (peer *Peer) TimerHandshakeComplete() {
* upon failure to complete a handshake * upon failure to complete a handshake
*/ */
func (peer *Peer) TimerEphemeralKeyCreated() { func (peer *Peer) TimerEphemeralKeyCreated() {
peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) peer.event.ephemeralKeyCreated.Fire()
// peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
} }
/* Sends a new handshake initiation message to the peer (endpoint) /* Sends a new handshake initiation message to the peer (endpoint)
*/ */
func (peer *Peer) sendNewHandshake() error { func (peer *Peer) sendNewHandshake() error {
// temporarily disable the handshake complete signal
peer.signal.handshakeCompleted.Disable()
// create initiation message // create initiation message
msg, err := peer.device.CreateMessageInitiation(peer) msg, err := peer.device.CreateMessageInitiation(peer)
@ -169,21 +120,15 @@ func (peer *Peer) sendNewHandshake() error {
// send to endpoint // send to endpoint
peer.TimerAnyAuthenticatedPacketTraversal() peer.event.anyAuthenticatedPacketTraversal.Fire()
err = peer.SendBuffer(packet) return peer.SendBuffer(packet)
if err == nil { }
peer.signal.handshakeCompleted.Enable()
}
// set timeout func newTimer() *time.Timer {
timer := time.NewTimer(time.Hour)
jitter := time.Millisecond * time.Duration(rand.Uint32()%334) timer.Stop()
return timer
peer.timer.keepalivePassive.Stop()
peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
return err
} }
func (peer *Peer) RoutineTimerHandler() { func (peer *Peer) RoutineTimerHandler() {
@ -194,24 +139,30 @@ func (peer *Peer) RoutineTimerHandler() {
logDebug := device.log.Debug logDebug := device.log.Debug
defer func() { defer func() {
logDebug.Println(peer.String() + ": Routine: timer handler - stopped") logDebug.Println(peer, ": Routine: timer handler - stopped")
peer.routines.stopping.Done() peer.routines.stopping.Done()
}() }()
logDebug.Println(peer.String() + ": Routine: timer handler - started") logDebug.Println(peer, ": Routine: timer handler - started")
// reset all timers // reset all timers
peer.timer.keepalivePassive.Stop() enableHandshake := true
peer.timer.handshakeDeadline.Stop()
peer.timer.handshakeTimeout.Stop() pendingHandshakeNew := false
peer.timer.handshakeNew.Stop() pendingKeepalivePassive := false
peer.timer.zeroAllKeys.Stop()
timerKeepalivePassive := newTimer()
timerHandshakeDeadline := newTimer()
timerHandshakeTimeout := newTimer()
timerHandshakeNew := newTimer()
timerZeroAllKeys := newTimer()
timerKeepalivePersistent := newTimer()
interval := peer.persistentKeepaliveInterval interval := peer.persistentKeepaliveInterval
if interval > 0 { if interval > 0 {
duration := time.Duration(interval) * time.Second duration := time.Duration(interval) * time.Second
peer.timer.keepalivePersistent.Reset(duration) timerKeepalivePersistent.Reset(duration)
} }
// signal synchronised setup complete // signal synchronised setup complete
@ -228,34 +179,56 @@ func (peer *Peer) RoutineTimerHandler() {
case <-peer.routines.stop.Wait(): case <-peer.routines.stop.Wait():
return return
/* events */
case <-peer.event.dataSent.C:
timerKeepalivePassive.Stop()
if !pendingHandshakeNew {
timerHandshakeNew.Reset(NewHandshakeTime)
}
case <-peer.event.dataReceived.C:
if pendingKeepalivePassive {
peer.timer.needAnotherKeepalive.Set(true) // TODO: make local
} else {
timerKeepalivePassive.Reset(KeepaliveTimeout)
}
case <-peer.event.anyAuthenticatedPacketTraversal.C:
interval := peer.persistentKeepaliveInterval
if interval > 0 {
duration := time.Duration(interval) * time.Second
timerKeepalivePersistent.Reset(duration)
}
/* timers */ /* timers */
// keep-alive // keep-alive
case <-peer.timer.keepalivePersistent.Wait(): case <-timerKeepalivePersistent.C:
interval := peer.persistentKeepaliveInterval interval := peer.persistentKeepaliveInterval
if interval > 0 { if interval > 0 {
logDebug.Println(peer.String() + ": Send keep-alive (persistent)") logDebug.Println(peer, ": Send keep-alive (persistent)")
peer.timer.keepalivePassive.Stop() timerKeepalivePassive.Stop()
peer.SendKeepAlive() peer.SendKeepAlive()
} }
case <-peer.timer.keepalivePassive.Wait(): case <-timerKeepalivePassive.C:
logDebug.Println(peer.String() + ": Send keep-alive (passive)") logDebug.Println(peer, ": Send keep-alive (passive)")
peer.SendKeepAlive() peer.SendKeepAlive()
if peer.timer.needAnotherKeepalive.Swap(false) { if peer.timer.needAnotherKeepalive.Swap(false) {
peer.timer.keepalivePassive.Reset(KeepaliveTimeout) timerKeepalivePassive.Reset(KeepaliveTimeout)
} }
// clear key material timer // clear key material timer
case <-peer.timer.zeroAllKeys.Wait(): case <-timerZeroAllKeys.C:
logDebug.Println(peer.String() + ": Clear all key-material (timer event)") logDebug.Println(peer, ": Clear all key-material (timer event)")
hs := &peer.handshake hs := &peer.handshake
hs.mutex.Lock() hs.mutex.Lock()
@ -287,11 +260,11 @@ func (peer *Peer) RoutineTimerHandler() {
// handshake timers // handshake timers
case <-peer.timer.handshakeNew.Wait(): case <-timerHandshakeTimeout.C:
logInfo.Println(peer.String() + ": Retrying handshake (timer event)")
peer.signal.handshakeBegin.Send()
case <-peer.timer.handshakeTimeout.Wait(): // allow new handshake to be send
enableHandshake = true
// clear source (in case this is causing problems) // clear source (in case this is causing problems)
@ -305,52 +278,84 @@ func (peer *Peer) RoutineTimerHandler() {
err := peer.sendNewHandshake() err := peer.sendNewHandshake()
// set timeout
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
timerKeepalivePassive.Stop()
timerHandshakeTimeout.Reset(RekeyTimeout + jitter)
if err != nil { if err != nil {
logInfo.Println(peer.String()+": Failed to send handshake initiation", err) logInfo.Println(peer, ": Failed to send handshake initiation", err)
} else { } else {
logDebug.Println(peer.String() + ": Send handshake initiation (subsequent)") logDebug.Println(peer, ": Send handshake initiation (subsequent)")
} }
case <-peer.timer.handshakeDeadline.Wait(): // disable further handshakes
peer.event.handshakeBegin.Clear()
enableHandshake = false
case <-timerHandshakeDeadline.C:
// clear all queued packets and stop keep-alive // clear all queued packets and stop keep-alive
logInfo.Println(peer.String() + ": Handshake negotiation timed-out") logInfo.Println(peer, ": Handshake negotiation timed-out")
peer.signal.flushNonceQueue.Send() peer.flushNonceQueue()
peer.timer.keepalivePersistent.Stop() signalSend(peer.signal.flushNonceQueue)
peer.signal.handshakeBegin.Enable() timerKeepalivePersistent.Stop()
/* signals */ // disable further handshakes
case <-peer.signal.handshakeBegin.Wait(): peer.event.handshakeBegin.Clear()
enableHandshake = true
peer.signal.handshakeBegin.Disable() case <-peer.event.handshakeBegin.C:
if !enableHandshake {
continue
}
logDebug.Println(peer, ": Event, Handshake Begin")
err := peer.sendNewHandshake() err := peer.sendNewHandshake()
// set timeout
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
timerKeepalivePassive.Stop()
timerHandshakeTimeout.Reset(RekeyTimeout + jitter)
if err != nil { if err != nil {
logInfo.Println(peer.String()+": Failed to send handshake initiation", err) logInfo.Println(peer, ": Failed to send handshake initiation", err)
} else { } else {
logDebug.Println(peer.String() + ": Send handshake initiation (initial)") logDebug.Println(peer, ": Send handshake initiation (initial)")
} }
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) timerHandshakeDeadline.Reset(RekeyAttemptTime)
case <-peer.signal.handshakeCompleted.Wait(): // disable further handshakes
logInfo.Println(peer.String() + ": Handshake completed") peer.event.handshakeBegin.Clear()
enableHandshake = false
case <-peer.event.handshakeCompleted.C:
logInfo.Println(peer, ": Handshake completed")
atomic.StoreInt64( atomic.StoreInt64(
&peer.stats.lastHandshakeNano, &peer.stats.lastHandshakeNano,
time.Now().UnixNano(), time.Now().UnixNano(),
) )
peer.timer.handshakeTimeout.Stop() timerHandshakeTimeout.Stop()
peer.timer.handshakeDeadline.Stop() timerHandshakeDeadline.Stop()
peer.signal.handshakeBegin.Enable()
peer.timer.sendLastMinuteHandshake.Set(false) peer.timer.sendLastMinuteHandshake.Set(false)
// allow further handshakes
peer.event.handshakeBegin.Clear()
enableHandshake = true
} }
} }
} }

20
uapi.go
View file

@ -253,12 +253,10 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to create new peer:", err) logError.Println("Failed to create new peer:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
logDebug.Println("UAPI: Created new peer:", peer.String()) logDebug.Println("UAPI: Created new peer:", peer)
} }
peer.mutex.Lock() peer.event.handshakePushDeadline.Fire()
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
peer.mutex.Unlock()
case "remove": case "remove":
@ -269,7 +267,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
if !dummy { if !dummy {
logDebug.Println("UAPI: Removing peer:", peer.String()) logDebug.Println("UAPI: Removing peer:", peer)
device.RemovePeer(peer.handshake.remoteStatic) device.RemovePeer(peer.handshake.remoteStatic)
} }
peer = &Peer{} peer = &Peer{}
@ -279,7 +277,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update PSK // update PSK
logDebug.Println("UAPI: Updating pre-shared key for peer:", peer.String()) logDebug.Println("UAPI: Updating pre-shared key for peer:", peer)
peer.handshake.mutex.Lock() peer.handshake.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value) err := peer.handshake.presharedKey.FromHex(value)
@ -294,7 +292,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// set endpoint destination // set endpoint destination
logDebug.Println("UAPI: Updating endpoint for peer:", peer.String()) logDebug.Println("UAPI: Updating endpoint for peer:", peer)
err := func() error { err := func() error {
peer.mutex.Lock() peer.mutex.Lock()
@ -304,7 +302,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return err return err
} }
peer.endpoint = endpoint peer.endpoint = endpoint
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) peer.event.handshakePushDeadline.Fire()
return nil return nil
}() }()
@ -317,7 +315,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update keep-alive interval // update keep-alive interval
logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer.String()) logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer)
secs, err := strconv.ParseUint(value, 10, 16) secs, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {
@ -342,7 +340,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "replace_allowed_ips": case "replace_allowed_ips":
logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer.String()) logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer)
if value != "true" { if value != "true" {
logError.Println("Failed to set replace_allowed_ips, invalid value:", value) logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
@ -359,7 +357,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "allowed_ip": case "allowed_ip":
logDebug.Println("UAPI: Adding allowed_ip to peer:", peer.String()) logDebug.Println("UAPI: Adding allowed_ip to peer:", peer)
_, network, err := net.ParseCIDR(value) _, network, err := net.ParseCIDR(value)
if err != nil { if err != nil {