diff --git a/src/index.go b/src/index.go index 83a7e29..81f71e9 100644 --- a/src/index.go +++ b/src/index.go @@ -6,13 +6,15 @@ import ( ) /* Index=0 is reserved for unset indecies + * + * TODO: Rethink map[id] -> peer VS map[id] -> handshake and handshake peer * */ type IndexTable struct { mutex sync.RWMutex keypairs map[uint32]*KeyPair - handshakes map[uint32]*Handshake + handshakes map[uint32]*Peer } func randUint32() (uint32, error) { @@ -32,10 +34,10 @@ func (table *IndexTable) Init() { table.mutex.Lock() defer table.mutex.Unlock() table.keypairs = make(map[uint32]*KeyPair) - table.handshakes = make(map[uint32]*Handshake) + table.handshakes = make(map[uint32]*Peer) } -func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) { +func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) { table.mutex.Lock() defer table.mutex.Unlock() for { @@ -60,11 +62,10 @@ func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) { continue } - // update the index + // clean old index - delete(table.handshakes, handshake.localIndex) - handshake.localIndex = id - table.handshakes[id] = handshake + delete(table.handshakes, peer.handshake.localIndex) + table.handshakes[id] = peer return id, nil } } @@ -75,7 +76,7 @@ func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair { return table.keypairs[id] } -func (table *IndexTable) LookupHandshake(id uint32) *Handshake { +func (table *IndexTable) LookupHandshake(id uint32) *Peer { table.mutex.RLock() defer table.mutex.RUnlock() return table.handshakes[id] diff --git a/src/keypair.go b/src/keypair.go index 22a8244..e434c74 100644 --- a/src/keypair.go +++ b/src/keypair.go @@ -5,8 +5,8 @@ import ( ) type KeyPair struct { - recieveKey cipher.AEAD - recieveNonce NoiseNonce - sendKey cipher.AEAD - sendNonce NoiseNonce + recv cipher.AEAD + recvNonce NoiseNonce + send cipher.AEAD + sendNonce NoiseNonce } diff --git a/src/noise_helpers.go b/src/noise_helpers.go index eadbc07..e163ace 100644 --- a/src/noise_helpers.go +++ b/src/noise_helpers.go @@ -45,22 +45,7 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt return } -/* - * - */ - -func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte { - return KDF1(c[:], data) -} - -func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte { - return blake2s.Sum256(append(h[:], data...)) -} - -/* Curve25519 wrappers - * - * TODO: Rethink this - */ +/* curve25519 wrappers */ func newPrivateKey() (sk NoisePrivateKey, err error) { // clamping: https://cr.yp.to/ecdh.html diff --git a/src/noise_protocol.go b/src/noise_protocol.go index b9c8981..7f26cf1 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -9,9 +9,11 @@ import ( ) const ( - HandshakeInitialCreated = iota + HandshakeReset = iota + HandshakeInitialCreated HandshakeInitialConsumed HandshakeResponseCreated + HandshakeResponseConsumed ) const ( @@ -71,7 +73,6 @@ type Handshake struct { } var ( - EmptyMessage []byte ZeroNonce [chacha20poly1305.NonceSize]byte InitalChainKey [blake2s.Size]byte InitalHash [blake2s.Size]byte @@ -82,6 +83,14 @@ func init() { InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...)) } +func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte { + return KDF1(c[:], data) +} + +func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte { + return blake2s.Sum256(append(h[:], data...)) +} + func (h *Handshake) addToHash(data []byte) { h.hash = addToHash(h.hash, data) } @@ -90,11 +99,6 @@ func (h *Handshake) addToChainKey(data []byte) { h.chainKey = addToChainKey(h.chainKey, data) } -func (device *Device) Precompute(peer *Peer) { - h := &peer.handshake - h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) -} - func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { handshake := &peer.handshake handshake.mutex.Lock() @@ -116,16 +120,17 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { msg.Type = MessageInitalType msg.Ephemeral = handshake.localEphemeral.publicKey() - msg.Sender, err = device.indices.NewIndex(handshake) + handshake.localIndex, err = device.indices.NewIndex(peer) if err != nil { return nil, err } + msg.Sender = handshake.localIndex handshake.addToChainKey(msg.Ephemeral[:]) handshake.addToHash(msg.Ephemeral[:]) - // encrypt long-term "identity key" + // encrypt identity key func() { var key [chacha20poly1305.KeySize]byte @@ -221,6 +226,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { handshake.chainKey = chainKey handshake.remoteIndex = msg.Sender handshake.remoteEphemeral = msg.Ephemeral + handshake.lastTimestamp = timestamp handshake.state = HandshakeInitialConsumed return peer } @@ -237,14 +243,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error // assign index var err error - var msg MessageResponse - msg.Type = MessageResponseType - msg.Sender, err = device.indices.NewIndex(handshake) - msg.Reciever = handshake.remoteIndex + handshake.localIndex, err = device.indices.NewIndex(peer) if err != nil { return nil, err } + var msg MessageResponse + msg.Type = MessageResponseType + msg.Sender = handshake.localIndex + msg.Reciever = handshake.remoteIndex + // create ephemeral key handshake.localEphemeral, err = newPrivateKey() @@ -252,6 +260,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error return nil, err } msg.Ephemeral = handshake.localEphemeral.publicKey() + handshake.addToHash(msg.Ephemeral[:]) func() { ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) @@ -269,9 +278,97 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error func() { aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:]) + aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) handshake.addToHash(msg.Empty[:]) }() + handshake.state = HandshakeResponseCreated return &msg, nil } + +func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { + if msg.Type != MessageResponseType { + panic(errors.New("bug: invalid message type")) + } + + // lookup handshake by reciever + + peer := device.indices.LookupHandshake(msg.Reciever) + if peer == nil { + return nil + } + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + if handshake.state != HandshakeInitialCreated { + return nil + } + + // finish 3-way DH + + hash := addToHash(handshake.hash, msg.Ephemeral[:]) + chainKey := handshake.chainKey + + func() { + ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) + chainKey = addToChainKey(chainKey, ss[:]) + ss = device.privateKey.sharedSecret(msg.Ephemeral) + chainKey = addToChainKey(chainKey, ss[:]) + }() + + // add preshared key (psk) + + var tau [blake2s.Size]byte + var key [chacha20poly1305.KeySize]byte + chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:]) + hash = addToHash(hash, tau[:]) + + // authenticate + + aead, _ := chacha20poly1305.New(key[:]) + _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) + if err != nil { + return nil + } + hash = addToHash(hash, msg.Empty[:]) + + // update handshake state + + handshake.hash = hash + handshake.chainKey = chainKey + handshake.remoteIndex = msg.Sender + handshake.state = HandshakeResponseConsumed + + return peer +} + +func (peer *Peer) NewKeyPair() *KeyPair { + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + // derive keys + + var sendKey [chacha20poly1305.KeySize]byte + var recvKey [chacha20poly1305.KeySize]byte + + if handshake.state == HandshakeResponseConsumed { + sendKey, recvKey = KDF2(handshake.chainKey[:], nil) + } else if handshake.state == HandshakeResponseCreated { + recvKey, sendKey = KDF2(handshake.chainKey[:], nil) + } else { + return nil + } + + // create AEAD instances + + var keyPair KeyPair + keyPair.send, _ = chacha20poly1305.New(sendKey[:]) + keyPair.recv, _ = chacha20poly1305.New(recvKey[:]) + keyPair.sendNonce = 0 + keyPair.recvNonce = 0 + + peer.handshake.state = HandshakeReset + + return &keyPair +} diff --git a/src/noise_test.go b/src/noise_test.go index 8d6a0fa..ddabf8e 100644 --- a/src/noise_test.go +++ b/src/noise_test.go @@ -63,7 +63,9 @@ func TestNoiseHandshake(t *testing.T) { /* simulate handshake */ - // Initiation message + // initiation message + + t.Log("exchange initiation message") msg1, err := dev1.CreateMessageInitial(peer2) assertNil(t, err) @@ -88,6 +90,68 @@ func TestNoiseHandshake(t *testing.T) { peer2.handshake.hash[:], ) - // Response message + // response message + t.Log("exchange response message") + + msg2, err := dev2.CreateMessageResponse(peer1) + assertNil(t, err) + + peer = dev1.ConsumeMessageResponse(msg2) + if peer == nil { + t.Fatal("handshake failed at response message") + } + + assertEqual( + t, + peer1.handshake.chainKey[:], + peer2.handshake.chainKey[:], + ) + + assertEqual( + t, + peer1.handshake.hash[:], + peer2.handshake.hash[:], + ) + + // key pairs + + t.Log("deriving keys") + + key1 := peer1.NewKeyPair() + key2 := peer2.NewKeyPair() + + if key1 == nil { + t.Fatal("failed to dervice key-pair for peer 1") + } + + if key2 == nil { + t.Fatal("failed to dervice key-pair for peer 2") + } + + // encrypting / decryption test + + t.Log("test key pairs") + + func() { + testMsg := []byte("wireguard test message 1") + var err error + var out []byte + var nonce [12]byte + out = key1.send.Seal(out, nonce[:], testMsg, nil) + out, err = key2.recv.Open(out[:0], nonce[:], out, nil) + assertNil(t, err) + assertEqual(t, out, testMsg) + }() + + func() { + testMsg := []byte("wireguard test message 2") + var err error + var out []byte + var nonce [12]byte + out = key2.send.Seal(out, nonce[:], testMsg, nil) + out, err = key1.recv.Open(out[:0], nonce[:], out, nil) + assertNil(t, err) + assertEqual(t, out, testMsg) + }() }