device: give handshake state a type

And unexport handshake constants.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
This commit is contained in:
David Crawshaw 2020-03-04 20:58:39 -05:00 committed by Jason A. Donenfeld
parent 1a1c3d0968
commit de374bfb44

View file

@ -7,6 +7,7 @@ package device
import ( import (
"errors" "errors"
"fmt"
"sync" "sync"
"time" "time"
@ -16,14 +17,34 @@ import (
"golang.zx2c4.com/wireguard/tai64n" "golang.zx2c4.com/wireguard/tai64n"
) )
type handshakeState int
// TODO(crawshaw): add commentary describing each state and the transitions
const ( const (
HandshakeZeroed = iota handshakeZeroed = handshakeState(iota)
HandshakeInitiationCreated handshakeInitiationCreated
HandshakeInitiationConsumed handshakeInitiationConsumed
HandshakeResponseCreated handshakeResponseCreated
HandshakeResponseConsumed handshakeResponseConsumed
) )
func (hs handshakeState) String() string {
switch hs {
case handshakeZeroed:
return "handshakeZeroed"
case handshakeInitiationCreated:
return "handshakeInitiationCreated"
case handshakeInitiationConsumed:
return "handshakeInitiationConsumed"
case handshakeResponseCreated:
return "handshakeResponseCreated"
case handshakeResponseConsumed:
return "handshakeResponseConsumed"
default:
return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
}
}
const ( const (
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
@ -95,7 +116,7 @@ type MessageCookieReply struct {
} }
type Handshake struct { type Handshake struct {
state int state handshakeState
mutex sync.RWMutex mutex sync.RWMutex
hash [blake2s.Size]byte // hash value hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key chainKey [blake2s.Size]byte // chain key
@ -135,7 +156,7 @@ func (h *Handshake) Clear() {
setZero(h.chainKey[:]) setZero(h.chainKey[:])
setZero(h.hash[:]) setZero(h.hash[:])
h.localIndex = 0 h.localIndex = 0
h.state = HandshakeZeroed h.state = handshakeZeroed
} }
func (h *Handshake) mixHash(data []byte) { func (h *Handshake) mixHash(data []byte) {
@ -221,7 +242,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.localIndex = msg.Sender handshake.localIndex = msg.Sender
handshake.mixHash(msg.Timestamp[:]) handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitiationCreated handshake.state = handshakeInitiationCreated
return &msg, nil return &msg, nil
} }
@ -316,7 +337,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
if now.After(handshake.lastInitiationConsumption) { if now.After(handshake.lastInitiationConsumption) {
handshake.lastInitiationConsumption = now handshake.lastInitiationConsumption = now
} }
handshake.state = HandshakeInitiationConsumed handshake.state = handshakeInitiationConsumed
handshake.mutex.Unlock() handshake.mutex.Unlock()
@ -331,7 +352,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
if handshake.state != HandshakeInitiationConsumed { if handshake.state != handshakeInitiationConsumed {
return nil, errors.New("handshake initiation must be consumed first") return nil, errors.New("handshake initiation must be consumed first")
} }
@ -387,7 +408,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Empty[:]) handshake.mixHash(msg.Empty[:])
}() }()
handshake.state = HandshakeResponseCreated handshake.state = handshakeResponseCreated
return &msg, nil return &msg, nil
} }
@ -417,7 +438,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.mutex.RLock() handshake.mutex.RLock()
defer handshake.mutex.RUnlock() defer handshake.mutex.RUnlock()
if handshake.state != HandshakeInitiationCreated { if handshake.state != handshakeInitiationCreated {
return false return false
} }
@ -478,7 +499,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.hash = hash handshake.hash = hash
handshake.chainKey = chainKey handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender handshake.remoteIndex = msg.Sender
handshake.state = HandshakeResponseConsumed handshake.state = handshakeResponseConsumed
handshake.mutex.Unlock() handshake.mutex.Unlock()
@ -503,7 +524,7 @@ func (peer *Peer) BeginSymmetricSession() error {
var sendKey [chacha20poly1305.KeySize]byte var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte var recvKey [chacha20poly1305.KeySize]byte
if handshake.state == HandshakeResponseConsumed { if handshake.state == handshakeResponseConsumed {
KDF2( KDF2(
&sendKey, &sendKey,
&recvKey, &recvKey,
@ -511,7 +532,7 @@ func (peer *Peer) BeginSymmetricSession() error {
nil, nil,
) )
isInitiator = true isInitiator = true
} else if handshake.state == HandshakeResponseCreated { } else if handshake.state == handshakeResponseCreated {
KDF2( KDF2(
&recvKey, &recvKey,
&sendKey, &sendKey,
@ -520,7 +541,7 @@ func (peer *Peer) BeginSymmetricSession() error {
) )
isInitiator = false isInitiator = false
} else { } else {
return errors.New("invalid state for keypair derivation") return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
} }
// zero handshake // zero handshake
@ -528,7 +549,7 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(handshake.chainKey[:]) setZero(handshake.chainKey[:])
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
setZero(handshake.localEphemeral[:]) setZero(handshake.localEphemeral[:])
peer.handshake.state = HandshakeZeroed peer.handshake.state = handshakeZeroed
// create AEAD instances // create AEAD instances