device: use atomic access for unlocked keypair.next
Go's GC semantics might not always guarantee the safety of this, and the race detector gets upset too, so instead we wrap this all in atomic accessors. Reported-by: David Anderson <danderson@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
fdba6c183a
commit
28c4d04304
|
@ -8,7 +8,9 @@ package device
|
||||||
import (
|
import (
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/replay"
|
"golang.zx2c4.com/wireguard/replay"
|
||||||
)
|
)
|
||||||
|
@ -38,6 +40,14 @@ type Keypairs struct {
|
||||||
next *Keypair
|
next *Keypair
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (kp *Keypairs) storeNext(next *Keypair) {
|
||||||
|
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kp *Keypairs) loadNext() *Keypair {
|
||||||
|
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
|
||||||
|
}
|
||||||
|
|
||||||
func (kp *Keypairs) Current() *Keypair {
|
func (kp *Keypairs) Current() *Keypair {
|
||||||
kp.RLock()
|
kp.RLock()
|
||||||
defer kp.RUnlock()
|
defer kp.RUnlock()
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/crypto/poly1305"
|
"golang.org/x/crypto/poly1305"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tai64n"
|
"golang.zx2c4.com/wireguard/tai64n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -583,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
defer keypairs.Unlock()
|
defer keypairs.Unlock()
|
||||||
|
|
||||||
previous := keypairs.previous
|
previous := keypairs.previous
|
||||||
next := keypairs.next
|
next := keypairs.loadNext()
|
||||||
current := keypairs.current
|
current := keypairs.current
|
||||||
|
|
||||||
if isInitiator {
|
if isInitiator {
|
||||||
if next != nil {
|
if next != nil {
|
||||||
keypairs.next = nil
|
keypairs.storeNext(nil)
|
||||||
keypairs.previous = next
|
keypairs.previous = next
|
||||||
device.DeleteKeypair(current)
|
device.DeleteKeypair(current)
|
||||||
} else {
|
} else {
|
||||||
|
@ -597,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
device.DeleteKeypair(previous)
|
device.DeleteKeypair(previous)
|
||||||
keypairs.current = keypair
|
keypairs.current = keypair
|
||||||
} else {
|
} else {
|
||||||
keypairs.next = keypair
|
keypairs.storeNext(keypair)
|
||||||
device.DeleteKeypair(next)
|
device.DeleteKeypair(next)
|
||||||
keypairs.previous = nil
|
keypairs.previous = nil
|
||||||
device.DeleteKeypair(previous)
|
device.DeleteKeypair(previous)
|
||||||
|
@ -608,18 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
|
|
||||||
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
||||||
keypairs := &peer.keypairs
|
keypairs := &peer.keypairs
|
||||||
if keypairs.next != receivedKeypair {
|
|
||||||
|
if keypairs.loadNext() != receivedKeypair {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
defer keypairs.Unlock()
|
defer keypairs.Unlock()
|
||||||
if keypairs.next != receivedKeypair {
|
if keypairs.loadNext() != receivedKeypair {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
old := keypairs.previous
|
old := keypairs.previous
|
||||||
keypairs.previous = keypairs.current
|
keypairs.previous = keypairs.current
|
||||||
peer.device.DeleteKeypair(old)
|
peer.device.DeleteKeypair(old)
|
||||||
keypairs.current = keypairs.next
|
keypairs.current = keypairs.loadNext()
|
||||||
keypairs.next = nil
|
keypairs.storeNext(nil)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
|
||||||
t.Fatal("failed to derive keypair for peer 2", err)
|
t.Fatal("failed to derive keypair for peer 2", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key1 := peer1.keypairs.next
|
key1 := peer1.keypairs.loadNext()
|
||||||
key2 := peer2.keypairs.current
|
key2 := peer2.keypairs.current
|
||||||
|
|
||||||
// encrypting / decryption test
|
// encrypting / decryption test
|
||||||
|
|
|
@ -223,10 +223,10 @@ func (peer *Peer) ZeroAndFlushAll() {
|
||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
device.DeleteKeypair(keypairs.previous)
|
device.DeleteKeypair(keypairs.previous)
|
||||||
device.DeleteKeypair(keypairs.current)
|
device.DeleteKeypair(keypairs.current)
|
||||||
device.DeleteKeypair(keypairs.next)
|
device.DeleteKeypair(keypairs.loadNext())
|
||||||
keypairs.previous = nil
|
keypairs.previous = nil
|
||||||
keypairs.current = nil
|
keypairs.current = nil
|
||||||
keypairs.next = nil
|
keypairs.storeNext(nil)
|
||||||
keypairs.Unlock()
|
keypairs.Unlock()
|
||||||
|
|
||||||
// clear handshake state
|
// clear handshake state
|
||||||
|
@ -254,7 +254,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
|
||||||
keypairs.current.sendNonce = RejectAfterMessages
|
keypairs.current.sendNonce = RejectAfterMessages
|
||||||
}
|
}
|
||||||
if keypairs.next != nil {
|
if keypairs.next != nil {
|
||||||
keypairs.next.sendNonce = RejectAfterMessages
|
keypairs.loadNext().sendNonce = RejectAfterMessages
|
||||||
}
|
}
|
||||||
keypairs.Unlock()
|
keypairs.Unlock()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue