device: give handshake state a type
And unexport handshake constants. Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
This commit is contained in:
		
							parent
							
								
									1a1c3d0968
								
							
						
					
					
						commit
						de374bfb44
					
				
					 1 changed files with 38 additions and 17 deletions
				
			
		|  | @ -7,6 +7,7 @@ package device | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"fmt" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | @ -16,14 +17,34 @@ import ( | ||||||
| 	"golang.zx2c4.com/wireguard/tai64n" | 	"golang.zx2c4.com/wireguard/tai64n" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | type handshakeState int | ||||||
|  | 
 | ||||||
|  | // TODO(crawshaw): add commentary describing each state and the transitions
 | ||||||
| const ( | const ( | ||||||
| 	HandshakeZeroed = iota | 	handshakeZeroed = handshakeState(iota) | ||||||
| 	HandshakeInitiationCreated | 	handshakeInitiationCreated | ||||||
| 	HandshakeInitiationConsumed | 	handshakeInitiationConsumed | ||||||
| 	HandshakeResponseCreated | 	handshakeResponseCreated | ||||||
| 	HandshakeResponseConsumed | 	handshakeResponseConsumed | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | func (hs handshakeState) String() string { | ||||||
|  | 	switch hs { | ||||||
|  | 	case handshakeZeroed: | ||||||
|  | 		return "handshakeZeroed" | ||||||
|  | 	case handshakeInitiationCreated: | ||||||
|  | 		return "handshakeInitiationCreated" | ||||||
|  | 	case handshakeInitiationConsumed: | ||||||
|  | 		return "handshakeInitiationConsumed" | ||||||
|  | 	case handshakeResponseCreated: | ||||||
|  | 		return "handshakeResponseCreated" | ||||||
|  | 	case handshakeResponseConsumed: | ||||||
|  | 		return "handshakeResponseConsumed" | ||||||
|  | 	default: | ||||||
|  | 		return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs)) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| const ( | const ( | ||||||
| 	NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" | 	NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" | ||||||
| 	WGIdentifier      = "WireGuard v1 zx2c4 Jason@zx2c4.com" | 	WGIdentifier      = "WireGuard v1 zx2c4 Jason@zx2c4.com" | ||||||
|  | @ -95,7 +116,7 @@ type MessageCookieReply struct { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type Handshake struct { | type Handshake struct { | ||||||
| 	state                     int | 	state                     handshakeState | ||||||
| 	mutex                     sync.RWMutex | 	mutex                     sync.RWMutex | ||||||
| 	hash                      [blake2s.Size]byte       // hash value
 | 	hash                      [blake2s.Size]byte       // hash value
 | ||||||
| 	chainKey                  [blake2s.Size]byte       // chain key
 | 	chainKey                  [blake2s.Size]byte       // chain key
 | ||||||
|  | @ -135,7 +156,7 @@ func (h *Handshake) Clear() { | ||||||
| 	setZero(h.chainKey[:]) | 	setZero(h.chainKey[:]) | ||||||
| 	setZero(h.hash[:]) | 	setZero(h.hash[:]) | ||||||
| 	h.localIndex = 0 | 	h.localIndex = 0 | ||||||
| 	h.state = HandshakeZeroed | 	h.state = handshakeZeroed | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (h *Handshake) mixHash(data []byte) { | func (h *Handshake) mixHash(data []byte) { | ||||||
|  | @ -221,7 +242,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e | ||||||
| 	handshake.localIndex = msg.Sender | 	handshake.localIndex = msg.Sender | ||||||
| 
 | 
 | ||||||
| 	handshake.mixHash(msg.Timestamp[:]) | 	handshake.mixHash(msg.Timestamp[:]) | ||||||
| 	handshake.state = HandshakeInitiationCreated | 	handshake.state = handshakeInitiationCreated | ||||||
| 	return &msg, nil | 	return &msg, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -316,7 +337,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { | ||||||
| 	if now.After(handshake.lastInitiationConsumption) { | 	if now.After(handshake.lastInitiationConsumption) { | ||||||
| 		handshake.lastInitiationConsumption = now | 		handshake.lastInitiationConsumption = now | ||||||
| 	} | 	} | ||||||
| 	handshake.state = HandshakeInitiationConsumed | 	handshake.state = handshakeInitiationConsumed | ||||||
| 
 | 
 | ||||||
| 	handshake.mutex.Unlock() | 	handshake.mutex.Unlock() | ||||||
| 
 | 
 | ||||||
|  | @ -331,7 +352,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error | ||||||
| 	handshake.mutex.Lock() | 	handshake.mutex.Lock() | ||||||
| 	defer handshake.mutex.Unlock() | 	defer handshake.mutex.Unlock() | ||||||
| 
 | 
 | ||||||
| 	if handshake.state != HandshakeInitiationConsumed { | 	if handshake.state != handshakeInitiationConsumed { | ||||||
| 		return nil, errors.New("handshake initiation must be consumed first") | 		return nil, errors.New("handshake initiation must be consumed first") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -387,7 +408,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error | ||||||
| 		handshake.mixHash(msg.Empty[:]) | 		handshake.mixHash(msg.Empty[:]) | ||||||
| 	}() | 	}() | ||||||
| 
 | 
 | ||||||
| 	handshake.state = HandshakeResponseCreated | 	handshake.state = handshakeResponseCreated | ||||||
| 
 | 
 | ||||||
| 	return &msg, nil | 	return &msg, nil | ||||||
| } | } | ||||||
|  | @ -417,7 +438,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { | ||||||
| 		handshake.mutex.RLock() | 		handshake.mutex.RLock() | ||||||
| 		defer handshake.mutex.RUnlock() | 		defer handshake.mutex.RUnlock() | ||||||
| 
 | 
 | ||||||
| 		if handshake.state != HandshakeInitiationCreated { | 		if handshake.state != handshakeInitiationCreated { | ||||||
| 			return false | 			return false | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -478,7 +499,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { | ||||||
| 	handshake.hash = hash | 	handshake.hash = hash | ||||||
| 	handshake.chainKey = chainKey | 	handshake.chainKey = chainKey | ||||||
| 	handshake.remoteIndex = msg.Sender | 	handshake.remoteIndex = msg.Sender | ||||||
| 	handshake.state = HandshakeResponseConsumed | 	handshake.state = handshakeResponseConsumed | ||||||
| 
 | 
 | ||||||
| 	handshake.mutex.Unlock() | 	handshake.mutex.Unlock() | ||||||
| 
 | 
 | ||||||
|  | @ -503,7 +524,7 @@ func (peer *Peer) BeginSymmetricSession() error { | ||||||
| 	var sendKey [chacha20poly1305.KeySize]byte | 	var sendKey [chacha20poly1305.KeySize]byte | ||||||
| 	var recvKey [chacha20poly1305.KeySize]byte | 	var recvKey [chacha20poly1305.KeySize]byte | ||||||
| 
 | 
 | ||||||
| 	if handshake.state == HandshakeResponseConsumed { | 	if handshake.state == handshakeResponseConsumed { | ||||||
| 		KDF2( | 		KDF2( | ||||||
| 			&sendKey, | 			&sendKey, | ||||||
| 			&recvKey, | 			&recvKey, | ||||||
|  | @ -511,7 +532,7 @@ func (peer *Peer) BeginSymmetricSession() error { | ||||||
| 			nil, | 			nil, | ||||||
| 		) | 		) | ||||||
| 		isInitiator = true | 		isInitiator = true | ||||||
| 	} else if handshake.state == HandshakeResponseCreated { | 	} else if handshake.state == handshakeResponseCreated { | ||||||
| 		KDF2( | 		KDF2( | ||||||
| 			&recvKey, | 			&recvKey, | ||||||
| 			&sendKey, | 			&sendKey, | ||||||
|  | @ -520,7 +541,7 @@ func (peer *Peer) BeginSymmetricSession() error { | ||||||
| 		) | 		) | ||||||
| 		isInitiator = false | 		isInitiator = false | ||||||
| 	} else { | 	} else { | ||||||
| 		return errors.New("invalid state for keypair derivation") | 		return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// zero handshake
 | 	// zero handshake
 | ||||||
|  | @ -528,7 +549,7 @@ func (peer *Peer) BeginSymmetricSession() error { | ||||||
| 	setZero(handshake.chainKey[:]) | 	setZero(handshake.chainKey[:]) | ||||||
| 	setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
 | 	setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
 | ||||||
| 	setZero(handshake.localEphemeral[:]) | 	setZero(handshake.localEphemeral[:]) | ||||||
| 	peer.handshake.state = HandshakeZeroed | 	peer.handshake.state = handshakeZeroed | ||||||
| 
 | 
 | ||||||
| 	// create AEAD instances
 | 	// create AEAD instances
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue