Refactor timers.go

This commit is contained in:
Mathias Hall-Andersen 2017-11-30 23:22:40 +01:00
parent 479a6f240e
commit 02ce67294c
8 changed files with 258 additions and 172 deletions

View file

@ -532,7 +532,6 @@ func (peer *Peer) NewKeyPair() *KeyPair {
kp := &peer.keyPairs kp := &peer.keyPairs
kp.mutex.Lock() kp.mutex.Lock()
// TODO: Adapt kernel behavior noise.c:161
if isInitiator { if isInitiator {
if kp.previous != nil { if kp.previous != nil {
device.DeleteKeyPair(kp.previous) device.DeleteKeyPair(kp.previous)
@ -545,7 +544,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
} else { } else {
kp.previous = kp.current kp.previous = kp.current
kp.current = keyPair kp.current = keyPair
signalSend(peer.signal.newKeyPair) // TODO: This more places (after confirming the key) peer.signal.newKeyPair.Send()
} }
} else { } else {

View file

@ -28,30 +28,26 @@ type Peer struct {
nextKeepalive time.Time nextKeepalive time.Time
} }
signal struct { signal struct {
newKeyPair chan struct{} // (size 1) : a new key pair was generated newKeyPair Signal // size 1, new key pair was generated
handshakeBegin chan struct{} // (size 1) : request that a new handshake be started ("queue handshake") handshakeCompleted Signal // size 1, handshake completed
handshakeCompleted chan struct{} // (size 1) : handshake completed handshakeBegin Signal // size 1, begin new handshake begin
handshakeReset chan struct{} // (size 1) : reset handshake negotiation state flushNonceQueue Signal // size 1, empty queued packets
flushNonceQueue chan struct{} // (size 1) : empty queued packets messageSend Signal // size 1, message was send to peer
messageSend chan struct{} // (size 1) : a message was send to the peer messageReceived Signal // size 1, authenticated message recv
messageReceived chan struct{} // (size 1) : an authenticated message was received stop Signal // size 0, stop all goroutines
stop chan struct{} // (size 0) : close to stop all goroutines for peer
} }
timer struct { timer struct {
// state related to WireGuard timers // state related to WireGuard timers
keepalivePersistent *time.Timer // set for persistent keepalives keepalivePersistent Timer // set for persistent keepalives
keepalivePassive *time.Timer // set upon recieving messages keepalivePassive Timer // set upon recieving messages
newHandshake *time.Timer // begin a new handshake (after Keepalive + RekeyTimeout) newHandshake Timer // begin a new handshake (stale)
zeroAllKeys *time.Timer // zero all key material (after RejectAfterTime*3) zeroAllKeys Timer // zero all key material
handshakeDeadline *time.Timer // Current handshake must be completed handshakeDeadline Timer // complete handshake timeout
handshakeTimeout Timer // current handshake message timeout
pendingKeepalivePassive bool
pendingNewHandshake bool
pendingZeroAllKeys bool
needAnotherKeepalive bool
sendLastMinuteHandshake bool sendLastMinuteHandshake bool
needAnotherKeepalive bool
} }
queue struct { queue struct {
nonce chan *QueueOutboundElement // nonce / pre-handshake queue nonce chan *QueueOutboundElement // nonce / pre-handshake queue
@ -71,10 +67,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.mac.Init(pk) peer.mac.Init(pk)
peer.device = device peer.device = device
peer.timer.keepalivePersistent = NewStoppedTimer() peer.timer.keepalivePersistent = NewTimer()
peer.timer.keepalivePassive = NewStoppedTimer() peer.timer.keepalivePassive = NewTimer()
peer.timer.newHandshake = NewStoppedTimer() peer.timer.newHandshake = NewTimer()
peer.timer.zeroAllKeys = NewStoppedTimer() peer.timer.zeroAllKeys = NewTimer()
peer.timer.handshakeDeadline = NewTimer()
peer.timer.handshakeTimeout = NewTimer()
// assign id for debugging // assign id for debugging
@ -102,7 +100,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
handshake.remoteStatic = pk handshake.remoteStatic = pk
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic) handshake.precomputedStaticStatic =
device.privateKey.sharedSecret(handshake.remoteStatic)
handshake.mutex.Unlock() handshake.mutex.Unlock()
// reset endpoint // reset endpoint
@ -117,16 +116,14 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// prepare signaling & routines // prepare signaling & routines
peer.signal.stop = make(chan struct{}) peer.signal.stop = NewSignal()
peer.signal.newKeyPair = make(chan struct{}, 1) peer.signal.newKeyPair = NewSignal()
peer.signal.handshakeBegin = make(chan struct{}, 1) peer.signal.handshakeBegin = NewSignal()
peer.signal.handshakeReset = make(chan struct{}, 1) peer.signal.handshakeCompleted = NewSignal()
peer.signal.handshakeCompleted = make(chan struct{}, 1) peer.signal.flushNonceQueue = NewSignal()
peer.signal.flushNonceQueue = make(chan struct{}, 1)
go peer.RoutineNonce() go peer.RoutineNonce()
go peer.RoutineTimerHandler() go peer.RoutineTimerHandler()
go peer.RoutineHandshakeInitiator()
go peer.RoutineSequentialSender() go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver() go peer.RoutineSequentialReceiver()
@ -163,5 +160,5 @@ func (peer *Peer) String() string {
} }
func (peer *Peer) Close() { func (peer *Peer) Close() {
close(peer.signal.stop) peer.signal.stop.Broadcast()
} }

View file

@ -482,7 +482,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
for { for {
select { select {
case <-peer.signal.stop:
case <-peer.signal.stop.Wait():
logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id) logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id)
return return

View file

@ -164,7 +164,7 @@ func (device *Device) RoutineReadFromTUN() {
// insert into nonce/pre-handshake queue // insert into nonce/pre-handshake queue
signalSend(peer.signal.handshakeReset) peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
addToOutboundQueue(peer.queue.nonce, elem) addToOutboundQueue(peer.queue.nonce, elem)
elem = device.NewOutboundElement() elem = device.NewOutboundElement()
} }
@ -186,7 +186,7 @@ func (peer *Peer) RoutineNonce() {
for { for {
NextPacket: NextPacket:
select { select {
case <-peer.signal.stop: case <-peer.signal.stop.Wait():
return return
case elem := <-peer.queue.nonce: case elem := <-peer.queue.nonce:
@ -201,16 +201,17 @@ func (peer *Peer) RoutineNonce() {
} }
} }
signalSend(peer.signal.handshakeBegin) peer.signal.handshakeBegin.Send()
logDebug.Println("Awaiting key-pair for", peer.String()) logDebug.Println("Awaiting key-pair for", peer.String())
select { select {
case <-peer.signal.newKeyPair: case <-peer.signal.newKeyPair.Wait():
case <-peer.signal.flushNonceQueue: case <-peer.signal.flushNonceQueue.Wait():
logDebug.Println("Clearing queue for", peer.String()) logDebug.Println("Clearing queue for", peer.String())
peer.FlushNonceQueue() peer.FlushNonceQueue()
goto NextPacket goto NextPacket
case <-peer.signal.stop: case <-peer.signal.stop.Wait():
return return
} }
} }
@ -309,8 +310,10 @@ func (peer *Peer) RoutineSequentialSender() {
for { for {
select { select {
case <-peer.signal.stop:
logDebug.Println("Routine, sequential sender, stopped for", peer.String()) case <-peer.signal.stop.Wait():
logDebug.Println(
"Routine, sequential sender, stopped for", peer.String())
return return
case elem := <-peer.queue.outbound: case elem := <-peer.queue.outbound:

45
src/signal.go Normal file
View file

@ -0,0 +1,45 @@
package main
type Signal struct {
enabled AtomicBool
C chan struct{}
}
func NewSignal() (s Signal) {
s.C = make(chan struct{}, 1)
s.Enable()
return
}
func (s *Signal) Disable() {
s.enabled.Set(false)
s.Clear()
}
func (s *Signal) Enable() {
s.enabled.Set(true)
}
func (s *Signal) Send() {
if s.enabled.Get() {
select {
case s.C <- struct{}{}:
default:
}
}
}
func (s Signal) Clear() {
select {
case <-s.C:
default:
}
}
func (s Signal) Broadcast() {
close(s.C) // unblocks all selectors
}
func (s Signal) Wait() chan struct{} {
return s.C
}

65
src/timer.go Normal file
View file

@ -0,0 +1,65 @@
package main
import (
"time"
)
type Timer struct {
pending AtomicBool
timer *time.Timer
}
/* Starts the timer if not already pending
*/
func (t *Timer) Start(dur time.Duration) bool {
set := t.pending.Swap(true)
if !set {
t.timer.Reset(dur)
return true
}
return false
}
/* Stops the timer
*/
func (t *Timer) Stop() {
set := t.pending.Swap(true)
if set {
t.timer.Stop()
select {
case <-t.timer.C:
default:
}
}
t.pending.Set(false)
}
func (t *Timer) Pending() bool {
return t.pending.Get()
}
func (t *Timer) Reset(dur time.Duration) {
t.pending.Set(false)
t.Start(dur)
}
func (t *Timer) Push(dur time.Duration) {
if t.pending.Get() {
t.Reset(dur)
}
}
func (t *Timer) Wait() <-chan time.Time {
return t.timer.C
}
func NewTimer() (t Timer) {
t.pending.Set(false)
t.timer = time.NewTimer(0)
t.timer.Stop()
select {
case <-t.timer.C:
default:
}
return
}

View file

@ -18,10 +18,10 @@ func (peer *Peer) KeepKeyFreshSending() {
} }
nonce := atomic.LoadUint64(&kp.sendNonce) nonce := atomic.LoadUint64(&kp.sendNonce)
if nonce > RekeyAfterMessages { if nonce > RekeyAfterMessages {
signalSend(peer.signal.handshakeBegin) peer.signal.handshakeBegin.Send()
} }
if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime { if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
signalSend(peer.signal.handshakeBegin) peer.signal.handshakeBegin.Send()
} }
} }
@ -44,7 +44,7 @@ func (peer *Peer) KeepKeyFreshReceiving() {
send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
if send { if send {
// do a last minute attempt at initiating a new handshake // do a last minute attempt at initiating a new handshake
signalSend(peer.signal.handshakeBegin) peer.signal.handshakeBegin.Send()
peer.timer.sendLastMinuteHandshake = true peer.timer.sendLastMinuteHandshake = true
} }
} }
@ -69,34 +69,36 @@ func (peer *Peer) SendKeepAlive() bool {
* Sent non-empty (authenticated) transport message * Sent non-empty (authenticated) transport message
*/ */
func (peer *Peer) TimerDataSent() { func (peer *Peer) TimerDataSent() {
timerStop(peer.timer.keepalivePassive) peer.timer.keepalivePassive.Stop()
if !peer.timer.pendingNewHandshake { if peer.timer.newHandshake.Pending() {
peer.timer.pendingNewHandshake = true
peer.timer.newHandshake.Reset(NewHandshakeTime) peer.timer.newHandshake.Reset(NewHandshakeTime)
} }
} }
/* Event: /* Event:
* Received non-empty (authenticated) transport message * Received non-empty (authenticated) transport message
*
* Action:
* Set a timer to confirm the message using a keep-alive (if not already set)
*/ */
func (peer *Peer) TimerDataReceived() { func (peer *Peer) TimerDataReceived() {
if peer.timer.pendingKeepalivePassive { if !peer.timer.keepalivePassive.Start(KeepaliveTimeout) {
peer.timer.needAnotherKeepalive = true peer.timer.needAnotherKeepalive = true
return
} }
peer.timer.pendingKeepalivePassive = false
peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
} }
/* Event: /* Event:
* Any (authenticated) packet received * Any (authenticated) packet received
*/ */
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() { func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
timerStop(peer.timer.newHandshake) peer.timer.newHandshake.Stop()
} }
/* Event: /* Event:
* Any authenticated packet send / received. * Any authenticated packet send / received.
*
* Action:
* Push persistent keep-alive into the future
*/ */
func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
@ -117,7 +119,7 @@ func (peer *Peer) TimerHandshakeComplete() {
&peer.stats.lastHandshakeNano, &peer.stats.lastHandshakeNano,
time.Now().UnixNano(), time.Now().UnixNano(),
) )
signalSend(peer.signal.handshakeCompleted) peer.signal.handshakeCompleted.Send()
peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
} }
@ -129,7 +131,8 @@ func (peer *Peer) TimerHandshakeComplete() {
* CreateMessageInitiation * CreateMessageInitiation
* CreateMessageResponse * CreateMessageResponse
* *
* Schedules the deletion of all key material * Action:
* Schedule the deletion of all key material
* upon failure to complete a handshake * upon failure to complete a handshake
*/ */
func (peer *Peer) TimerEphemeralKeyCreated() { func (peer *Peer) TimerEphemeralKeyCreated() {
@ -139,18 +142,18 @@ func (peer *Peer) TimerEphemeralKeyCreated() {
func (peer *Peer) RoutineTimerHandler() { func (peer *Peer) RoutineTimerHandler() {
device := peer.device device := peer.device
logInfo := device.log.Info
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.String()) logDebug.Println("Routine, timer handler, started for peer", peer.String())
for { for {
select { select {
case <-peer.signal.stop: /* timers */
return
// keep-alives // keep-alive
case <-peer.timer.keepalivePersistent.C: case <-peer.timer.keepalivePersistent.Wait():
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 { if interval > 0 {
@ -158,7 +161,7 @@ func (peer *Peer) RoutineTimerHandler() {
peer.SendKeepAlive() peer.SendKeepAlive()
} }
case <-peer.timer.keepalivePassive.C: case <-peer.timer.keepalivePassive.Wait():
logDebug.Println("Sending keep-alive to", peer.String()) logDebug.Println("Sending keep-alive to", peer.String())
@ -169,17 +172,9 @@ func (peer *Peer) RoutineTimerHandler() {
peer.timer.needAnotherKeepalive = false peer.timer.needAnotherKeepalive = false
} }
// unresponsive session // clear key material timer
case <-peer.timer.newHandshake.C: case <-peer.timer.zeroAllKeys.Wait():
logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply")
signalSend(peer.signal.handshakeBegin)
// clear key material
case <-peer.timer.zeroAllKeys.C:
logDebug.Println("Clearing all key material for", peer.String()) logDebug.Println("Clearing all key material for", peer.String())
@ -215,125 +210,106 @@ func (peer *Peer) RoutineTimerHandler() {
setZero(hs.chainKey[:]) setZero(hs.chainKey[:])
setZero(hs.hash[:]) setZero(hs.hash[:])
hs.mutex.Unlock() hs.mutex.Unlock()
}
}
}
/* This is the state machine for handshake initiation // handshake timers
*
* Associated with this routine is the signal "handshakeBegin"
* The routine will read from the "handshakeBegin" channel
* at most every RekeyTimeout seconds
*/
func (peer *Peer) RoutineHandshakeInitiator() {
device := peer.device
logInfo := device.log.Info case <-peer.timer.newHandshake.Wait():
logError := device.log.Error logInfo.Println("Retrying handshake with", peer.String())
logDebug := device.log.Debug peer.signal.handshakeBegin.Send()
logDebug.Println("Routine, handshake initiator, started for", peer.String())
var temp [256]byte case <-peer.timer.handshakeTimeout.Wait():
for { // clear source (in case this is causing problems)
// wait for signal peer.mutex.Lock()
if peer.endpoint != nil {
select { peer.endpoint.ClearSrc()
case <-peer.signal.handshakeBegin:
case <-peer.signal.stop:
return
}
// set deadline
BeginHandshakes:
signalClear(peer.signal.handshakeReset)
deadline := time.NewTimer(RekeyAttemptTime)
AttemptHandshakes:
for attempts := uint(1); ; attempts++ {
// check if deadline reached
select {
case <-deadline.C:
logInfo.Println("Handshake negotiation timed out for:", peer.String())
signalSend(peer.signal.flushNonceQueue)
timerStop(peer.timer.keepalivePersistent)
break
case <-peer.signal.stop:
return
default:
} }
peer.mutex.Unlock()
signalClear(peer.signal.handshakeCompleted) // send new handshake
// create initiation message err := peer.sendNewHandshake()
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil { if err != nil {
logError.Println("Failed to create handshake initiation message:", err) logInfo.Println(
break AttemptHandshakes "Failed to send handshake to peer:", peer.String())
} }
// marshal handshake message case <-peer.timer.handshakeDeadline.Wait():
writer := bytes.NewBuffer(temp[:0]) // clear all queued packets and stop keep-alive
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.mac.AddMacs(packet)
// send to endpoint logInfo.Println(
"Handshake negotiation timed out for:", peer.String())
err = peer.SendBuffer(packet) peer.signal.flushNonceQueue.Send()
jitter := time.Millisecond * time.Duration(rand.Uint32()%334) peer.timer.keepalivePersistent.Stop()
timeout := time.NewTimer(RekeyTimeout + jitter) peer.signal.handshakeBegin.Enable()
if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal() /* signals */
logDebug.Println(
"Handshake initiation attempt", case <-peer.signal.stop.Wait():
attempts, "sent to", peer.String(), return
)
} else { case <-peer.signal.handshakeBegin.Wait():
logError.Println(
"Failed to send handshake initiation message to", peer.signal.handshakeBegin.Disable()
peer.String(), ":", err,
) err := peer.sendNewHandshake()
if err != nil {
logInfo.Println(
"Failed to send handshake to peer:", peer.String())
} }
// wait for handshake or timeout peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
select { case <-peer.signal.handshakeCompleted.Wait():
case <-peer.signal.stop: logInfo.Println(
return "Handshake completed for:", peer.String())
case <-peer.signal.handshakeCompleted: peer.timer.handshakeTimeout.Stop()
<-timeout.C peer.timer.handshakeDeadline.Stop()
peer.timer.sendLastMinuteHandshake = false peer.signal.handshakeBegin.Enable()
break AttemptHandshakes
case <-peer.signal.handshakeReset:
<-timeout.C
goto BeginHandshakes
case <-timeout.C:
// clear source address of peer
peer.mutex.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.mutex.Unlock()
}
} }
// clear signal set in the meantime
signalClear(peer.signal.handshakeBegin)
} }
} }
/* Sends a new handshake initiation message to the peer (endpoint)
*/
func (peer *Peer) sendNewHandshake() error {
// temporarily disable the handshake complete signal
peer.signal.handshakeCompleted.Disable()
// create initiation message
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
return err
}
// marshal handshake message
var buff [MessageInitiationSize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.mac.AddMacs(packet)
// send to endpoint
err = peer.SendBuffer(packet)
if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal()
peer.signal.handshakeCompleted.Enable()
}
// set timeout
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
return err
}

View file

@ -221,7 +221,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
} }
signalSend(peer.signal.handshakeReset) peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
dummy = false dummy = false
} }
@ -265,7 +265,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return err return err
} }
peer.endpoint = endpoint peer.endpoint = endpoint
signalSend(peer.signal.handshakeReset) peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
return nil return nil
}() }()