70861686d3
Access keypair.sendNonce atomically. Eliminate one unnecessary initialization to zero. Mutate handshake.lastSentHandshake with the mutex held. Co-authored-by: David Anderson <danderson@tailscale.com> Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
627 lines
15 KiB
Go
627 lines
15 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package device
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/blake2s"
|
|
"golang.org/x/crypto/chacha20poly1305"
|
|
"golang.org/x/crypto/poly1305"
|
|
|
|
"golang.zx2c4.com/wireguard/tai64n"
|
|
)
|
|
|
|
type handshakeState int
|
|
|
|
// TODO(crawshaw): add commentary describing each state and the transitions
|
|
const (
|
|
handshakeZeroed = handshakeState(iota)
|
|
handshakeInitiationCreated
|
|
handshakeInitiationConsumed
|
|
handshakeResponseCreated
|
|
handshakeResponseConsumed
|
|
)
|
|
|
|
func (hs handshakeState) String() string {
|
|
switch hs {
|
|
case handshakeZeroed:
|
|
return "handshakeZeroed"
|
|
case handshakeInitiationCreated:
|
|
return "handshakeInitiationCreated"
|
|
case handshakeInitiationConsumed:
|
|
return "handshakeInitiationConsumed"
|
|
case handshakeResponseCreated:
|
|
return "handshakeResponseCreated"
|
|
case handshakeResponseConsumed:
|
|
return "handshakeResponseConsumed"
|
|
default:
|
|
return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
|
|
}
|
|
}
|
|
|
|
const (
|
|
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
|
|
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
|
|
WGLabelMAC1 = "mac1----"
|
|
WGLabelCookie = "cookie--"
|
|
)
|
|
|
|
const (
|
|
MessageInitiationType = 1
|
|
MessageResponseType = 2
|
|
MessageCookieReplyType = 3
|
|
MessageTransportType = 4
|
|
)
|
|
|
|
const (
|
|
MessageInitiationSize = 148 // size of handshake initiation message
|
|
MessageResponseSize = 92 // size of response message
|
|
MessageCookieReplySize = 64 // size of cookie reply message
|
|
MessageTransportHeaderSize = 16 // size of data preceding content in transport message
|
|
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
|
|
MessageKeepaliveSize = MessageTransportSize // size of keepalive
|
|
MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
|
|
)
|
|
|
|
const (
|
|
MessageTransportOffsetReceiver = 4
|
|
MessageTransportOffsetCounter = 8
|
|
MessageTransportOffsetContent = 16
|
|
)
|
|
|
|
/* Type is an 8-bit field, followed by 3 nul bytes,
|
|
* by marshalling the messages in little-endian byteorder
|
|
* we can treat these as a 32-bit unsigned int (for now)
|
|
*
|
|
*/
|
|
|
|
type MessageInitiation struct {
|
|
Type uint32
|
|
Sender uint32
|
|
Ephemeral NoisePublicKey
|
|
Static [NoisePublicKeySize + poly1305.TagSize]byte
|
|
Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
|
|
MAC1 [blake2s.Size128]byte
|
|
MAC2 [blake2s.Size128]byte
|
|
}
|
|
|
|
type MessageResponse struct {
|
|
Type uint32
|
|
Sender uint32
|
|
Receiver uint32
|
|
Ephemeral NoisePublicKey
|
|
Empty [poly1305.TagSize]byte
|
|
MAC1 [blake2s.Size128]byte
|
|
MAC2 [blake2s.Size128]byte
|
|
}
|
|
|
|
type MessageTransport struct {
|
|
Type uint32
|
|
Receiver uint32
|
|
Counter uint64
|
|
Content []byte
|
|
}
|
|
|
|
type MessageCookieReply struct {
|
|
Type uint32
|
|
Receiver uint32
|
|
Nonce [chacha20poly1305.NonceSizeX]byte
|
|
Cookie [blake2s.Size128 + poly1305.TagSize]byte
|
|
}
|
|
|
|
type Handshake struct {
|
|
state handshakeState
|
|
mutex sync.RWMutex
|
|
hash [blake2s.Size]byte // hash value
|
|
chainKey [blake2s.Size]byte // chain key
|
|
presharedKey NoiseSymmetricKey // psk
|
|
localEphemeral NoisePrivateKey // ephemeral secret key
|
|
localIndex uint32 // used to clear hash-table
|
|
remoteIndex uint32 // index for sending
|
|
remoteStatic NoisePublicKey // long term key
|
|
remoteEphemeral NoisePublicKey // ephemeral public key
|
|
precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
|
|
lastTimestamp tai64n.Timestamp
|
|
lastInitiationConsumption time.Time
|
|
lastSentHandshake time.Time
|
|
}
|
|
|
|
var (
|
|
InitialChainKey [blake2s.Size]byte
|
|
InitialHash [blake2s.Size]byte
|
|
ZeroNonce [chacha20poly1305.NonceSize]byte
|
|
)
|
|
|
|
func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
|
|
KDF1(dst, c[:], data)
|
|
}
|
|
|
|
func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
|
|
hash, _ := blake2s.New256(nil)
|
|
hash.Write(h[:])
|
|
hash.Write(data)
|
|
hash.Sum(dst[:0])
|
|
hash.Reset()
|
|
}
|
|
|
|
func (h *Handshake) Clear() {
|
|
setZero(h.localEphemeral[:])
|
|
setZero(h.remoteEphemeral[:])
|
|
setZero(h.chainKey[:])
|
|
setZero(h.hash[:])
|
|
h.localIndex = 0
|
|
h.state = handshakeZeroed
|
|
}
|
|
|
|
func (h *Handshake) mixHash(data []byte) {
|
|
mixHash(&h.hash, &h.hash, data)
|
|
}
|
|
|
|
func (h *Handshake) mixKey(data []byte) {
|
|
mixKey(&h.chainKey, &h.chainKey, data)
|
|
}
|
|
|
|
/* Do basic precomputations
|
|
*/
|
|
func init() {
|
|
InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
|
|
mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
|
|
}
|
|
|
|
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
|
var errZeroECDHResult = errors.New("ECDH returned all zeros")
|
|
|
|
device.staticIdentity.RLock()
|
|
defer device.staticIdentity.RUnlock()
|
|
|
|
handshake := &peer.handshake
|
|
handshake.mutex.Lock()
|
|
defer handshake.mutex.Unlock()
|
|
|
|
// create ephemeral key
|
|
var err error
|
|
handshake.hash = InitialHash
|
|
handshake.chainKey = InitialChainKey
|
|
handshake.localEphemeral, err = newPrivateKey()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
handshake.mixHash(handshake.remoteStatic[:])
|
|
|
|
msg := MessageInitiation{
|
|
Type: MessageInitiationType,
|
|
Ephemeral: handshake.localEphemeral.publicKey(),
|
|
}
|
|
|
|
handshake.mixKey(msg.Ephemeral[:])
|
|
handshake.mixHash(msg.Ephemeral[:])
|
|
|
|
// encrypt static key
|
|
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
|
if isZero(ss[:]) {
|
|
return nil, errZeroECDHResult
|
|
}
|
|
var key [chacha20poly1305.KeySize]byte
|
|
KDF2(
|
|
&handshake.chainKey,
|
|
&key,
|
|
handshake.chainKey[:],
|
|
ss[:],
|
|
)
|
|
aead, _ := chacha20poly1305.New(key[:])
|
|
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
|
|
handshake.mixHash(msg.Static[:])
|
|
|
|
// encrypt timestamp
|
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
|
return nil, errZeroECDHResult
|
|
}
|
|
KDF2(
|
|
&handshake.chainKey,
|
|
&key,
|
|
handshake.chainKey[:],
|
|
handshake.precomputedStaticStatic[:],
|
|
)
|
|
timestamp := tai64n.Now()
|
|
aead, _ = chacha20poly1305.New(key[:])
|
|
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
|
|
|
|
// assign index
|
|
device.indexTable.Delete(handshake.localIndex)
|
|
msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
handshake.localIndex = msg.Sender
|
|
|
|
handshake.mixHash(msg.Timestamp[:])
|
|
handshake.state = handshakeInitiationCreated
|
|
return &msg, nil
|
|
}
|
|
|
|
func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|
var (
|
|
hash [blake2s.Size]byte
|
|
chainKey [blake2s.Size]byte
|
|
)
|
|
|
|
if msg.Type != MessageInitiationType {
|
|
return nil
|
|
}
|
|
|
|
device.staticIdentity.RLock()
|
|
defer device.staticIdentity.RUnlock()
|
|
|
|
mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
|
|
mixHash(&hash, &hash, msg.Ephemeral[:])
|
|
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
|
|
|
// decrypt static key
|
|
var err error
|
|
var peerPK NoisePublicKey
|
|
var key [chacha20poly1305.KeySize]byte
|
|
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
|
if isZero(ss[:]) {
|
|
return nil
|
|
}
|
|
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
|
aead, _ := chacha20poly1305.New(key[:])
|
|
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
mixHash(&hash, &hash, msg.Static[:])
|
|
|
|
// lookup peer
|
|
|
|
peer := device.LookupPeer(peerPK)
|
|
if peer == nil {
|
|
return nil
|
|
}
|
|
|
|
handshake := &peer.handshake
|
|
|
|
// verify identity
|
|
|
|
var timestamp tai64n.Timestamp
|
|
|
|
handshake.mutex.RLock()
|
|
|
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
|
handshake.mutex.RUnlock()
|
|
return nil
|
|
}
|
|
KDF2(
|
|
&chainKey,
|
|
&key,
|
|
chainKey[:],
|
|
handshake.precomputedStaticStatic[:],
|
|
)
|
|
aead, _ = chacha20poly1305.New(key[:])
|
|
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
|
if err != nil {
|
|
handshake.mutex.RUnlock()
|
|
return nil
|
|
}
|
|
mixHash(&hash, &hash, msg.Timestamp[:])
|
|
|
|
// protect against replay & flood
|
|
|
|
replay := !timestamp.After(handshake.lastTimestamp)
|
|
flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
|
|
handshake.mutex.RUnlock()
|
|
if replay {
|
|
device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake replay @ %v\n", peer, timestamp)
|
|
return nil
|
|
}
|
|
if flood {
|
|
device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake flood\n", peer)
|
|
return nil
|
|
}
|
|
|
|
// update handshake state
|
|
|
|
handshake.mutex.Lock()
|
|
|
|
handshake.hash = hash
|
|
handshake.chainKey = chainKey
|
|
handshake.remoteIndex = msg.Sender
|
|
handshake.remoteEphemeral = msg.Ephemeral
|
|
if timestamp.After(handshake.lastTimestamp) {
|
|
handshake.lastTimestamp = timestamp
|
|
}
|
|
now := time.Now()
|
|
if now.After(handshake.lastInitiationConsumption) {
|
|
handshake.lastInitiationConsumption = now
|
|
}
|
|
handshake.state = handshakeInitiationConsumed
|
|
|
|
handshake.mutex.Unlock()
|
|
|
|
setZero(hash[:])
|
|
setZero(chainKey[:])
|
|
|
|
return peer
|
|
}
|
|
|
|
func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
|
|
handshake := &peer.handshake
|
|
handshake.mutex.Lock()
|
|
defer handshake.mutex.Unlock()
|
|
|
|
if handshake.state != handshakeInitiationConsumed {
|
|
return nil, errors.New("handshake initiation must be consumed first")
|
|
}
|
|
|
|
// assign index
|
|
|
|
var err error
|
|
device.indexTable.Delete(handshake.localIndex)
|
|
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var msg MessageResponse
|
|
msg.Type = MessageResponseType
|
|
msg.Sender = handshake.localIndex
|
|
msg.Receiver = handshake.remoteIndex
|
|
|
|
// create ephemeral key
|
|
|
|
handshake.localEphemeral, err = newPrivateKey()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
|
handshake.mixHash(msg.Ephemeral[:])
|
|
handshake.mixKey(msg.Ephemeral[:])
|
|
|
|
func() {
|
|
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
|
handshake.mixKey(ss[:])
|
|
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
|
handshake.mixKey(ss[:])
|
|
}()
|
|
|
|
// add preshared key
|
|
|
|
var tau [blake2s.Size]byte
|
|
var key [chacha20poly1305.KeySize]byte
|
|
|
|
KDF3(
|
|
&handshake.chainKey,
|
|
&tau,
|
|
&key,
|
|
handshake.chainKey[:],
|
|
handshake.presharedKey[:],
|
|
)
|
|
|
|
handshake.mixHash(tau[:])
|
|
|
|
func() {
|
|
aead, _ := chacha20poly1305.New(key[:])
|
|
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
|
handshake.mixHash(msg.Empty[:])
|
|
}()
|
|
|
|
handshake.state = handshakeResponseCreated
|
|
|
|
return &msg, nil
|
|
}
|
|
|
|
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|
if msg.Type != MessageResponseType {
|
|
return nil
|
|
}
|
|
|
|
// lookup handshake by receiver
|
|
|
|
lookup := device.indexTable.Lookup(msg.Receiver)
|
|
handshake := lookup.handshake
|
|
if handshake == nil {
|
|
return nil
|
|
}
|
|
|
|
var (
|
|
hash [blake2s.Size]byte
|
|
chainKey [blake2s.Size]byte
|
|
)
|
|
|
|
ok := func() bool {
|
|
|
|
// lock handshake state
|
|
|
|
handshake.mutex.RLock()
|
|
defer handshake.mutex.RUnlock()
|
|
|
|
if handshake.state != handshakeInitiationCreated {
|
|
return false
|
|
}
|
|
|
|
// lock private key for reading
|
|
|
|
device.staticIdentity.RLock()
|
|
defer device.staticIdentity.RUnlock()
|
|
|
|
// finish 3-way DH
|
|
|
|
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
|
|
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
|
|
|
|
func() {
|
|
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
|
mixKey(&chainKey, &chainKey, ss[:])
|
|
setZero(ss[:])
|
|
}()
|
|
|
|
func() {
|
|
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
|
mixKey(&chainKey, &chainKey, ss[:])
|
|
setZero(ss[:])
|
|
}()
|
|
|
|
// add preshared key (psk)
|
|
|
|
var tau [blake2s.Size]byte
|
|
var key [chacha20poly1305.KeySize]byte
|
|
KDF3(
|
|
&chainKey,
|
|
&tau,
|
|
&key,
|
|
chainKey[:],
|
|
handshake.presharedKey[:],
|
|
)
|
|
mixHash(&hash, &hash, tau[:])
|
|
|
|
// authenticate transcript
|
|
|
|
aead, _ := chacha20poly1305.New(key[:])
|
|
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
|
if err != nil {
|
|
return false
|
|
}
|
|
mixHash(&hash, &hash, msg.Empty[:])
|
|
return true
|
|
}()
|
|
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
// update handshake state
|
|
|
|
handshake.mutex.Lock()
|
|
|
|
handshake.hash = hash
|
|
handshake.chainKey = chainKey
|
|
handshake.remoteIndex = msg.Sender
|
|
handshake.state = handshakeResponseConsumed
|
|
|
|
handshake.mutex.Unlock()
|
|
|
|
setZero(hash[:])
|
|
setZero(chainKey[:])
|
|
|
|
return lookup.peer
|
|
}
|
|
|
|
/* Derives a new keypair from the current handshake state
|
|
*
|
|
*/
|
|
func (peer *Peer) BeginSymmetricSession() error {
|
|
device := peer.device
|
|
handshake := &peer.handshake
|
|
handshake.mutex.Lock()
|
|
defer handshake.mutex.Unlock()
|
|
|
|
// derive keys
|
|
|
|
var isInitiator bool
|
|
var sendKey [chacha20poly1305.KeySize]byte
|
|
var recvKey [chacha20poly1305.KeySize]byte
|
|
|
|
if handshake.state == handshakeResponseConsumed {
|
|
KDF2(
|
|
&sendKey,
|
|
&recvKey,
|
|
handshake.chainKey[:],
|
|
nil,
|
|
)
|
|
isInitiator = true
|
|
} else if handshake.state == handshakeResponseCreated {
|
|
KDF2(
|
|
&recvKey,
|
|
&sendKey,
|
|
handshake.chainKey[:],
|
|
nil,
|
|
)
|
|
isInitiator = false
|
|
} else {
|
|
return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
|
|
}
|
|
|
|
// zero handshake
|
|
|
|
setZero(handshake.chainKey[:])
|
|
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
|
|
setZero(handshake.localEphemeral[:])
|
|
peer.handshake.state = handshakeZeroed
|
|
|
|
// create AEAD instances
|
|
|
|
keypair := new(Keypair)
|
|
keypair.send, _ = chacha20poly1305.New(sendKey[:])
|
|
keypair.receive, _ = chacha20poly1305.New(recvKey[:])
|
|
|
|
setZero(sendKey[:])
|
|
setZero(recvKey[:])
|
|
|
|
keypair.created = time.Now()
|
|
keypair.replayFilter.Reset()
|
|
keypair.isInitiator = isInitiator
|
|
keypair.localIndex = peer.handshake.localIndex
|
|
keypair.remoteIndex = peer.handshake.remoteIndex
|
|
|
|
// remap index
|
|
|
|
device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
|
|
handshake.localIndex = 0
|
|
|
|
// rotate key pairs
|
|
|
|
keypairs := &peer.keypairs
|
|
keypairs.Lock()
|
|
defer keypairs.Unlock()
|
|
|
|
previous := keypairs.previous
|
|
next := keypairs.loadNext()
|
|
current := keypairs.current
|
|
|
|
if isInitiator {
|
|
if next != nil {
|
|
keypairs.storeNext(nil)
|
|
keypairs.previous = next
|
|
device.DeleteKeypair(current)
|
|
} else {
|
|
keypairs.previous = current
|
|
}
|
|
device.DeleteKeypair(previous)
|
|
keypairs.current = keypair
|
|
} else {
|
|
keypairs.storeNext(keypair)
|
|
device.DeleteKeypair(next)
|
|
keypairs.previous = nil
|
|
device.DeleteKeypair(previous)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
|
keypairs := &peer.keypairs
|
|
|
|
if keypairs.loadNext() != receivedKeypair {
|
|
return false
|
|
}
|
|
keypairs.Lock()
|
|
defer keypairs.Unlock()
|
|
if keypairs.loadNext() != receivedKeypair {
|
|
return false
|
|
}
|
|
old := keypairs.previous
|
|
keypairs.previous = keypairs.current
|
|
peer.device.DeleteKeypair(old)
|
|
keypairs.current = keypairs.loadNext()
|
|
keypairs.storeNext(nil)
|
|
return true
|
|
}
|