diff --git a/noise-protocol.go b/noise-protocol.go index 82d553e..f72dcc4 100644 --- a/noise-protocol.go +++ b/noise-protocol.go @@ -319,6 +319,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { handshake.mutex.Unlock() + setZero(hash[:]) + setZero(chainKey[:]) + return peer } @@ -362,7 +365,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mixKey(ss[:]) }() - // add preshared key (psk) + // add preshared key var tau [blake2s.Size]byte var key [chacha20poly1305.KeySize]byte @@ -457,7 +460,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { aead, _ := chacha20poly1305.New(key[:]) _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) if err != nil { - device.log.Debug.Println("failed to open") return false } mixHash(&hash, &hash, msg.Empty[:]) @@ -485,10 +487,10 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { return lookup.peer } -/* Derives a new key-pair from the current handshake state +/* Derives a new keypair from the current handshake state * */ -func (peer *Peer) NewKeypair() *Keypair { +func (peer *Peer) DeriveNewKeypair() error { device := peer.device handshake := &peer.handshake handshake.mutex.Lock() @@ -517,12 +519,13 @@ func (peer *Peer) NewKeypair() *Keypair { ) isInitiator = false } else { - return nil + return errors.New("invalid state for keypair derivation") } // zero handshake setZero(handshake.chainKey[:]) + setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. setZero(handshake.localEphemeral[:]) peer.handshake.state = HandshakeZeroed @@ -576,5 +579,23 @@ func (peer *Peer) NewKeypair() *Keypair { } kp.mutex.Unlock() - return keypair + return nil +} + +func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { + kp := &peer.keypairs + if kp.next != receivedKeypair { + return false + } + kp.mutex.Lock() + defer kp.mutex.Unlock() + if kp.next != receivedKeypair { + return false + } + old := kp.previous + kp.previous = kp.current + peer.device.DeleteKeypair(old) + kp.current = kp.next + kp.next = nil + return true } diff --git a/noise_test.go b/noise_test.go index 37bfb94..ce32097 100644 --- a/noise_test.go +++ b/noise_test.go @@ -102,15 +102,15 @@ func TestNoiseHandshake(t *testing.T) { t.Log("deriving keys") - key1 := peer1.NewKeypair() - key2 := peer2.NewKeypair() + key1 := peer1.DeriveNewKeypair() + key2 := peer2.DeriveNewKeypair() if key1 == nil { - t.Fatal("failed to dervice key-pair for peer 1") + t.Fatal("failed to dervice keypair for peer 1") } if key2 == nil { - t.Fatal("failed to dervice key-pair for peer 2") + t.Fatal("failed to dervice keypair for peer 2") } // encrypting / decryption test diff --git a/receive.go b/receive.go index 32ff512..64253e6 100644 --- a/receive.go +++ b/receive.go @@ -189,7 +189,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { continue } - // check key-pair expiry + // check keypair expiry if keypair.created.Add(RejectAfterTime).Before(time.Now()) { continue @@ -475,7 +475,7 @@ func (device *Device) RoutineHandshake() { continue } - if peer.NewKeypair() == nil { + if peer.DeriveNewKeypair() != nil { continue } @@ -532,9 +532,9 @@ func (device *Device) RoutineHandshake() { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketReceived() - // derive key-pair + // derive keypair - if peer.NewKeypair() == nil { + if peer.DeriveNewKeypair() != nil { continue } @@ -597,25 +597,12 @@ func (peer *Peer) RoutineSequentialReceiver() { peer.endpoint = elem.endpoint peer.mutex.Unlock() - // check if using new key-pair - - kp := &peer.keypairs - if kp.next == elem.keypair { - kp.mutex.Lock() - if kp.next != elem.keypair { - kp.mutex.Unlock() - } else { - old := kp.previous - kp.previous = kp.current - device.DeleteKeypair(old) - kp.current = kp.next - kp.next = nil - kp.mutex.Unlock() - peer.timersHandshakeComplete() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: - } + // check if using new keypair + if peer.ReceivedWithKeypair(elem.keypair) { + peer.timersHandshakeComplete() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: } } diff --git a/send.go b/send.go index 35e0d00..a8ec28c 100644 --- a/send.go +++ b/send.go @@ -47,7 +47,7 @@ type QueueOutboundElement struct { buffer *[MaxMessageSize]byte // slice holding the packet data packet []byte // slice of "buffer" (always!) nonce uint64 // nonce for encryption - keypair *Keypair // key-pair for encryption + keypair *Keypair // keypair for encryption peer *Peer // related peer } @@ -306,11 +306,11 @@ func (peer *Peer) RoutineNonce() { peer.SendHandshakeInitiation(false) - logDebug.Println(peer, ": Awaiting key-pair") + logDebug.Println(peer, ": Awaiting keypair") select { case <-peer.signals.newKeypairArrived: - logDebug.Println(peer, ": Obtained awaited key-pair") + logDebug.Println(peer, ": Obtained awaited keypair") case <-peer.signals.flushNonceQueue: for { select {