device: give handshake state a type
And unexport handshake constants. Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
This commit is contained in:
		
							parent
							
								
									1a1c3d0968
								
							
						
					
					
						commit
						de374bfb44
					
				
					 1 changed files with 38 additions and 17 deletions
				
			
		|  | @ -7,6 +7,7 @@ package device | |||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
|  | @ -16,14 +17,34 @@ import ( | |||
| 	"golang.zx2c4.com/wireguard/tai64n" | ||||
| ) | ||||
| 
 | ||||
| type handshakeState int | ||||
| 
 | ||||
| // TODO(crawshaw): add commentary describing each state and the transitions
 | ||||
| const ( | ||||
| 	HandshakeZeroed = iota | ||||
| 	HandshakeInitiationCreated | ||||
| 	HandshakeInitiationConsumed | ||||
| 	HandshakeResponseCreated | ||||
| 	HandshakeResponseConsumed | ||||
| 	handshakeZeroed = handshakeState(iota) | ||||
| 	handshakeInitiationCreated | ||||
| 	handshakeInitiationConsumed | ||||
| 	handshakeResponseCreated | ||||
| 	handshakeResponseConsumed | ||||
| ) | ||||
| 
 | ||||
| func (hs handshakeState) String() string { | ||||
| 	switch hs { | ||||
| 	case handshakeZeroed: | ||||
| 		return "handshakeZeroed" | ||||
| 	case handshakeInitiationCreated: | ||||
| 		return "handshakeInitiationCreated" | ||||
| 	case handshakeInitiationConsumed: | ||||
| 		return "handshakeInitiationConsumed" | ||||
| 	case handshakeResponseCreated: | ||||
| 		return "handshakeResponseCreated" | ||||
| 	case handshakeResponseConsumed: | ||||
| 		return "handshakeResponseConsumed" | ||||
| 	default: | ||||
| 		return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| const ( | ||||
| 	NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" | ||||
| 	WGIdentifier      = "WireGuard v1 zx2c4 Jason@zx2c4.com" | ||||
|  | @ -95,7 +116,7 @@ type MessageCookieReply struct { | |||
| } | ||||
| 
 | ||||
| type Handshake struct { | ||||
| 	state                     int | ||||
| 	state                     handshakeState | ||||
| 	mutex                     sync.RWMutex | ||||
| 	hash                      [blake2s.Size]byte       // hash value
 | ||||
| 	chainKey                  [blake2s.Size]byte       // chain key
 | ||||
|  | @ -135,7 +156,7 @@ func (h *Handshake) Clear() { | |||
| 	setZero(h.chainKey[:]) | ||||
| 	setZero(h.hash[:]) | ||||
| 	h.localIndex = 0 | ||||
| 	h.state = HandshakeZeroed | ||||
| 	h.state = handshakeZeroed | ||||
| } | ||||
| 
 | ||||
| func (h *Handshake) mixHash(data []byte) { | ||||
|  | @ -221,7 +242,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e | |||
| 	handshake.localIndex = msg.Sender | ||||
| 
 | ||||
| 	handshake.mixHash(msg.Timestamp[:]) | ||||
| 	handshake.state = HandshakeInitiationCreated | ||||
| 	handshake.state = handshakeInitiationCreated | ||||
| 	return &msg, nil | ||||
| } | ||||
| 
 | ||||
|  | @ -316,7 +337,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { | |||
| 	if now.After(handshake.lastInitiationConsumption) { | ||||
| 		handshake.lastInitiationConsumption = now | ||||
| 	} | ||||
| 	handshake.state = HandshakeInitiationConsumed | ||||
| 	handshake.state = handshakeInitiationConsumed | ||||
| 
 | ||||
| 	handshake.mutex.Unlock() | ||||
| 
 | ||||
|  | @ -331,7 +352,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error | |||
| 	handshake.mutex.Lock() | ||||
| 	defer handshake.mutex.Unlock() | ||||
| 
 | ||||
| 	if handshake.state != HandshakeInitiationConsumed { | ||||
| 	if handshake.state != handshakeInitiationConsumed { | ||||
| 		return nil, errors.New("handshake initiation must be consumed first") | ||||
| 	} | ||||
| 
 | ||||
|  | @ -387,7 +408,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error | |||
| 		handshake.mixHash(msg.Empty[:]) | ||||
| 	}() | ||||
| 
 | ||||
| 	handshake.state = HandshakeResponseCreated | ||||
| 	handshake.state = handshakeResponseCreated | ||||
| 
 | ||||
| 	return &msg, nil | ||||
| } | ||||
|  | @ -417,7 +438,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { | |||
| 		handshake.mutex.RLock() | ||||
| 		defer handshake.mutex.RUnlock() | ||||
| 
 | ||||
| 		if handshake.state != HandshakeInitiationCreated { | ||||
| 		if handshake.state != handshakeInitiationCreated { | ||||
| 			return false | ||||
| 		} | ||||
| 
 | ||||
|  | @ -478,7 +499,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { | |||
| 	handshake.hash = hash | ||||
| 	handshake.chainKey = chainKey | ||||
| 	handshake.remoteIndex = msg.Sender | ||||
| 	handshake.state = HandshakeResponseConsumed | ||||
| 	handshake.state = handshakeResponseConsumed | ||||
| 
 | ||||
| 	handshake.mutex.Unlock() | ||||
| 
 | ||||
|  | @ -503,7 +524,7 @@ func (peer *Peer) BeginSymmetricSession() error { | |||
| 	var sendKey [chacha20poly1305.KeySize]byte | ||||
| 	var recvKey [chacha20poly1305.KeySize]byte | ||||
| 
 | ||||
| 	if handshake.state == HandshakeResponseConsumed { | ||||
| 	if handshake.state == handshakeResponseConsumed { | ||||
| 		KDF2( | ||||
| 			&sendKey, | ||||
| 			&recvKey, | ||||
|  | @ -511,7 +532,7 @@ func (peer *Peer) BeginSymmetricSession() error { | |||
| 			nil, | ||||
| 		) | ||||
| 		isInitiator = true | ||||
| 	} else if handshake.state == HandshakeResponseCreated { | ||||
| 	} else if handshake.state == handshakeResponseCreated { | ||||
| 		KDF2( | ||||
| 			&recvKey, | ||||
| 			&sendKey, | ||||
|  | @ -520,7 +541,7 @@ func (peer *Peer) BeginSymmetricSession() error { | |||
| 		) | ||||
| 		isInitiator = false | ||||
| 	} else { | ||||
| 		return errors.New("invalid state for keypair derivation") | ||||
| 		return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state) | ||||
| 	} | ||||
| 
 | ||||
| 	// zero handshake
 | ||||
|  | @ -528,7 +549,7 @@ func (peer *Peer) BeginSymmetricSession() error { | |||
| 	setZero(handshake.chainKey[:]) | ||||
| 	setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
 | ||||
| 	setZero(handshake.localEphemeral[:]) | ||||
| 	peer.handshake.state = HandshakeZeroed | ||||
| 	peer.handshake.state = handshakeZeroed | ||||
| 
 | ||||
| 	// create AEAD instances
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue