diff --git a/device/device_test.go b/device/device_test.go index e143914..6b7980b 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -215,6 +215,14 @@ func TestConcurrencySafety(t *testing.T) { }() warmup.Wait() + applyCfg := func(cfg io.ReadSeeker) { + cfg.Seek(0, io.SeekStart) + err := pair[0].dev.IpcSetOperation(cfg) + if err != nil { + t.Fatal(err) + } + } + // Change persistent_keepalive_interval concurrently with tunnel use. t.Run("persistentKeepaliveInterval", func(t *testing.T) { cfg := uapiCfg( @@ -222,11 +230,24 @@ func TestConcurrencySafety(t *testing.T) { "persistent_keepalive_interval", "1", ) for i := 0; i < 1000; i++ { - cfg.Seek(0, io.SeekStart) - err := pair[0].dev.IpcSetOperation(cfg) - if err != nil { - t.Fatal(err) - } + applyCfg(cfg) + } + }) + + // Change private keys concurrently with tunnel use. + t.Run("privateKey", func(t *testing.T) { + bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777") + good := uapiCfg("private_key", "481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58") + // Set iters to a large number like 1000 to flush out data races quickly. + // Don't leave it large. That can cause logical races + // in which the handshake is interleaved with key changes + // such that the private key appears to be unchanging but + // other state gets reset, which can cause handshake failures like + // "Received packet with invalid mac1". + const iters = 1 + for i := 0; i < iters; i++ { + applyCfg(bad) + applyCfg(good) } }) diff --git a/device/keypair.go b/device/keypair.go index 254e696..27db779 100644 --- a/device/keypair.go +++ b/device/keypair.go @@ -23,7 +23,7 @@ import ( */ type Keypair struct { - sendNonce uint64 + sendNonce uint64 // accessed atomically send cipher.AEAD receive cipher.AEAD replayFilter replay.Filter diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 1dc854f..e34da83 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -566,7 +566,6 @@ func (peer *Peer) BeginSymmetricSession() error { setZero(recvKey[:]) keypair.created = time.Now() - keypair.sendNonce = 0 keypair.replayFilter.Reset() keypair.isInitiator = isInitiator keypair.localIndex = peer.handshake.localIndex diff --git a/device/peer.go b/device/peer.go index c094160..fe6de33 100644 --- a/device/peer.go +++ b/device/peer.go @@ -249,16 +249,17 @@ func (peer *Peer) ExpireCurrentKeypairs() { handshake.mutex.Lock() peer.device.indexTable.Delete(handshake.localIndex) handshake.Clear() - handshake.mutex.Unlock() peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) + handshake.mutex.Unlock() keypairs := &peer.keypairs keypairs.Lock() if keypairs.current != nil { - keypairs.current.sendNonce = RejectAfterMessages + atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages) } if keypairs.next != nil { - keypairs.loadNext().sendNonce = RejectAfterMessages + next := keypairs.loadNext() + atomic.StoreUint64(&next.sendNonce, RejectAfterMessages) } keypairs.Unlock() } diff --git a/device/send.go b/device/send.go index 6b21708..bc51fa6 100644 --- a/device/send.go +++ b/device/send.go @@ -403,7 +403,7 @@ NextPacket: // check validity of newest key pair keypair = peer.keypairs.Current() - if keypair != nil && keypair.sendNonce < RejectAfterMessages { + if keypair != nil && atomic.LoadUint64(&keypair.sendNonce) < RejectAfterMessages { if time.Since(keypair.created) < RejectAfterTime { break }