Refactor timers.go

This commit is contained in:
Mathias Hall-Andersen 2017-11-30 23:22:40 +01:00
parent 479a6f240e
commit 02ce67294c
8 changed files with 258 additions and 172 deletions

View file

@ -532,7 +532,6 @@ func (peer *Peer) NewKeyPair() *KeyPair {
kp := &peer.keyPairs
kp.mutex.Lock()
// TODO: Adapt kernel behavior noise.c:161
if isInitiator {
if kp.previous != nil {
device.DeleteKeyPair(kp.previous)
@ -545,7 +544,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
} else {
kp.previous = kp.current
kp.current = keyPair
signalSend(peer.signal.newKeyPair) // TODO: This more places (after confirming the key)
peer.signal.newKeyPair.Send()
}
} else {

View file

@ -28,30 +28,26 @@ type Peer struct {
nextKeepalive time.Time
}
signal struct {
newKeyPair chan struct{} // (size 1) : a new key pair was generated
handshakeBegin chan struct{} // (size 1) : request that a new handshake be started ("queue handshake")
handshakeCompleted chan struct{} // (size 1) : handshake completed
handshakeReset chan struct{} // (size 1) : reset handshake negotiation state
flushNonceQueue chan struct{} // (size 1) : empty queued packets
messageSend chan struct{} // (size 1) : a message was send to the peer
messageReceived chan struct{} // (size 1) : an authenticated message was received
stop chan struct{} // (size 0) : close to stop all goroutines for peer
newKeyPair Signal // size 1, new key pair was generated
handshakeCompleted Signal // size 1, handshake completed
handshakeBegin Signal // size 1, begin new handshake begin
flushNonceQueue Signal // size 1, empty queued packets
messageSend Signal // size 1, message was send to peer
messageReceived Signal // size 1, authenticated message recv
stop Signal // size 0, stop all goroutines
}
timer struct {
// state related to WireGuard timers
keepalivePersistent *time.Timer // set for persistent keepalives
keepalivePassive *time.Timer // set upon recieving messages
newHandshake *time.Timer // begin a new handshake (after Keepalive + RekeyTimeout)
zeroAllKeys *time.Timer // zero all key material (after RejectAfterTime*3)
handshakeDeadline *time.Timer // Current handshake must be completed
keepalivePersistent Timer // set for persistent keepalives
keepalivePassive Timer // set upon recieving messages
newHandshake Timer // begin a new handshake (stale)
zeroAllKeys Timer // zero all key material
handshakeDeadline Timer // complete handshake timeout
handshakeTimeout Timer // current handshake message timeout
pendingKeepalivePassive bool
pendingNewHandshake bool
pendingZeroAllKeys bool
needAnotherKeepalive bool
sendLastMinuteHandshake bool
needAnotherKeepalive bool
}
queue struct {
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
@ -71,10 +67,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.mac.Init(pk)
peer.device = device
peer.timer.keepalivePersistent = NewStoppedTimer()
peer.timer.keepalivePassive = NewStoppedTimer()
peer.timer.newHandshake = NewStoppedTimer()
peer.timer.zeroAllKeys = NewStoppedTimer()
peer.timer.keepalivePersistent = NewTimer()
peer.timer.keepalivePassive = NewTimer()
peer.timer.newHandshake = NewTimer()
peer.timer.zeroAllKeys = NewTimer()
peer.timer.handshakeDeadline = NewTimer()
peer.timer.handshakeTimeout = NewTimer()
// assign id for debugging
@ -102,7 +100,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.remoteStatic = pk
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
handshake.precomputedStaticStatic =
device.privateKey.sharedSecret(handshake.remoteStatic)
handshake.mutex.Unlock()
// reset endpoint
@ -117,16 +116,14 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// prepare signaling & routines
peer.signal.stop = make(chan struct{})
peer.signal.newKeyPair = make(chan struct{}, 1)
peer.signal.handshakeBegin = make(chan struct{}, 1)
peer.signal.handshakeReset = make(chan struct{}, 1)
peer.signal.handshakeCompleted = make(chan struct{}, 1)
peer.signal.flushNonceQueue = make(chan struct{}, 1)
peer.signal.stop = NewSignal()
peer.signal.newKeyPair = NewSignal()
peer.signal.handshakeBegin = NewSignal()
peer.signal.handshakeCompleted = NewSignal()
peer.signal.flushNonceQueue = NewSignal()
go peer.RoutineNonce()
go peer.RoutineTimerHandler()
go peer.RoutineHandshakeInitiator()
go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver()
@ -163,5 +160,5 @@ func (peer *Peer) String() string {
}
func (peer *Peer) Close() {
close(peer.signal.stop)
peer.signal.stop.Broadcast()
}

View file

@ -482,7 +482,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
for {
select {
case <-peer.signal.stop:
case <-peer.signal.stop.Wait():
logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id)
return

View file

@ -164,7 +164,7 @@ func (device *Device) RoutineReadFromTUN() {
// insert into nonce/pre-handshake queue
signalSend(peer.signal.handshakeReset)
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
addToOutboundQueue(peer.queue.nonce, elem)
elem = device.NewOutboundElement()
}
@ -186,7 +186,7 @@ func (peer *Peer) RoutineNonce() {
for {
NextPacket:
select {
case <-peer.signal.stop:
case <-peer.signal.stop.Wait():
return
case elem := <-peer.queue.nonce:
@ -201,16 +201,17 @@ func (peer *Peer) RoutineNonce() {
}
}
signalSend(peer.signal.handshakeBegin)
peer.signal.handshakeBegin.Send()
logDebug.Println("Awaiting key-pair for", peer.String())
select {
case <-peer.signal.newKeyPair:
case <-peer.signal.flushNonceQueue:
case <-peer.signal.newKeyPair.Wait():
case <-peer.signal.flushNonceQueue.Wait():
logDebug.Println("Clearing queue for", peer.String())
peer.FlushNonceQueue()
goto NextPacket
case <-peer.signal.stop:
case <-peer.signal.stop.Wait():
return
}
}
@ -309,8 +310,10 @@ func (peer *Peer) RoutineSequentialSender() {
for {
select {
case <-peer.signal.stop:
logDebug.Println("Routine, sequential sender, stopped for", peer.String())
case <-peer.signal.stop.Wait():
logDebug.Println(
"Routine, sequential sender, stopped for", peer.String())
return
case elem := <-peer.queue.outbound:

45
src/signal.go Normal file
View file

@ -0,0 +1,45 @@
package main
type Signal struct {
enabled AtomicBool
C chan struct{}
}
func NewSignal() (s Signal) {
s.C = make(chan struct{}, 1)
s.Enable()
return
}
func (s *Signal) Disable() {
s.enabled.Set(false)
s.Clear()
}
func (s *Signal) Enable() {
s.enabled.Set(true)
}
func (s *Signal) Send() {
if s.enabled.Get() {
select {
case s.C <- struct{}{}:
default:
}
}
}
func (s Signal) Clear() {
select {
case <-s.C:
default:
}
}
func (s Signal) Broadcast() {
close(s.C) // unblocks all selectors
}
func (s Signal) Wait() chan struct{} {
return s.C
}

65
src/timer.go Normal file
View file

@ -0,0 +1,65 @@
package main
import (
"time"
)
type Timer struct {
pending AtomicBool
timer *time.Timer
}
/* Starts the timer if not already pending
*/
func (t *Timer) Start(dur time.Duration) bool {
set := t.pending.Swap(true)
if !set {
t.timer.Reset(dur)
return true
}
return false
}
/* Stops the timer
*/
func (t *Timer) Stop() {
set := t.pending.Swap(true)
if set {
t.timer.Stop()
select {
case <-t.timer.C:
default:
}
}
t.pending.Set(false)
}
func (t *Timer) Pending() bool {
return t.pending.Get()
}
func (t *Timer) Reset(dur time.Duration) {
t.pending.Set(false)
t.Start(dur)
}
func (t *Timer) Push(dur time.Duration) {
if t.pending.Get() {
t.Reset(dur)
}
}
func (t *Timer) Wait() <-chan time.Time {
return t.timer.C
}
func NewTimer() (t Timer) {
t.pending.Set(false)
t.timer = time.NewTimer(0)
t.timer.Stop()
select {
case <-t.timer.C:
default:
}
return
}

View file

@ -18,10 +18,10 @@ func (peer *Peer) KeepKeyFreshSending() {
}
nonce := atomic.LoadUint64(&kp.sendNonce)
if nonce > RekeyAfterMessages {
signalSend(peer.signal.handshakeBegin)
peer.signal.handshakeBegin.Send()
}
if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
signalSend(peer.signal.handshakeBegin)
peer.signal.handshakeBegin.Send()
}
}
@ -44,7 +44,7 @@ func (peer *Peer) KeepKeyFreshReceiving() {
send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
if send {
// do a last minute attempt at initiating a new handshake
signalSend(peer.signal.handshakeBegin)
peer.signal.handshakeBegin.Send()
peer.timer.sendLastMinuteHandshake = true
}
}
@ -69,34 +69,36 @@ func (peer *Peer) SendKeepAlive() bool {
* Sent non-empty (authenticated) transport message
*/
func (peer *Peer) TimerDataSent() {
timerStop(peer.timer.keepalivePassive)
if !peer.timer.pendingNewHandshake {
peer.timer.pendingNewHandshake = true
peer.timer.keepalivePassive.Stop()
if peer.timer.newHandshake.Pending() {
peer.timer.newHandshake.Reset(NewHandshakeTime)
}
}
/* Event:
* Received non-empty (authenticated) transport message
*
* Action:
* Set a timer to confirm the message using a keep-alive (if not already set)
*/
func (peer *Peer) TimerDataReceived() {
if peer.timer.pendingKeepalivePassive {
if !peer.timer.keepalivePassive.Start(KeepaliveTimeout) {
peer.timer.needAnotherKeepalive = true
return
}
peer.timer.pendingKeepalivePassive = false
peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
}
/* Event:
* Any (authenticated) packet received
*/
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
timerStop(peer.timer.newHandshake)
peer.timer.newHandshake.Stop()
}
/* Event:
* Any authenticated packet send / received.
*
* Action:
* Push persistent keep-alive into the future
*/
func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
@ -117,7 +119,7 @@ func (peer *Peer) TimerHandshakeComplete() {
&peer.stats.lastHandshakeNano,
time.Now().UnixNano(),
)
signalSend(peer.signal.handshakeCompleted)
peer.signal.handshakeCompleted.Send()
peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
}
@ -129,7 +131,8 @@ func (peer *Peer) TimerHandshakeComplete() {
* CreateMessageInitiation
* CreateMessageResponse
*
* Schedules the deletion of all key material
* Action:
* Schedule the deletion of all key material
* upon failure to complete a handshake
*/
func (peer *Peer) TimerEphemeralKeyCreated() {
@ -139,18 +142,18 @@ func (peer *Peer) TimerEphemeralKeyCreated() {
func (peer *Peer) RoutineTimerHandler() {
device := peer.device
logInfo := device.log.Info
logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.String())
for {
select {
case <-peer.signal.stop:
return
/* timers */
// keep-alives
// keep-alive
case <-peer.timer.keepalivePersistent.C:
case <-peer.timer.keepalivePersistent.Wait():
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 {
@ -158,7 +161,7 @@ func (peer *Peer) RoutineTimerHandler() {
peer.SendKeepAlive()
}
case <-peer.timer.keepalivePassive.C:
case <-peer.timer.keepalivePassive.Wait():
logDebug.Println("Sending keep-alive to", peer.String())
@ -169,17 +172,9 @@ func (peer *Peer) RoutineTimerHandler() {
peer.timer.needAnotherKeepalive = false
}
// unresponsive session
// clear key material timer
case <-peer.timer.newHandshake.C:
logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply")
signalSend(peer.signal.handshakeBegin)
// clear key material
case <-peer.timer.zeroAllKeys.C:
case <-peer.timer.zeroAllKeys.Wait():
logDebug.Println("Clearing all key material for", peer.String())
@ -215,125 +210,106 @@ func (peer *Peer) RoutineTimerHandler() {
setZero(hs.chainKey[:])
setZero(hs.hash[:])
hs.mutex.Unlock()
}
}
}
/* This is the state machine for handshake initiation
*
* Associated with this routine is the signal "handshakeBegin"
* The routine will read from the "handshakeBegin" channel
* at most every RekeyTimeout seconds
*/
func (peer *Peer) RoutineHandshakeInitiator() {
device := peer.device
// handshake timers
logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
logDebug.Println("Routine, handshake initiator, started for", peer.String())
case <-peer.timer.newHandshake.Wait():
logInfo.Println("Retrying handshake with", peer.String())
peer.signal.handshakeBegin.Send()
var temp [256]byte
case <-peer.timer.handshakeTimeout.Wait():
for {
// clear source (in case this is causing problems)
// wait for signal
select {
case <-peer.signal.handshakeBegin:
case <-peer.signal.stop:
return
}
// set deadline
BeginHandshakes:
signalClear(peer.signal.handshakeReset)
deadline := time.NewTimer(RekeyAttemptTime)
AttemptHandshakes:
for attempts := uint(1); ; attempts++ {
// check if deadline reached
select {
case <-deadline.C:
logInfo.Println("Handshake negotiation timed out for:", peer.String())
signalSend(peer.signal.flushNonceQueue)
timerStop(peer.timer.keepalivePersistent)
break
case <-peer.signal.stop:
return
default:
peer.mutex.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.mutex.Unlock()
signalClear(peer.signal.handshakeCompleted)
// send new handshake
// create initiation message
msg, err := peer.device.CreateMessageInitiation(peer)
err := peer.sendNewHandshake()
if err != nil {
logError.Println("Failed to create handshake initiation message:", err)
break AttemptHandshakes
logInfo.Println(
"Failed to send handshake to peer:", peer.String())
}
// marshal handshake message
case <-peer.timer.handshakeDeadline.Wait():
writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.mac.AddMacs(packet)
// clear all queued packets and stop keep-alive
// send to endpoint
logInfo.Println(
"Handshake negotiation timed out for:", peer.String())
err = peer.SendBuffer(packet)
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
timeout := time.NewTimer(RekeyTimeout + jitter)
if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal()
logDebug.Println(
"Handshake initiation attempt",
attempts, "sent to", peer.String(),
)
} else {
logError.Println(
"Failed to send handshake initiation message to",
peer.String(), ":", err,
)
peer.signal.flushNonceQueue.Send()
peer.timer.keepalivePersistent.Stop()
peer.signal.handshakeBegin.Enable()
/* signals */
case <-peer.signal.stop.Wait():
return
case <-peer.signal.handshakeBegin.Wait():
peer.signal.handshakeBegin.Disable()
err := peer.sendNewHandshake()
if err != nil {
logInfo.Println(
"Failed to send handshake to peer:", peer.String())
}
// wait for handshake or timeout
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
select {
case <-peer.signal.handshakeCompleted.Wait():
case <-peer.signal.stop:
return
logInfo.Println(
"Handshake completed for:", peer.String())
case <-peer.signal.handshakeCompleted:
<-timeout.C
peer.timer.sendLastMinuteHandshake = false
break AttemptHandshakes
case <-peer.signal.handshakeReset:
<-timeout.C
goto BeginHandshakes
case <-timeout.C:
// clear source address of peer
peer.mutex.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.mutex.Unlock()
}
peer.timer.handshakeTimeout.Stop()
peer.timer.handshakeDeadline.Stop()
peer.signal.handshakeBegin.Enable()
}
// clear signal set in the meantime
signalClear(peer.signal.handshakeBegin)
}
}
/* Sends a new handshake initiation message to the peer (endpoint)
*/
func (peer *Peer) sendNewHandshake() error {
// temporarily disable the handshake complete signal
peer.signal.handshakeCompleted.Disable()
// create initiation message
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
return err
}
// marshal handshake message
var buff [MessageInitiationSize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.mac.AddMacs(packet)
// send to endpoint
err = peer.SendBuffer(packet)
if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal()
peer.signal.handshakeCompleted.Enable()
}
// set timeout
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
return err
}

View file

@ -221,7 +221,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid}
}
}
signalSend(peer.signal.handshakeReset)
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
dummy = false
}
@ -265,7 +265,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return err
}
peer.endpoint = endpoint
signalSend(peer.signal.handshakeReset)
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
return nil
}()