Completed noise handshake

This commit is contained in:
Mathias Hall-Andersen 2017-06-24 22:03:52 +02:00
parent 25190e4336
commit cf3a5130d3
5 changed files with 191 additions and 44 deletions

View file

@ -6,13 +6,15 @@ import (
) )
/* Index=0 is reserved for unset indecies /* Index=0 is reserved for unset indecies
*
* TODO: Rethink map[id] -> peer VS map[id] -> handshake and handshake <ref> peer
* *
*/ */
type IndexTable struct { type IndexTable struct {
mutex sync.RWMutex mutex sync.RWMutex
keypairs map[uint32]*KeyPair keypairs map[uint32]*KeyPair
handshakes map[uint32]*Handshake handshakes map[uint32]*Peer
} }
func randUint32() (uint32, error) { func randUint32() (uint32, error) {
@ -32,10 +34,10 @@ func (table *IndexTable) Init() {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()
table.keypairs = make(map[uint32]*KeyPair) table.keypairs = make(map[uint32]*KeyPair)
table.handshakes = make(map[uint32]*Handshake) table.handshakes = make(map[uint32]*Peer)
} }
func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) { func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()
for { for {
@ -60,11 +62,10 @@ func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
continue continue
} }
// update the index // clean old index
delete(table.handshakes, handshake.localIndex) delete(table.handshakes, peer.handshake.localIndex)
handshake.localIndex = id table.handshakes[id] = peer
table.handshakes[id] = handshake
return id, nil return id, nil
} }
} }
@ -75,7 +76,7 @@ func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair {
return table.keypairs[id] return table.keypairs[id]
} }
func (table *IndexTable) LookupHandshake(id uint32) *Handshake { func (table *IndexTable) LookupHandshake(id uint32) *Peer {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
return table.handshakes[id] return table.handshakes[id]

View file

@ -5,8 +5,8 @@ import (
) )
type KeyPair struct { type KeyPair struct {
recieveKey cipher.AEAD recv cipher.AEAD
recieveNonce NoiseNonce recvNonce NoiseNonce
sendKey cipher.AEAD send cipher.AEAD
sendNonce NoiseNonce sendNonce NoiseNonce
} }

View file

@ -45,22 +45,7 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
return return
} }
/* /* curve25519 wrappers */
*
*/
func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
return KDF1(c[:], data)
}
func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
return blake2s.Sum256(append(h[:], data...))
}
/* Curve25519 wrappers
*
* TODO: Rethink this
*/
func newPrivateKey() (sk NoisePrivateKey, err error) { func newPrivateKey() (sk NoisePrivateKey, err error) {
// clamping: https://cr.yp.to/ecdh.html // clamping: https://cr.yp.to/ecdh.html

View file

@ -9,9 +9,11 @@ import (
) )
const ( const (
HandshakeInitialCreated = iota HandshakeReset = iota
HandshakeInitialCreated
HandshakeInitialConsumed HandshakeInitialConsumed
HandshakeResponseCreated HandshakeResponseCreated
HandshakeResponseConsumed
) )
const ( const (
@ -71,7 +73,6 @@ type Handshake struct {
} }
var ( var (
EmptyMessage []byte
ZeroNonce [chacha20poly1305.NonceSize]byte ZeroNonce [chacha20poly1305.NonceSize]byte
InitalChainKey [blake2s.Size]byte InitalChainKey [blake2s.Size]byte
InitalHash [blake2s.Size]byte InitalHash [blake2s.Size]byte
@ -82,6 +83,14 @@ func init() {
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...)) InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
} }
func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
return KDF1(c[:], data)
}
func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
return blake2s.Sum256(append(h[:], data...))
}
func (h *Handshake) addToHash(data []byte) { func (h *Handshake) addToHash(data []byte) {
h.hash = addToHash(h.hash, data) h.hash = addToHash(h.hash, data)
} }
@ -90,11 +99,6 @@ func (h *Handshake) addToChainKey(data []byte) {
h.chainKey = addToChainKey(h.chainKey, data) h.chainKey = addToChainKey(h.chainKey, data)
} }
func (device *Device) Precompute(peer *Peer) {
h := &peer.handshake
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
}
func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
@ -116,16 +120,17 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
msg.Type = MessageInitalType msg.Type = MessageInitalType
msg.Ephemeral = handshake.localEphemeral.publicKey() msg.Ephemeral = handshake.localEphemeral.publicKey()
msg.Sender, err = device.indices.NewIndex(handshake) handshake.localIndex, err = device.indices.NewIndex(peer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
msg.Sender = handshake.localIndex
handshake.addToChainKey(msg.Ephemeral[:]) handshake.addToChainKey(msg.Ephemeral[:])
handshake.addToHash(msg.Ephemeral[:]) handshake.addToHash(msg.Ephemeral[:])
// encrypt long-term "identity key" // encrypt identity key
func() { func() {
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
@ -221,6 +226,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
handshake.chainKey = chainKey handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral handshake.remoteEphemeral = msg.Ephemeral
handshake.lastTimestamp = timestamp
handshake.state = HandshakeInitialConsumed handshake.state = HandshakeInitialConsumed
return peer return peer
} }
@ -237,14 +243,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
// assign index // assign index
var err error var err error
var msg MessageResponse handshake.localIndex, err = device.indices.NewIndex(peer)
msg.Type = MessageResponseType
msg.Sender, err = device.indices.NewIndex(handshake)
msg.Reciever = handshake.remoteIndex
if err != nil { if err != nil {
return nil, err return nil, err
} }
var msg MessageResponse
msg.Type = MessageResponseType
msg.Sender = handshake.localIndex
msg.Reciever = handshake.remoteIndex
// create ephemeral key // create ephemeral key
handshake.localEphemeral, err = newPrivateKey() handshake.localEphemeral, err = newPrivateKey()
@ -252,6 +260,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
return nil, err return nil, err
} }
msg.Ephemeral = handshake.localEphemeral.publicKey() msg.Ephemeral = handshake.localEphemeral.publicKey()
handshake.addToHash(msg.Ephemeral[:])
func() { func() {
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
@ -269,9 +278,97 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
func() { func() {
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:]) aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
handshake.addToHash(msg.Empty[:]) handshake.addToHash(msg.Empty[:])
}() }()
handshake.state = HandshakeResponseCreated
return &msg, nil return &msg, nil
} }
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
if msg.Type != MessageResponseType {
panic(errors.New("bug: invalid message type"))
}
// lookup handshake by reciever
peer := device.indices.LookupHandshake(msg.Reciever)
if peer == nil {
return nil
}
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
if handshake.state != HandshakeInitialCreated {
return nil
}
// finish 3-way DH
hash := addToHash(handshake.hash, msg.Ephemeral[:])
chainKey := handshake.chainKey
func() {
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
chainKey = addToChainKey(chainKey, ss[:])
ss = device.privateKey.sharedSecret(msg.Ephemeral)
chainKey = addToChainKey(chainKey, ss[:])
}()
// add preshared key (psk)
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
hash = addToHash(hash, tau[:])
// authenticate
aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
return nil
}
hash = addToHash(hash, msg.Empty[:])
// update handshake state
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.state = HandshakeResponseConsumed
return peer
}
func (peer *Peer) NewKeyPair() *KeyPair {
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
// derive keys
var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte
if handshake.state == HandshakeResponseConsumed {
sendKey, recvKey = KDF2(handshake.chainKey[:], nil)
} else if handshake.state == HandshakeResponseCreated {
recvKey, sendKey = KDF2(handshake.chainKey[:], nil)
} else {
return nil
}
// create AEAD instances
var keyPair KeyPair
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
keyPair.sendNonce = 0
keyPair.recvNonce = 0
peer.handshake.state = HandshakeReset
return &keyPair
}

