From 4f97b52ea60ce4f2448d8617853aa44759727197 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 2 Feb 2018 17:24:29 +0100 Subject: [PATCH] Clear cryptographic state when interface down Attempts to clear the cryptographic state for every peer when the device goes down. --- src/device.go | 22 ---------------------- src/keypair.go | 4 +++- src/noise_protocol.go | 22 ++++++++++++++++++---- src/peer.go | 44 +++++++++++++++++++++++++++++++++++-------- src/timers.go | 7 +------ 5 files changed, 58 insertions(+), 41 deletions(-) diff --git a/src/device.go b/src/device.go index 0317b60..c041987 100644 --- a/src/device.go +++ b/src/device.go @@ -88,28 +88,6 @@ func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { device.routing.table.RemovePeer(peer) peer.Stop() - // clean index table - - kp := &peer.keyPairs - kp.mutex.Lock() - - if kp.previous != nil { - device.indices.Delete(kp.previous.localIndex) - } - - if kp.current != nil { - device.indices.Delete(kp.current.localIndex) - } - - if kp.next != nil { - device.indices.Delete(kp.next.localIndex) - } - - kp.previous = nil - kp.current = nil - kp.next = nil - kp.mutex.Unlock() - // remove from peer map delete(device.peers.keyMap, key) diff --git a/src/keypair.go b/src/keypair.go index 7e5297b..283cb92 100644 --- a/src/keypair.go +++ b/src/keypair.go @@ -38,5 +38,7 @@ func (kp *KeyPairs) Current() *KeyPair { } func (device *Device) DeleteKeyPair(key *KeyPair) { - device.indices.Delete(key.localIndex) + if key != nil { + device.indices.Delete(key.localIndex) + } } diff --git a/src/noise_protocol.go b/src/noise_protocol.go index d620a0d..c9713c0 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -121,6 +121,15 @@ func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { hsh.Reset() } +func (h *Handshake) Clear() { + setZero(h.localEphemeral[:]) + setZero(h.remoteEphemeral[:]) + setZero(h.chainKey[:]) + setZero(h.hash[:]) + h.localIndex = 0 + h.state = HandshakeZeroed +} + func (h *Handshake) mixHash(data []byte) { mixHash(&h.hash, &h.hash, data) } @@ -138,8 +147,8 @@ func init() { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { - device.noise.mutex.Lock() - defer device.noise.mutex.Unlock() + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() handshake := &peer.handshake handshake.mutex.Lock() @@ -393,7 +402,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { ok := func() bool { - // read lock handshake + // lock handshake state handshake.mutex.RLock() defer handshake.mutex.RUnlock() @@ -402,6 +411,11 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { return false } + // lock private key for reading + + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() + // finish 3-way DH mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) @@ -432,7 +446,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { ) mixHash(&hash, &hash, tau[:]) - // authenticate + // authenticate transcript aead, _ := chacha20poly1305.New(key[:]) _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) diff --git a/src/peer.go b/src/peer.go index 3b8f7cc..7776b71 100644 --- a/src/peer.go +++ b/src/peer.go @@ -48,8 +48,8 @@ type Peer struct { // state related to WireGuard timers - keepalivePersistent Timer // set for persistent keepalives - keepalivePassive Timer // set upon recieving messages + keepalivePersistent Timer // set for persistent keep-alive + keepalivePassive Timer // set upon receiving messages zeroAllKeys Timer // zero all key material handshakeNew Timer // begin a new handshake (stale) handshakeDeadline Timer // complete handshake timeout @@ -69,7 +69,7 @@ type Peer struct { mutex deadlock.Mutex // held when stopping / starting routines starting sync.WaitGroup // routines pending start stopping sync.WaitGroup // routines pending stop - stop Signal // size 0, stop all goroutines in peer + stop Signal // size 0, stop all go-routines in peer } mac CookieGenerator @@ -123,7 +123,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { } device.peers.keyMap[pk] = peer - // precompute DH + // pre-compute DH handshake := &peer.handshake handshake.mutex.Lock() @@ -186,16 +186,19 @@ func (peer *Peer) String() string { func (peer *Peer) Start() { + // should never start a peer on a closed device + if peer.device.isClosed.Get() { return } + // prevent simultaneous start/stop operations + peer.routines.mutex.Lock() defer peer.routines.mutex.Unlock() - peer.device.log.Debug.Println("Starting:", peer.String()) - // stop & wait for ungoing routines (if any) + // stop & wait for ongoing routines (if any) peer.isRunning.Set(false) peer.routines.stop.Broadcast() @@ -230,12 +233,15 @@ func (peer *Peer) Start() { func (peer *Peer) Stop() { + // prevent simultaneous start/stop operations + peer.routines.mutex.Lock() defer peer.routines.mutex.Unlock() - peer.device.log.Debug.Println("Stopping:", peer.String()) + device := peer.device + device.log.Debug.Println("Stopping:", peer.String()) - // stop & wait for ungoing routines (if any) + // stop & wait for ongoing peer routines (if any) peer.routines.stop.Broadcast() peer.routines.starting.Wait() @@ -247,6 +253,28 @@ func (peer *Peer) Stop() { close(peer.queue.outbound) close(peer.queue.inbound) + // clear key pairs + + kp := &peer.keyPairs + kp.mutex.Lock() + + device.DeleteKeyPair(kp.previous) + device.DeleteKeyPair(kp.current) + device.DeleteKeyPair(kp.next) + + kp.previous = nil + kp.current = nil + kp.next = nil + kp.mutex.Unlock() + + // clear handshake state + + hs := &peer.handshake + hs.mutex.Lock() + device.indices.Delete(hs.localIndex) + hs.Clear() + hs.mutex.Unlock() + // reset signal (to handle repeated stopping) peer.routines.stop = NewSignal() diff --git a/src/timers.go b/src/timers.go index 2ef105e..7092688 100644 --- a/src/timers.go +++ b/src/timers.go @@ -274,12 +274,7 @@ func (peer *Peer) RoutineTimerHandler() { // zero out handshake device.indices.Delete(hs.localIndex) - - hs.localIndex = 0 - setZero(hs.localEphemeral[:]) - setZero(hs.remoteEphemeral[:]) - setZero(hs.chainKey[:]) - setZero(hs.hash[:]) + hs.Clear() hs.mutex.Unlock() // handshake timers