device: use atomic access for unlocked keypair.next

Go's GC semantics might not always guarantee the safety of this, and the
race detector gets upset too, so instead we wrap this all in atomic
accessors.

Reported-by: David Anderson <danderson@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2020-05-02 01:30:23 -06:00
parent fdba6c183a
commit 28c4d04304
4 changed files with 23 additions and 11 deletions

View file

@ -8,7 +8,9 @@ package device
import ( import (
"crypto/cipher" "crypto/cipher"
"sync" "sync"
"sync/atomic"
"time" "time"
"unsafe"
"golang.zx2c4.com/wireguard/replay" "golang.zx2c4.com/wireguard/replay"
) )
@ -38,6 +40,14 @@ type Keypairs struct {
next *Keypair next *Keypair
} }
func (kp *Keypairs) storeNext(next *Keypair) {
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
}
func (kp *Keypairs) loadNext() *Keypair {
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
}
func (kp *Keypairs) Current() *Keypair { func (kp *Keypairs) Current() *Keypair {
kp.RLock() kp.RLock()
defer kp.RUnlock() defer kp.RUnlock()

View file

@ -14,6 +14,7 @@ import (
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305" "golang.org/x/crypto/poly1305"
"golang.zx2c4.com/wireguard/tai64n" "golang.zx2c4.com/wireguard/tai64n"
) )
@ -583,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock() defer keypairs.Unlock()
previous := keypairs.previous previous := keypairs.previous
next := keypairs.next next := keypairs.loadNext()
current := keypairs.current current := keypairs.current
if isInitiator { if isInitiator {
if next != nil { if next != nil {
keypairs.next = nil keypairs.storeNext(nil)
keypairs.previous = next keypairs.previous = next
device.DeleteKeypair(current) device.DeleteKeypair(current)
} else { } else {
@ -597,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
keypairs.current = keypair keypairs.current = keypair
} else { } else {
keypairs.next = keypair keypairs.storeNext(keypair)
device.DeleteKeypair(next) device.DeleteKeypair(next)
keypairs.previous = nil keypairs.previous = nil
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
@ -608,18 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs keypairs := &peer.keypairs
if keypairs.next != receivedKeypair {
if keypairs.loadNext() != receivedKeypair {
return false return false
} }
keypairs.Lock() keypairs.Lock()
defer keypairs.Unlock() defer keypairs.Unlock()
if keypairs.next != receivedKeypair { if keypairs.loadNext() != receivedKeypair {
return false return false
} }
old := keypairs.previous old := keypairs.previous
keypairs.previous = keypairs.current keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old) peer.device.DeleteKeypair(old)
keypairs.current = keypairs.next keypairs.current = keypairs.loadNext()
keypairs.next = nil keypairs.storeNext(nil)
return true return true
} }

View file

@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err) t.Fatal("failed to derive keypair for peer 2", err)
} }
key1 := peer1.keypairs.next key1 := peer1.keypairs.loadNext()
key2 := peer2.keypairs.current key2 := peer2.keypairs.current
// encrypting / decryption test // encrypting / decryption test

View file

@ -223,10 +223,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock() keypairs.Lock()
device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current) device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.next) device.DeleteKeypair(keypairs.loadNext())
keypairs.previous = nil keypairs.previous = nil
keypairs.current = nil keypairs.current = nil
keypairs.next = nil keypairs.storeNext(nil)
keypairs.Unlock() keypairs.Unlock()
// clear handshake state // clear handshake state
@ -254,7 +254,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs.current.sendNonce = RejectAfterMessages keypairs.current.sendNonce = RejectAfterMessages
} }
if keypairs.next != nil { if keypairs.next != nil {
keypairs.next.sendNonce = RejectAfterMessages keypairs.loadNext().sendNonce = RejectAfterMessages
} }
keypairs.Unlock() keypairs.Unlock()
} }