View file

@ -63,7 +63,9 @@ func TestNoiseHandshake(t *testing.T) {
/* simulate handshake */ /* simulate handshake */
// Initiation message // initiation message
t.Log("exchange initiation message")
msg1, err := dev1.CreateMessageInitial(peer2) msg1, err := dev1.CreateMessageInitial(peer2)
assertNil(t, err) assertNil(t, err)
@ -88,6 +90,68 @@ func TestNoiseHandshake(t *testing.T) {
peer2.handshake.hash[:], peer2.handshake.hash[:],
) )
// Response message // response message
t.Log("exchange response message")
msg2, err := dev2.CreateMessageResponse(peer1)
assertNil(t, err)
peer = dev1.ConsumeMessageResponse(msg2)
if peer == nil {
t.Fatal("handshake failed at response message")
}
assertEqual(
t,
peer1.handshake.chainKey[:],
peer2.handshake.chainKey[:],
)
assertEqual(
t,
peer1.handshake.hash[:],
peer2.handshake.hash[:],
)
// key pairs
t.Log("deriving keys")
key1 := peer1.NewKeyPair()
key2 := peer2.NewKeyPair()
if key1 == nil {
t.Fatal("failed to dervice key-pair for peer 1")
}
if key2 == nil {
t.Fatal("failed to dervice key-pair for peer 2")
}
// encrypting / decryption test
t.Log("test key pairs")
func() {
testMsg := []byte("wireguard test message 1")
var err error
var out []byte
var nonce [12]byte
out = key1.send.Seal(out, nonce[:], testMsg, nil)
out, err = key2.recv.Open(out[:0], nonce[:], out, nil)
assertNil(t, err)
assertEqual(t, out, testMsg)
}()
func() {
testMsg := []byte("wireguard test message 2")
var err error
var out []byte
var nonce [12]byte
out = key2.send.Seal(out, nonce[:], testMsg, nil)
out, err = key1.recv.Open(out[:0], nonce[:], out, nil)
assertNil(t, err)
assertEqual(t, out, testMsg)
}()
} }