Completed noise handshake
This commit is contained in:
parent
25190e4336
commit
cf3a5130d3
17
src/index.go
17
src/index.go
|
@ -6,13 +6,15 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Index=0 is reserved for unset indecies
|
/* Index=0 is reserved for unset indecies
|
||||||
|
*
|
||||||
|
* TODO: Rethink map[id] -> peer VS map[id] -> handshake and handshake <ref> peer
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
type IndexTable struct {
|
type IndexTable struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
keypairs map[uint32]*KeyPair
|
keypairs map[uint32]*KeyPair
|
||||||
handshakes map[uint32]*Handshake
|
handshakes map[uint32]*Peer
|
||||||
}
|
}
|
||||||
|
|
||||||
func randUint32() (uint32, error) {
|
func randUint32() (uint32, error) {
|
||||||
|
@ -32,10 +34,10 @@ func (table *IndexTable) Init() {
|
||||||
table.mutex.Lock()
|
table.mutex.Lock()
|
||||||
defer table.mutex.Unlock()
|
defer table.mutex.Unlock()
|
||||||
table.keypairs = make(map[uint32]*KeyPair)
|
table.keypairs = make(map[uint32]*KeyPair)
|
||||||
table.handshakes = make(map[uint32]*Handshake)
|
table.handshakes = make(map[uint32]*Peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
|
func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
|
||||||
table.mutex.Lock()
|
table.mutex.Lock()
|
||||||
defer table.mutex.Unlock()
|
defer table.mutex.Unlock()
|
||||||
for {
|
for {
|
||||||
|
@ -60,11 +62,10 @@ func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// update the index
|
// clean old index
|
||||||
|
|
||||||
delete(table.handshakes, handshake.localIndex)
|
delete(table.handshakes, peer.handshake.localIndex)
|
||||||
handshake.localIndex = id
|
table.handshakes[id] = peer
|
||||||
table.handshakes[id] = handshake
|
|
||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -75,7 +76,7 @@ func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair {
|
||||||
return table.keypairs[id]
|
return table.keypairs[id]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *IndexTable) LookupHandshake(id uint32) *Handshake {
|
func (table *IndexTable) LookupHandshake(id uint32) *Peer {
|
||||||
table.mutex.RLock()
|
table.mutex.RLock()
|
||||||
defer table.mutex.RUnlock()
|
defer table.mutex.RUnlock()
|
||||||
return table.handshakes[id]
|
return table.handshakes[id]
|
||||||
|
|
|
@ -5,8 +5,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type KeyPair struct {
|
type KeyPair struct {
|
||||||
recieveKey cipher.AEAD
|
recv cipher.AEAD
|
||||||
recieveNonce NoiseNonce
|
recvNonce NoiseNonce
|
||||||
sendKey cipher.AEAD
|
send cipher.AEAD
|
||||||
sendNonce NoiseNonce
|
sendNonce NoiseNonce
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,22 +45,7 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/* curve25519 wrappers */
|
||||||
*
|
|
||||||
*/
|
|
||||||
|
|
||||||
func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
|
|
||||||
return KDF1(c[:], data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
|
|
||||||
return blake2s.Sum256(append(h[:], data...))
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Curve25519 wrappers
|
|
||||||
*
|
|
||||||
* TODO: Rethink this
|
|
||||||
*/
|
|
||||||
|
|
||||||
func newPrivateKey() (sk NoisePrivateKey, err error) {
|
func newPrivateKey() (sk NoisePrivateKey, err error) {
|
||||||
// clamping: https://cr.yp.to/ecdh.html
|
// clamping: https://cr.yp.to/ecdh.html
|
||||||
|
|
|
@ -9,9 +9,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
HandshakeInitialCreated = iota
|
HandshakeReset = iota
|
||||||
|
HandshakeInitialCreated
|
||||||
HandshakeInitialConsumed
|
HandshakeInitialConsumed
|
||||||
HandshakeResponseCreated
|
HandshakeResponseCreated
|
||||||
|
HandshakeResponseConsumed
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -71,7 +73,6 @@ type Handshake struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
EmptyMessage []byte
|
|
||||||
ZeroNonce [chacha20poly1305.NonceSize]byte
|
ZeroNonce [chacha20poly1305.NonceSize]byte
|
||||||
InitalChainKey [blake2s.Size]byte
|
InitalChainKey [blake2s.Size]byte
|
||||||
InitalHash [blake2s.Size]byte
|
InitalHash [blake2s.Size]byte
|
||||||
|
@ -82,6 +83,14 @@ func init() {
|
||||||
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
|
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
|
||||||
|
return KDF1(c[:], data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
|
||||||
|
return blake2s.Sum256(append(h[:], data...))
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handshake) addToHash(data []byte) {
|
func (h *Handshake) addToHash(data []byte) {
|
||||||
h.hash = addToHash(h.hash, data)
|
h.hash = addToHash(h.hash, data)
|
||||||
}
|
}
|
||||||
|
@ -90,11 +99,6 @@ func (h *Handshake) addToChainKey(data []byte) {
|
||||||
h.chainKey = addToChainKey(h.chainKey, data)
|
h.chainKey = addToChainKey(h.chainKey, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) Precompute(peer *Peer) {
|
|
||||||
h := &peer.handshake
|
|
||||||
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
|
func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
|
@ -116,16 +120,17 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
|
||||||
|
|
||||||
msg.Type = MessageInitalType
|
msg.Type = MessageInitalType
|
||||||
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
||||||
msg.Sender, err = device.indices.NewIndex(handshake)
|
handshake.localIndex, err = device.indices.NewIndex(peer)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
msg.Sender = handshake.localIndex
|
||||||
handshake.addToChainKey(msg.Ephemeral[:])
|
handshake.addToChainKey(msg.Ephemeral[:])
|
||||||
handshake.addToHash(msg.Ephemeral[:])
|
handshake.addToHash(msg.Ephemeral[:])
|
||||||
|
|
||||||
// encrypt long-term "identity key"
|
// encrypt identity key
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
var key [chacha20poly1305.KeySize]byte
|
var key [chacha20poly1305.KeySize]byte
|
||||||
|
@ -221,6 +226,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
|
||||||
handshake.chainKey = chainKey
|
handshake.chainKey = chainKey
|
||||||
handshake.remoteIndex = msg.Sender
|
handshake.remoteIndex = msg.Sender
|
||||||
handshake.remoteEphemeral = msg.Ephemeral
|
handshake.remoteEphemeral = msg.Ephemeral
|
||||||
|
handshake.lastTimestamp = timestamp
|
||||||
handshake.state = HandshakeInitialConsumed
|
handshake.state = HandshakeInitialConsumed
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
@ -237,14 +243,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||||
// assign index
|
// assign index
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
var msg MessageResponse
|
handshake.localIndex, err = device.indices.NewIndex(peer)
|
||||||
msg.Type = MessageResponseType
|
|
||||||
msg.Sender, err = device.indices.NewIndex(handshake)
|
|
||||||
msg.Reciever = handshake.remoteIndex
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var msg MessageResponse
|
||||||
|
msg.Type = MessageResponseType
|
||||||
|
msg.Sender = handshake.localIndex
|
||||||
|
msg.Reciever = handshake.remoteIndex
|
||||||
|
|
||||||
// create ephemeral key
|
// create ephemeral key
|
||||||
|
|
||||||
handshake.localEphemeral, err = newPrivateKey()
|
handshake.localEphemeral, err = newPrivateKey()
|
||||||
|
@ -252,6 +260,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
||||||
|
handshake.addToHash(msg.Ephemeral[:])
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
||||||
|
@ -269,9 +278,97 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:])
|
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
||||||
handshake.addToHash(msg.Empty[:])
|
handshake.addToHash(msg.Empty[:])
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
handshake.state = HandshakeResponseCreated
|
||||||
return &msg, nil
|
return &msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
|
if msg.Type != MessageResponseType {
|
||||||
|
panic(errors.New("bug: invalid message type"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookup handshake by reciever
|
||||||
|
|
||||||
|
peer := device.indices.LookupHandshake(msg.Reciever)
|
||||||
|
if peer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
handshake := &peer.handshake
|
||||||
|
handshake.mutex.Lock()
|
||||||
|
defer handshake.mutex.Unlock()
|
||||||
|
if handshake.state != HandshakeInitialCreated {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// finish 3-way DH
|
||||||
|
|
||||||
|
hash := addToHash(handshake.hash, msg.Ephemeral[:])
|
||||||
|
chainKey := handshake.chainKey
|
||||||
|
|
||||||
|
func() {
|
||||||
|
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
||||||
|
chainKey = addToChainKey(chainKey, ss[:])
|
||||||
|
ss = device.privateKey.sharedSecret(msg.Ephemeral)
|
||||||
|
chainKey = addToChainKey(chainKey, ss[:])
|
||||||
|
}()
|
||||||
|
|
||||||
|
// add preshared key (psk)
|
||||||
|
|
||||||
|
var tau [blake2s.Size]byte
|
||||||
|
var key [chacha20poly1305.KeySize]byte
|
||||||
|
chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
|
||||||
|
hash = addToHash(hash, tau[:])
|
||||||
|
|
||||||
|
// authenticate
|
||||||
|
|
||||||
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
|
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
hash = addToHash(hash, msg.Empty[:])
|
||||||
|
|
||||||
|
// update handshake state
|
||||||
|
|
||||||
|
handshake.hash = hash
|
||||||
|
handshake.chainKey = chainKey
|
||||||
|
handshake.remoteIndex = msg.Sender
|
||||||
|
handshake.state = HandshakeResponseConsumed
|
||||||
|
|
||||||
|
return peer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) NewKeyPair() *KeyPair {
|
||||||
|
handshake := &peer.handshake
|
||||||
|
handshake.mutex.Lock()
|
||||||
|
defer handshake.mutex.Unlock()
|
||||||
|
|
||||||
|
// derive keys
|
||||||
|
|
||||||
|
var sendKey [chacha20poly1305.KeySize]byte
|
||||||
|
var recvKey [chacha20poly1305.KeySize]byte
|
||||||
|
|
||||||
|
if handshake.state == HandshakeResponseConsumed {
|
||||||
|
sendKey, recvKey = KDF2(handshake.chainKey[:], nil)
|
||||||
|
} else if handshake.state == HandshakeResponseCreated {
|
||||||
|
recvKey, sendKey = KDF2(handshake.chainKey[:], nil)
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// create AEAD instances
|
||||||
|
|
||||||
|
var keyPair KeyPair
|
||||||
|
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
|
||||||
|
keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
|
||||||
|
keyPair.sendNonce = 0
|
||||||
|
keyPair.recvNonce = 0
|
||||||
|
|
||||||
|
peer.handshake.state = HandshakeReset
|
||||||
|
|
||||||
|
return &keyPair
|
||||||
|
}
|
||||||
|
|
|
@ -63,7 +63,9 @@ func TestNoiseHandshake(t *testing.T) {
|
||||||
|
|
||||||
/* simulate handshake */
|
/* simulate handshake */
|
||||||
|
|
||||||
// Initiation message
|
// initiation message
|
||||||
|
|
||||||
|
t.Log("exchange initiation message")
|
||||||
|
|
||||||
msg1, err := dev1.CreateMessageInitial(peer2)
|
msg1, err := dev1.CreateMessageInitial(peer2)
|
||||||
assertNil(t, err)
|
assertNil(t, err)
|
||||||
|
@ -88,6 +90,68 @@ func TestNoiseHandshake(t *testing.T) {
|
||||||
peer2.handshake.hash[:],
|
peer2.handshake.hash[:],
|
||||||
)
|
)
|
||||||
|
|
||||||
// Response message
|
// response message
|
||||||
|
|
||||||
|
t.Log("exchange response message")
|
||||||
|
|
||||||
|
msg2, err := dev2.CreateMessageResponse(peer1)
|
||||||
|
assertNil(t, err)
|
||||||
|
|
||||||
|
peer = dev1.ConsumeMessageResponse(msg2)
|
||||||
|
if peer == nil {
|
||||||
|
t.Fatal("handshake failed at response message")
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEqual(
|
||||||
|
t,
|
||||||
|
peer1.handshake.chainKey[:],
|
||||||
|
peer2.handshake.chainKey[:],
|
||||||
|
)
|
||||||
|
|
||||||
|
assertEqual(
|
||||||
|
t,
|
||||||
|
peer1.handshake.hash[:],
|
||||||
|
peer2.handshake.hash[:],
|
||||||
|
)
|
||||||
|
|
||||||
|
// key pairs
|
||||||
|
|
||||||
|
t.Log("deriving keys")
|
||||||
|
|
||||||
|
key1 := peer1.NewKeyPair()
|
||||||
|
key2 := peer2.NewKeyPair()
|
||||||
|
|
||||||
|
if key1 == nil {
|
||||||
|
t.Fatal("failed to dervice key-pair for peer 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
if key2 == nil {
|
||||||
|
t.Fatal("failed to dervice key-pair for peer 2")
|
||||||
|
}
|
||||||
|
|
||||||
|
// encrypting / decryption test
|
||||||
|
|
||||||
|
t.Log("test key pairs")
|
||||||
|
|
||||||
|
func() {
|
||||||
|
testMsg := []byte("wireguard test message 1")
|
||||||
|
var err error
|
||||||
|
var out []byte
|
||||||
|
var nonce [12]byte
|
||||||
|
out = key1.send.Seal(out, nonce[:], testMsg, nil)
|
||||||
|
out, err = key2.recv.Open(out[:0], nonce[:], out, nil)
|
||||||
|
assertNil(t, err)
|
||||||
|
assertEqual(t, out, testMsg)
|
||||||
|
}()
|
||||||
|
|
||||||
|
func() {
|
||||||
|
testMsg := []byte("wireguard test message 2")
|
||||||
|
var err error
|
||||||
|
var out []byte
|
||||||
|
var nonce [12]byte
|
||||||
|
out = key2.send.Seal(out, nonce[:], testMsg, nil)
|
||||||
|
out, err = key1.recv.Open(out[:0], nonce[:], out, nil)
|
||||||
|
assertNil(t, err)
|
||||||
|
assertEqual(t, out, testMsg)
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue