device: fix races from changing private_key

Access keypair.sendNonce atomically.
Eliminate one unnecessary initialization to zero.

Mutate handshake.lastSentHandshake with the mutex held.

Co-authored-by: David Anderson <danderson@tailscale.com>
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
Josh Bleecher Snyder 2020-12-15 15:02:13 -08:00 committed by Jason A. Donenfeld
parent c8faa34cde
commit 70861686d3
5 changed files with 32 additions and 11 deletions

View file

@ -215,6 +215,14 @@ func TestConcurrencySafety(t *testing.T) {
}() }()
warmup.Wait() warmup.Wait()
applyCfg := func(cfg io.ReadSeeker) {
cfg.Seek(0, io.SeekStart)
err := pair[0].dev.IpcSetOperation(cfg)
if err != nil {
t.Fatal(err)
}
}
// Change persistent_keepalive_interval concurrently with tunnel use. // Change persistent_keepalive_interval concurrently with tunnel use.
t.Run("persistentKeepaliveInterval", func(t *testing.T) { t.Run("persistentKeepaliveInterval", func(t *testing.T) {
cfg := uapiCfg( cfg := uapiCfg(
@ -222,11 +230,24 @@ func TestConcurrencySafety(t *testing.T) {
"persistent_keepalive_interval", "1", "persistent_keepalive_interval", "1",
) )
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
cfg.Seek(0, io.SeekStart) applyCfg(cfg)
err := pair[0].dev.IpcSetOperation(cfg)
if err != nil {
t.Fatal(err)
} }
})
// Change private keys concurrently with tunnel use.
t.Run("privateKey", func(t *testing.T) {
bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777")
good := uapiCfg("private_key", "481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58")
// Set iters to a large number like 1000 to flush out data races quickly.
// Don't leave it large. That can cause logical races
// in which the handshake is interleaved with key changes
// such that the private key appears to be unchanging but
// other state gets reset, which can cause handshake failures like
// "Received packet with invalid mac1".
const iters = 1
for i := 0; i < iters; i++ {
applyCfg(bad)
applyCfg(good)
} }
}) })

View file

@ -23,7 +23,7 @@ import (
*/ */
type Keypair struct { type Keypair struct {
sendNonce uint64 sendNonce uint64 // accessed atomically
send cipher.AEAD send cipher.AEAD
receive cipher.AEAD receive cipher.AEAD
replayFilter replay.Filter replayFilter replay.Filter

View file

@ -566,7 +566,6 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(recvKey[:]) setZero(recvKey[:])
keypair.created = time.Now() keypair.created = time.Now()
keypair.sendNonce = 0
keypair.replayFilter.Reset() keypair.replayFilter.Reset()
keypair.isInitiator = isInitiator keypair.isInitiator = isInitiator
keypair.localIndex = peer.handshake.localIndex keypair.localIndex = peer.handshake.localIndex

View file

@ -249,16 +249,17 @@ func (peer *Peer) ExpireCurrentKeypairs() {
handshake.mutex.Lock() handshake.mutex.Lock()
peer.device.indexTable.Delete(handshake.localIndex) peer.device.indexTable.Delete(handshake.localIndex)
handshake.Clear() handshake.Clear()
handshake.mutex.Unlock()
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
handshake.mutex.Unlock()
keypairs := &peer.keypairs keypairs := &peer.keypairs
keypairs.Lock() keypairs.Lock()
if keypairs.current != nil { if keypairs.current != nil {
keypairs.current.sendNonce = RejectAfterMessages atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages)
} }
if keypairs.next != nil { if keypairs.next != nil {
keypairs.loadNext().sendNonce = RejectAfterMessages next := keypairs.loadNext()
atomic.StoreUint64(&next.sendNonce, RejectAfterMessages)
} }
keypairs.Unlock() keypairs.Unlock()
} }

View file

@ -403,7 +403,7 @@ NextPacket:
// check validity of newest key pair // check validity of newest key pair
keypair = peer.keypairs.Current() keypair = peer.keypairs.Current()
if keypair != nil && keypair.sendNonce < RejectAfterMessages { if keypair != nil && atomic.LoadUint64(&keypair.sendNonce) < RejectAfterMessages {
if time.Since(keypair.created) < RejectAfterTime { if time.Since(keypair.created) < RejectAfterTime {
break break
} }