device: give handshake state a type
And unexport handshake constants. Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
This commit is contained in:
parent
1a1c3d0968
commit
de374bfb44
|
@ -7,6 +7,7 @@ package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -16,14 +17,34 @@ import (
|
||||||
"golang.zx2c4.com/wireguard/tai64n"
|
"golang.zx2c4.com/wireguard/tai64n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type handshakeState int
|
||||||
|
|
||||||
|
// TODO(crawshaw): add commentary describing each state and the transitions
|
||||||
const (
|
const (
|
||||||
HandshakeZeroed = iota
|
handshakeZeroed = handshakeState(iota)
|
||||||
HandshakeInitiationCreated
|
handshakeInitiationCreated
|
||||||
HandshakeInitiationConsumed
|
handshakeInitiationConsumed
|
||||||
HandshakeResponseCreated
|
handshakeResponseCreated
|
||||||
HandshakeResponseConsumed
|
handshakeResponseConsumed
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (hs handshakeState) String() string {
|
||||||
|
switch hs {
|
||||||
|
case handshakeZeroed:
|
||||||
|
return "handshakeZeroed"
|
||||||
|
case handshakeInitiationCreated:
|
||||||
|
return "handshakeInitiationCreated"
|
||||||
|
case handshakeInitiationConsumed:
|
||||||
|
return "handshakeInitiationConsumed"
|
||||||
|
case handshakeResponseCreated:
|
||||||
|
return "handshakeResponseCreated"
|
||||||
|
case handshakeResponseConsumed:
|
||||||
|
return "handshakeResponseConsumed"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
|
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
|
||||||
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
|
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
|
||||||
|
@ -95,7 +116,7 @@ type MessageCookieReply struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Handshake struct {
|
type Handshake struct {
|
||||||
state int
|
state handshakeState
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
hash [blake2s.Size]byte // hash value
|
hash [blake2s.Size]byte // hash value
|
||||||
chainKey [blake2s.Size]byte // chain key
|
chainKey [blake2s.Size]byte // chain key
|
||||||
|
@ -135,7 +156,7 @@ func (h *Handshake) Clear() {
|
||||||
setZero(h.chainKey[:])
|
setZero(h.chainKey[:])
|
||||||
setZero(h.hash[:])
|
setZero(h.hash[:])
|
||||||
h.localIndex = 0
|
h.localIndex = 0
|
||||||
h.state = HandshakeZeroed
|
h.state = handshakeZeroed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshake) mixHash(data []byte) {
|
func (h *Handshake) mixHash(data []byte) {
|
||||||
|
@ -221,7 +242,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
||||||
handshake.localIndex = msg.Sender
|
handshake.localIndex = msg.Sender
|
||||||
|
|
||||||
handshake.mixHash(msg.Timestamp[:])
|
handshake.mixHash(msg.Timestamp[:])
|
||||||
handshake.state = HandshakeInitiationCreated
|
handshake.state = handshakeInitiationCreated
|
||||||
return &msg, nil
|
return &msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -316,7 +337,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||||
if now.After(handshake.lastInitiationConsumption) {
|
if now.After(handshake.lastInitiationConsumption) {
|
||||||
handshake.lastInitiationConsumption = now
|
handshake.lastInitiationConsumption = now
|
||||||
}
|
}
|
||||||
handshake.state = HandshakeInitiationConsumed
|
handshake.state = handshakeInitiationConsumed
|
||||||
|
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
|
@ -331,7 +352,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
defer handshake.mutex.Unlock()
|
defer handshake.mutex.Unlock()
|
||||||
|
|
||||||
if handshake.state != HandshakeInitiationConsumed {
|
if handshake.state != handshakeInitiationConsumed {
|
||||||
return nil, errors.New("handshake initiation must be consumed first")
|
return nil, errors.New("handshake initiation must be consumed first")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -387,7 +408,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||||
handshake.mixHash(msg.Empty[:])
|
handshake.mixHash(msg.Empty[:])
|
||||||
}()
|
}()
|
||||||
|
|
||||||
handshake.state = HandshakeResponseCreated
|
handshake.state = handshakeResponseCreated
|
||||||
|
|
||||||
return &msg, nil
|
return &msg, nil
|
||||||
}
|
}
|
||||||
|
@ -417,7 +438,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
handshake.mutex.RLock()
|
handshake.mutex.RLock()
|
||||||
defer handshake.mutex.RUnlock()
|
defer handshake.mutex.RUnlock()
|
||||||
|
|
||||||
if handshake.state != HandshakeInitiationCreated {
|
if handshake.state != handshakeInitiationCreated {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -478,7 +499,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
handshake.hash = hash
|
handshake.hash = hash
|
||||||
handshake.chainKey = chainKey
|
handshake.chainKey = chainKey
|
||||||
handshake.remoteIndex = msg.Sender
|
handshake.remoteIndex = msg.Sender
|
||||||
handshake.state = HandshakeResponseConsumed
|
handshake.state = handshakeResponseConsumed
|
||||||
|
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
|
@ -503,7 +524,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
var sendKey [chacha20poly1305.KeySize]byte
|
var sendKey [chacha20poly1305.KeySize]byte
|
||||||
var recvKey [chacha20poly1305.KeySize]byte
|
var recvKey [chacha20poly1305.KeySize]byte
|
||||||
|
|
||||||
if handshake.state == HandshakeResponseConsumed {
|
if handshake.state == handshakeResponseConsumed {
|
||||||
KDF2(
|
KDF2(
|
||||||
&sendKey,
|
&sendKey,
|
||||||
&recvKey,
|
&recvKey,
|
||||||
|
@ -511,7 +532,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
isInitiator = true
|
isInitiator = true
|
||||||
} else if handshake.state == HandshakeResponseCreated {
|
} else if handshake.state == handshakeResponseCreated {
|
||||||
KDF2(
|
KDF2(
|
||||||
&recvKey,
|
&recvKey,
|
||||||
&sendKey,
|
&sendKey,
|
||||||
|
@ -520,7 +541,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
)
|
)
|
||||||
isInitiator = false
|
isInitiator = false
|
||||||
} else {
|
} else {
|
||||||
return errors.New("invalid state for keypair derivation")
|
return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
|
||||||
}
|
}
|
||||||
|
|
||||||
// zero handshake
|
// zero handshake
|
||||||
|
@ -528,7 +549,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
setZero(handshake.chainKey[:])
|
setZero(handshake.chainKey[:])
|
||||||
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
|
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
|
||||||
setZero(handshake.localEphemeral[:])
|
setZero(handshake.localEphemeral[:])
|
||||||
peer.handshake.state = HandshakeZeroed
|
peer.handshake.state = handshakeZeroed
|
||||||
|
|
||||||
// create AEAD instances
|
// create AEAD instances
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue