Peer timer teardown

This commit is contained in:
Mathias Hall-Andersen 2017-12-29 17:42:09 +01:00
parent 996c7c4d8a
commit d73f960aab
7 changed files with 163 additions and 102 deletions

View file

@ -82,7 +82,7 @@ func updateBind(device *Device) error {
// open new sockets // open new sockets
if device.tun.isUp.Get() { if device.isUp.Get() {
device.log.Debug.Println("UDP bind updating") device.log.Debug.Println("UDP bind updating")

View file

@ -8,13 +8,13 @@ import (
) )
type Device struct { type Device struct {
closed AtomicBool // device is closed? (acting as guard) isUp AtomicBool // device is up (TUN interface up)?
isClosed AtomicBool // device is closed? (acting as guard)
log *Logger // collection of loggers for levels log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers idCounter uint // for assigning debug ids to peers
fwMark uint32 fwMark uint32
tun struct { tun struct {
device TUNDevice device TUNDevice
isUp AtomicBool
mtu int32 mtu int32
} }
pool struct { pool struct {
@ -45,6 +45,28 @@ type Device struct {
mac CookieChecker mac CookieChecker
} }
func (device *Device) Up() {
device.mutex.Lock()
defer device.mutex.Unlock()
device.isUp.Set(true)
updateBind(device)
for _, peer := range device.peers {
peer.Start()
}
}
func (device *Device) Down() {
device.mutex.Lock()
defer device.mutex.Unlock()
device.isUp.Set(false)
closeBind(device)
for _, peer := range device.peers {
peer.Stop()
}
}
/* Warning: /* Warning:
* The caller must hold the device mutex (write lock) * The caller must hold the device mutex (write lock)
*/ */
@ -54,9 +76,9 @@ func removePeerUnsafe(device *Device, key NoisePublicKey) {
return return
} }
peer.mutex.Lock() peer.mutex.Lock()
peer.Stop()
device.routingTable.RemovePeer(peer) device.routingTable.RemovePeer(peer)
delete(device.peers, key) delete(device.peers, key)
peer.Close()
} }
func (device *Device) IsUnderLoad() bool { func (device *Device) IsUnderLoad() bool {
@ -98,7 +120,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
device.publicKey = publicKey device.publicKey = publicKey
device.mac.Init(publicKey) device.mac.Init(publicKey)
// do DH precomputations // do DH pre-computations
rmKey := device.privateKey.IsZero() rmKey := device.privateKey.IsZero()
@ -132,10 +154,12 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
device.mutex.Lock() device.mutex.Lock()
defer device.mutex.Unlock() defer device.mutex.Unlock()
device.isUp.Set(false)
device.isClosed.Set(false)
device.log = logger device.log = logger
device.peers = make(map[NoisePublicKey]*Peer) device.peers = make(map[NoisePublicKey]*Peer)
device.tun.device = tun device.tun.device = tun
device.tun.isUp.Set(false)
device.indices.Init() device.indices.Init()
device.ratelimiter.Init() device.ratelimiter.Init()
@ -196,17 +220,13 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
func (device *Device) RemoveAllPeers() { func (device *Device) RemoveAllPeers() {
device.mutex.Lock() device.mutex.Lock()
defer device.mutex.Unlock() defer device.mutex.Unlock()
for key := range device.peers {
for key, peer := range device.peers { removePeerUnsafe(device, key)
peer.mutex.Lock()
delete(device.peers, key)
peer.Close()
peer.mutex.Unlock()
} }
} }
func (device *Device) Close() { func (device *Device) Close() {
if device.closed.Swap(true) { if device.isClosed.Swap(true) {
return return
} }
device.log.Info.Println("Closing device") device.log.Info.Println("Closing device")

View file

@ -34,15 +34,15 @@ type Peer struct {
flushNonceQueue Signal // size 1, empty queued packets flushNonceQueue Signal // size 1, empty queued packets
messageSend Signal // size 1, message was send to peer messageSend Signal // size 1, message was send to peer
messageReceived Signal // size 1, authenticated message recv messageReceived Signal // size 1, authenticated message recv
stop Signal // size 0, stop all goroutines stop Signal // size 0, stop all goroutines in peer
} }
timer struct { timer struct {
// state related to WireGuard timers // state related to WireGuard timers
keepalivePersistent Timer // set for persistent keepalives keepalivePersistent Timer // set for persistent keepalives
keepalivePassive Timer // set upon recieving messages keepalivePassive Timer // set upon recieving messages
newHandshake Timer // begin a new handshake (stale)
zeroAllKeys Timer // zero all key material zeroAllKeys Timer // zero all key material
handshakeNew Timer // begin a new handshake (stale)
handshakeDeadline Timer // complete handshake timeout handshakeDeadline Timer // complete handshake timeout
handshakeTimeout Timer // current handshake message timeout handshakeTimeout Timer // current handshake message timeout
@ -69,8 +69,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.timer.keepalivePersistent = NewTimer() peer.timer.keepalivePersistent = NewTimer()
peer.timer.keepalivePassive = NewTimer() peer.timer.keepalivePassive = NewTimer()
peer.timer.newHandshake = NewTimer()
peer.timer.zeroAllKeys = NewTimer() peer.timer.zeroAllKeys = NewTimer()
peer.timer.handshakeNew = NewTimer()
peer.timer.handshakeDeadline = NewTimer() peer.timer.handshakeDeadline = NewTimer()
peer.timer.handshakeTimeout = NewTimer() peer.timer.handshakeTimeout = NewTimer()
@ -116,32 +116,29 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// prepare signaling & routines // prepare signaling & routines
peer.signal.stop = NewSignal()
peer.signal.newKeyPair = NewSignal() peer.signal.newKeyPair = NewSignal()
peer.signal.handshakeBegin = NewSignal() peer.signal.handshakeBegin = NewSignal()
peer.signal.handshakeCompleted = NewSignal() peer.signal.handshakeCompleted = NewSignal()
peer.signal.flushNonceQueue = NewSignal() peer.signal.flushNonceQueue = NewSignal()
go peer.RoutineNonce()
go peer.RoutineTimerHandler()
go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver()
return peer, nil return peer, nil
} }
func (peer *Peer) SendBuffer(buffer []byte) error { func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.mutex.RLock() peer.device.net.mutex.RLock()
defer peer.device.net.mutex.RUnlock() defer peer.device.net.mutex.RUnlock()
peer.mutex.RLock() peer.mutex.RLock()
defer peer.mutex.RUnlock() defer peer.mutex.RUnlock()
if peer.endpoint == nil { if peer.endpoint == nil {
return errors.New("No known endpoint for peer") return errors.New("No known endpoint for peer")
} }
return peer.device.net.bind.Send(buffer, peer.endpoint) return peer.device.net.bind.Send(buffer, peer.endpoint)
} }
/* Returns a short string identification for logging /* Returns a short string identifier for logging
*/ */
func (peer *Peer) String() string { func (peer *Peer) String() string {
if peer.endpoint == nil { if peer.endpoint == nil {
@ -159,6 +156,32 @@ func (peer *Peer) String() string {
) )
} }
func (peer *Peer) Close() { /* Starts all routines for a given peer
*
* Requires that the caller holds the exclusive peer lock!
*/
func unsafePeerStart(peer *Peer) {
peer.signal.stop.Broadcast()
peer.signal.stop = NewSignal()
var wait sync.WaitGroup
wait.Add(1)
go peer.RoutineNonce()
go peer.RoutineTimerHandler(&wait)
go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver()
wait.Wait()
}
func (peer *Peer) Start() {
peer.mutex.Lock()
unsafePeerStart(peer)
peer.mutex.Unlock()
}
func (peer *Peer) Stop() {
peer.signal.stop.Broadcast() peer.signal.stop.Broadcast()
} }

View file

@ -43,12 +43,6 @@ func (t *Timer) Reset(dur time.Duration) {
t.Start(dur) t.Start(dur)
} }
func (t *Timer) Push(dur time.Duration) {
if t.pending.Get() {
t.Reset(dur)
}
}
func (t *Timer) Wait() <-chan time.Time { func (t *Timer) Wait() <-chan time.Time {
return t.timer.C return t.timer.C
} }

View file

@ -4,10 +4,17 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"math/rand" "math/rand"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
) )
/* NOTE:
* Notion of validity
*
*
*/
/* Called when a new authenticated message has been send /* Called when a new authenticated message has been send
* *
*/ */
@ -44,17 +51,19 @@ func (peer *Peer) KeepKeyFreshReceiving() {
send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
if send { if send {
// do a last minute attempt at initiating a new handshake // do a last minute attempt at initiating a new handshake
peer.signal.handshakeBegin.Send()
peer.timer.sendLastMinuteHandshake = true peer.timer.sendLastMinuteHandshake = true
peer.signal.handshakeBegin.Send()
} }
} }
/* Queues a keep-alive if no packets are queued for peer /* Queues a keep-alive if no packets are queued for peer
*/ */
func (peer *Peer) SendKeepAlive() bool { func (peer *Peer) SendKeepAlive() bool {
if len(peer.queue.nonce) != 0 {
return false
}
elem := peer.device.NewOutboundElement() elem := peer.device.NewOutboundElement()
elem.packet = nil elem.packet = nil
if len(peer.queue.nonce) == 0 {
select { select {
case peer.queue.nonce <- elem: case peer.queue.nonce <- elem:
return true return true
@ -62,17 +71,13 @@ func (peer *Peer) SendKeepAlive() bool {
return false return false
} }
} }
return true
}
/* Event: /* Event:
* Sent non-empty (authenticated) transport message * Sent non-empty (authenticated) transport message
*/ */
func (peer *Peer) TimerDataSent() { func (peer *Peer) TimerDataSent() {
peer.timer.keepalivePassive.Stop() peer.timer.keepalivePassive.Stop()
if peer.timer.newHandshake.Pending() { peer.timer.handshakeNew.Start(NewHandshakeTime)
peer.timer.newHandshake.Reset(NewHandshakeTime)
}
} }
/* Event: /* Event:
@ -91,7 +96,7 @@ func (peer *Peer) TimerDataReceived() {
* Any (authenticated) packet received * Any (authenticated) packet received
*/ */
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() { func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
peer.timer.newHandshake.Stop() peer.timer.handshakeNew.Stop()
} }
/* Event: /* Event:
@ -115,10 +120,6 @@ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
* - First transport message under the "next" key * - First transport message under the "next" key
*/ */
func (peer *Peer) TimerHandshakeComplete() { func (peer *Peer) TimerHandshakeComplete() {
atomic.StoreInt64(
&peer.stats.lastHandshakeNano,
time.Now().UnixNano(),
)
peer.signal.handshakeCompleted.Send() peer.signal.handshakeCompleted.Send()
peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
} }
@ -139,13 +140,75 @@ func (peer *Peer) TimerEphemeralKeyCreated() {
peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
} }
func (peer *Peer) RoutineTimerHandler() { /* Sends a new handshake initiation message to the peer (endpoint)
*/
func (peer *Peer) sendNewHandshake() error {
// temporarily disable the handshake complete signal
peer.signal.handshakeCompleted.Disable()
// create initiation message
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
return err
}
// marshal handshake message
var buff [MessageInitiationSize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.mac.AddMacs(packet)
// send to endpoint
peer.TimerAnyAuthenticatedPacketTraversal()
err = peer.SendBuffer(packet)
if err == nil {
peer.signal.handshakeCompleted.Enable()
}
// set timeout
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
peer.timer.keepalivePassive.Stop()
peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
return err
}
func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) {
device := peer.device device := peer.device
logInfo := device.log.Info logInfo := device.log.Info
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.String()) logDebug.Println("Routine, timer handler, started for peer", peer.String())
// reset all timers
peer.timer.keepalivePassive.Stop()
peer.timer.handshakeDeadline.Stop()
peer.timer.handshakeTimeout.Stop()
peer.timer.handshakeNew.Stop()
peer.timer.zeroAllKeys.Stop()
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 {
duration := time.Duration(interval) * time.Second
peer.timer.keepalivePersistent.Reset(duration)
}
// signal that timers are reset
ready.Done()
// handle timer events
for { for {
select { select {
@ -158,6 +221,7 @@ func (peer *Peer) RoutineTimerHandler() {
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 { if interval > 0 {
logDebug.Println("Sending keep-alive to", peer.String()) logDebug.Println("Sending keep-alive to", peer.String())
peer.timer.keepalivePassive.Stop()
peer.SendKeepAlive() peer.SendKeepAlive()
} }
@ -168,8 +232,8 @@ func (peer *Peer) RoutineTimerHandler() {
peer.SendKeepAlive() peer.SendKeepAlive()
if peer.timer.needAnotherKeepalive { if peer.timer.needAnotherKeepalive {
peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
peer.timer.needAnotherKeepalive = false peer.timer.needAnotherKeepalive = false
peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
} }
// clear key material timer // clear key material timer
@ -213,7 +277,7 @@ func (peer *Peer) RoutineTimerHandler() {
// handshake timers // handshake timers
case <-peer.timer.newHandshake.Wait(): case <-peer.timer.handshakeNew.Wait():
logInfo.Println("Retrying handshake with", peer.String()) logInfo.Println("Retrying handshake with", peer.String())
peer.signal.handshakeBegin.Send() peer.signal.handshakeBegin.Send()
@ -268,48 +332,16 @@ func (peer *Peer) RoutineTimerHandler() {
logInfo.Println( logInfo.Println(
"Handshake completed for:", peer.String()) "Handshake completed for:", peer.String())
atomic.StoreInt64(
&peer.stats.lastHandshakeNano,
time.Now().UnixNano(),
)
peer.timer.handshakeTimeout.Stop() peer.timer.handshakeTimeout.Stop()
peer.timer.handshakeDeadline.Stop() peer.timer.handshakeDeadline.Stop()
peer.signal.handshakeBegin.Enable() peer.signal.handshakeBegin.Enable()
peer.timer.sendLastMinuteHandshake = false
} }
} }
} }
/* Sends a new handshake initiation message to the peer (endpoint)
*/
func (peer *Peer) sendNewHandshake() error {
// temporarily disable the handshake complete signal
peer.signal.handshakeCompleted.Disable()
// create initiation message
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
return err
}
// marshal handshake message
var buff [MessageInitiationSize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.mac.AddMacs(packet)
// send to endpoint
err = peer.SendBuffer(packet)
if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal()
peer.signal.handshakeCompleted.Enable()
}
// set timeout
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
return err
}

View file

@ -46,21 +46,13 @@ func (device *Device) RoutineTUNEventReader() {
} }
if event&TUNEventUp != 0 { if event&TUNEventUp != 0 {
if !device.tun.isUp.Get() {
// begin listening for incomming datagrams
logInfo.Println("Interface set up") logInfo.Println("Interface set up")
device.tun.isUp.Set(true) device.Up()
updateBind(device)
}
} }
if event&TUNEventDown != 0 { if event&TUNEventDown != 0 {
if device.tun.isUp.Get() {
// stop listening for incomming datagrams
logInfo.Println("Interface set down") logInfo.Println("Interface set down")
device.tun.isUp.Set(false) device.Up()
closeBind(device)
}
} }
} }
} }

View file

@ -296,7 +296,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to get tun device status:", err) logError.Println("Failed to get tun device status:", err)
return &IPCError{Code: ipcErrorIO} return &IPCError{Code: ipcErrorIO}
} }
if device.tun.isUp.Get() && !dummy { if device.isUp.Get() && !dummy {
peer.SendKeepAlive() peer.SendKeepAlive()
} }
} }