diff --git a/device/device.go b/device/device.go index 0b909a7..8c08f1c 100644 --- a/device/device.go +++ b/device/device.go @@ -240,9 +240,6 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { for _, peer := range device.peers.keyMap { handshake := &peer.handshake handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) - if isZero(handshake.precomputedStaticStatic[:]) { - panic("an invalid peer public key made it into the configuration") - } expiredPeers = append(expiredPeers, peer) } diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 1c08e0a..ee327d2 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -154,6 +154,7 @@ func init() { } func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { + var errZeroECDHResult = errors.New("ECDH returned all zeros") device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -162,12 +163,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mutex.Lock() defer handshake.mutex.Unlock() - if isZero(handshake.precomputedStaticStatic[:]) { - return nil, errors.New("static shared secret is zero") - } - // create ephemeral key - var err error handshake.hash = InitialHash handshake.chainKey = InitialChainKey @@ -176,56 +172,53 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e return nil, err } - // assign index - - device.indexTable.Delete(handshake.localIndex) - handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) - - if err != nil { - return nil, err - } - handshake.mixHash(handshake.remoteStatic[:]) msg := MessageInitiation{ Type: MessageInitiationType, Ephemeral: handshake.localEphemeral.publicKey(), - Sender: handshake.localIndex, } handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) // encrypt static key - - func() { - var key [chacha20poly1305.KeySize]byte - ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - KDF2( - &handshake.chainKey, - &key, - handshake.chainKey[:], - ss[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) - }() + ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + if isZero(ss[:]) { + return nil, errZeroECDHResult + } + var key [chacha20poly1305.KeySize]byte + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + ss[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) handshake.mixHash(msg.Static[:]) // encrypt timestamp - + if isZero(handshake.precomputedStaticStatic[:]) { + return nil, errZeroECDHResult + } + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + handshake.precomputedStaticStatic[:], + ) timestamp := tai64n.Now() - func() { - var key [chacha20poly1305.KeySize]byte - KDF2( - &handshake.chainKey, - &key, - handshake.chainKey[:], - handshake.precomputedStaticStatic[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) - }() + aead, _ = chacha20poly1305.New(key[:]) + aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) + + // assign index + device.indexTable.Delete(handshake.localIndex) + msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake) + if err != nil { + return nil, err + } + handshake.localIndex = msg.Sender handshake.mixHash(msg.Timestamp[:]) handshake.state = HandshakeInitiationCreated @@ -250,16 +243,16 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) // decrypt static key - var err error var peerPK NoisePublicKey - func() { - var key [chacha20poly1305.KeySize]byte - ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) - KDF2(&chainKey, &key, chainKey[:], ss[:]) - aead, _ := chacha20poly1305.New(key[:]) - _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) - }() + var key [chacha20poly1305.KeySize]byte + ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + if isZero(ss[:]) { + return nil + } + KDF2(&chainKey, &key, chainKey[:], ss[:]) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) if err != nil { return nil } @@ -273,23 +266,24 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { } handshake := &peer.handshake - if isZero(handshake.precomputedStaticStatic[:]) { - return nil - } // verify identity var timestamp tai64n.Timestamp - var key [chacha20poly1305.KeySize]byte handshake.mutex.RLock() + + if isZero(handshake.precomputedStaticStatic[:]) { + handshake.mutex.RUnlock() + return nil + } KDF2( &chainKey, &key, chainKey[:], handshake.precomputedStaticStatic[:], ) - aead, _ := chacha20poly1305.New(key[:]) + aead, _ = chacha20poly1305.New(key[:]) _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) if err != nil { handshake.mutex.RUnlock() diff --git a/device/peer.go b/device/peer.go index 91d975a..8a8224c 100644 --- a/device/peer.go +++ b/device/peer.go @@ -108,7 +108,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake := &peer.handshake handshake.mutex.Lock() handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) - ssIsZero := isZero(handshake.precomputedStaticStatic[:]) handshake.remoteStatic = pk handshake.mutex.Unlock() @@ -116,13 +115,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.endpoint = nil - // conditionally add + // add - if !ssIsZero { - device.peers.keyMap[pk] = peer - } else { - return nil, nil - } + device.peers.keyMap[pk] = peer // start peer