device: use atomic access for unlocked keypair.next
Go's GC semantics might not always guarantee the safety of this, and the race detector gets upset too, so instead we wrap this all in atomic accessors. Reported-by: David Anderson <danderson@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
fdba6c183a
commit
28c4d04304
|
@ -8,7 +8,9 @@ package device
|
|||
import (
|
||||
"crypto/cipher"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.zx2c4.com/wireguard/replay"
|
||||
)
|
||||
|
@ -38,6 +40,14 @@ type Keypairs struct {
|
|||
next *Keypair
|
||||
}
|
||||
|
||||
func (kp *Keypairs) storeNext(next *Keypair) {
|
||||
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
|
||||
}
|
||||
|
||||
func (kp *Keypairs) loadNext() *Keypair {
|
||||
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
|
||||
}
|
||||
|
||||
func (kp *Keypairs) Current() *Keypair {
|
||||
kp.RLock()
|
||||
defer kp.RUnlock()
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"golang.org/x/crypto/blake2s"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/poly1305"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tai64n"
|
||||
)
|
||||
|
||||
|
@ -583,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||
defer keypairs.Unlock()
|
||||
|
||||
previous := keypairs.previous
|
||||
next := keypairs.next
|
||||
next := keypairs.loadNext()
|
||||
current := keypairs.current
|
||||
|
||||
if isInitiator {
|
||||
if next != nil {
|
||||
keypairs.next = nil
|
||||
keypairs.storeNext(nil)
|
||||
keypairs.previous = next
|
||||
device.DeleteKeypair(current)
|
||||
} else {
|
||||
|
@ -597,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||
device.DeleteKeypair(previous)
|
||||
keypairs.current = keypair
|
||||
} else {
|
||||
keypairs.next = keypair
|
||||
keypairs.storeNext(keypair)
|
||||
device.DeleteKeypair(next)
|
||||
keypairs.previous = nil
|
||||
device.DeleteKeypair(previous)
|
||||
|
@ -608,18 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||
|
||||
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
||||
keypairs := &peer.keypairs
|
||||
if keypairs.next != receivedKeypair {
|
||||
|
||||
if keypairs.loadNext() != receivedKeypair {
|
||||
return false
|
||||
}
|
||||
keypairs.Lock()
|
||||
defer keypairs.Unlock()
|
||||
if keypairs.next != receivedKeypair {
|
||||
if keypairs.loadNext() != receivedKeypair {
|
||||
return false
|
||||
}
|
||||
old := keypairs.previous
|
||||
keypairs.previous = keypairs.current
|
||||
peer.device.DeleteKeypair(old)
|
||||
keypairs.current = keypairs.next
|
||||
keypairs.next = nil
|
||||
keypairs.current = keypairs.loadNext()
|
||||
keypairs.storeNext(nil)
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
|
|||
t.Fatal("failed to derive keypair for peer 2", err)
|
||||
}
|
||||
|
||||
key1 := peer1.keypairs.next
|
||||
key1 := peer1.keypairs.loadNext()
|
||||
key2 := peer2.keypairs.current
|
||||
|
||||
// encrypting / decryption test
|
||||
|
|
|
@ -223,10 +223,10 @@ func (peer *Peer) ZeroAndFlushAll() {
|
|||
keypairs.Lock()
|
||||
device.DeleteKeypair(keypairs.previous)
|
||||
device.DeleteKeypair(keypairs.current)
|
||||
device.DeleteKeypair(keypairs.next)
|
||||
device.DeleteKeypair(keypairs.loadNext())
|
||||
keypairs.previous = nil
|
||||
keypairs.current = nil
|
||||
keypairs.next = nil
|
||||
keypairs.storeNext(nil)
|
||||
keypairs.Unlock()
|
||||
|
||||
// clear handshake state
|
||||
|
@ -254,7 +254,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
|
|||
keypairs.current.sendNonce = RejectAfterMessages
|
||||
}
|
||||
if keypairs.next != nil {
|
||||
keypairs.next.sendNonce = RejectAfterMessages
|
||||
keypairs.loadNext().sendNonce = RejectAfterMessages
|
||||
}
|
||||
keypairs.Unlock()
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue