From 50aeefcb5198d99777e19f9a0100fe74af630dfb Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 23 Jun 2017 13:41:59 +0200 Subject: [PATCH] Beginning work noise handshake --- src/device.go | 18 ++++ src/kdf_test.go | 76 +++++++++++++ src/noise_helpers.go | 86 +++++++++++++++ src/noise_protocol.go | 179 +++++++++++++++++++++++++++++++ src/noise_test.go | 38 +++++++ src/{noise.go => noise_types.go} | 6 +- src/tai64.go | 23 ++++ 7 files changed, 422 insertions(+), 4 deletions(-) create mode 100644 src/kdf_test.go create mode 100644 src/noise_helpers.go create mode 100644 src/noise_protocol.go create mode 100644 src/noise_test.go rename src/{noise.go => noise_types.go} (83%) create mode 100644 src/tai64.go diff --git a/src/device.go b/src/device.go index d03057d..9f1daa6 100644 --- a/src/device.go +++ b/src/device.go @@ -1,12 +1,17 @@ package main import ( + "math/rand" "sync" ) +/* TODO: Locking may be a little broad here + */ + type Device struct { mutex sync.RWMutex peers map[NoisePublicKey]*Peer + sessions map[uint32]*Handshake privateKey NoisePrivateKey publicKey NoisePublicKey fwMark uint32 @@ -14,6 +19,19 @@ type Device struct { routingTable RoutingTable } +func (dev *Device) NewID(h *Handshake) uint32 { + dev.mutex.Lock() + defer dev.mutex.Unlock() + for { + id := rand.Uint32() + _, ok := dev.sessions[id] + if !ok { + dev.sessions[id] = h + return id + } + } +} + func (dev *Device) RemovePeer(key NoisePublicKey) { dev.mutex.Lock() defer dev.mutex.Unlock() diff --git a/src/kdf_test.go b/src/kdf_test.go new file mode 100644 index 0000000..0cce81d --- /dev/null +++ b/src/kdf_test.go @@ -0,0 +1,76 @@ +package main + +import ( + "encoding/hex" + "testing" +) + +type KDFTest struct { + key string + input string + t0 string + t1 string + t2 string +} + +func assertEquals(t *testing.T, a string, b string) { + if a != b { + t.Fatal("expected", a, "=", b) + } +} + +func TestKDF(t *testing.T) { + tests := []KDFTest{ + { + key: "746573742d6b6579", + input: "746573742d696e707574", + t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633", + t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a", + t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24", + }, + { + key: "776972656775617264", + input: "776972656775617264", + t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8", + t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f", + t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160", + }, + { + key: "", + input: "", + t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0", + t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e", + t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e", + }, + } + + for _, test := range tests { + key, _ := hex.DecodeString(test.key) + input, _ := hex.DecodeString(test.input) + t0, t1, t2 := KDF3(key, input) + t0s := hex.EncodeToString(t0[:]) + t1s := hex.EncodeToString(t1[:]) + t2s := hex.EncodeToString(t2[:]) + assertEquals(t, t0s, test.t0) + assertEquals(t, t1s, test.t1) + assertEquals(t, t2s, test.t2) + } + + for _, test := range tests { + key, _ := hex.DecodeString(test.key) + input, _ := hex.DecodeString(test.input) + t0, t1 := KDF2(key, input) + t0s := hex.EncodeToString(t0[:]) + t1s := hex.EncodeToString(t1[:]) + assertEquals(t, t0s, test.t0) + assertEquals(t, t1s, test.t1) + } + + for _, test := range tests { + key, _ := hex.DecodeString(test.key) + input, _ := hex.DecodeString(test.input) + t0 := KDF1(key, input) + t0s := hex.EncodeToString(t0[:]) + assertEquals(t, t0s, test.t0) + } +} diff --git a/src/noise_helpers.go b/src/noise_helpers.go new file mode 100644 index 0000000..df25011 --- /dev/null +++ b/src/noise_helpers.go @@ -0,0 +1,86 @@ +package main + +import ( + "crypto/hmac" + "crypto/rand" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/curve25519" + "hash" +) + +/* KDF related functions. + * HMAC-based Key Derivation Function (HKDF) + * https://tools.ietf.org/html/rfc5869 + */ + +func HMAC(sum *[blake2s.Size]byte, key []byte, input []byte) { + mac := hmac.New(func() hash.Hash { + h, _ := blake2s.New256(nil) + return h + }, key) + mac.Write(input) + mac.Sum(sum[:0]) +} + +func KDF1(key []byte, input []byte) (t0 [blake2s.Size]byte) { + HMAC(&t0, key, input) + HMAC(&t0, t0[:], []byte{0x1}) + return +} + +func KDF2(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte) { + var prk [blake2s.Size]byte + HMAC(&prk, key, input) + HMAC(&t0, prk[:], []byte{0x1}) + HMAC(&t1, prk[:], append(t0[:], 0x2)) + return +} + +func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte, t2 [blake2s.Size]byte) { + var prk [blake2s.Size]byte + HMAC(&prk, key, input) + HMAC(&t0, prk[:], []byte{0x1}) + HMAC(&t1, prk[:], append(t0[:], 0x2)) + HMAC(&t2, prk[:], append(t1[:], 0x3)) + return +} + +/* + * + */ + +func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte { + return KDF1(c[:], data) +} + +func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte { + return blake2s.Sum256(append(h[:], data...)) +} + +/* Curve25519 wrappers + * + * TODO: Rethink this + */ + +func newPrivateKey() (sk NoisePrivateKey, err error) { + // clamping: https://cr.yp.to/ecdh.html + _, err = rand.Read(sk[:]) + sk[0] &= 248 + sk[31] &= 127 + sk[31] |= 64 + return +} + +func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { + apk := (*[NoisePublicKeySize]byte)(&pk) + ask := (*[NoisePrivateKeySize]byte)(sk) + curve25519.ScalarBaseMult(apk, ask) + return +} + +func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { + apk := (*[NoisePublicKeySize]byte)(&pk) + ask := (*[NoisePrivateKeySize]byte)(sk) + curve25519.ScalarMult(&ss, apk, ask) + return ss +} diff --git a/src/noise_protocol.go b/src/noise_protocol.go new file mode 100644 index 0000000..e7c8774 --- /dev/null +++ b/src/noise_protocol.go @@ -0,0 +1,179 @@ +package main + +import ( + "errors" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/poly1305" + "sync" +) + +const ( + HandshakeInitialCreated = iota + HandshakeInitialConsumed + HandshakeResponseCreated +) + +const ( + NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" + WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" + WGLabelMAC1 = "mac1----" + WGLabelCookie = "cookie--" +) + +const ( + MessageInitalType = 1 + MessageResponseType = 2 + MessageCookieResponseType = 3 + MessageTransportType = 4 +) + +type MessageInital struct { + Type uint32 + Sender uint32 + Ephemeral NoisePublicKey + Static [NoisePublicKeySize + poly1305.TagSize]byte + Timestamp [TAI64NSize + poly1305.TagSize]byte + Mac1 [blake2s.Size128]byte + Mac2 [blake2s.Size128]byte +} + +type MessageResponse struct { + Type uint32 + Sender uint32 + Reciever uint32 + Ephemeral NoisePublicKey + Empty [poly1305.TagSize]byte + Mac1 [blake2s.Size128]byte + Mac2 [blake2s.Size128]byte +} + +type MessageTransport struct { + Type uint32 + Reciever uint32 + Counter uint64 + Content []byte +} + +type Handshake struct { + lock sync.Mutex + state int + chainKey [blake2s.Size]byte // chain key + hash [blake2s.Size]byte // hash value + staticStatic NoisePublicKey // precomputed DH(S_i, S_r) + ephemeral NoisePrivateKey // ephemeral secret key + remoteIndex uint32 // index for sending + device *Device + peer *Peer +} + +var ( + ZeroNonce [chacha20poly1305.NonceSize]byte + InitalChainKey [blake2s.Size]byte + InitalHash [blake2s.Size]byte +) + +func init() { + InitalChainKey = blake2s.Sum256([]byte(NoiseConstruction)) + InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...)) +} + +func (h *Handshake) Precompute() { + h.staticStatic = h.device.privateKey.sharedSecret(h.peer.publicKey) +} + +func (h *Handshake) ConsumeMessageResponse(msg *MessageResponse) { + +} + +func (h *Handshake) addHash(data []byte) { + h.hash = addToHash(h.hash, data) +} + +func (h *Handshake) addChain(data []byte) { + h.chainKey = addToChainKey(h.chainKey, data) +} + +func (h *Handshake) CreateMessageInital() (*MessageInital, error) { + h.lock.Lock() + defer h.lock.Unlock() + + // reset handshake + + var err error + h.ephemeral, err = newPrivateKey() + if err != nil { + return nil, err + } + h.chainKey = InitalChainKey + h.hash = addToHash(InitalHash, h.device.publicKey[:]) + + // create ephemeral key + + var msg MessageInital + msg.Type = MessageInitalType + msg.Sender = h.device.NewID(h) + msg.Ephemeral = h.ephemeral.publicKey() + h.chainKey = addToChainKey(h.chainKey, msg.Ephemeral[:]) + h.hash = addToHash(h.hash, msg.Ephemeral[:]) + + // encrypt long-term "identity key" + + func() { + var key [chacha20poly1305.KeySize]byte + ss := h.ephemeral.sharedSecret(h.peer.publicKey) + h.chainKey, key = KDF2(h.chainKey[:], ss[:]) + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Static[:0], ZeroNonce[:], h.device.publicKey[:], nil) + }() + h.addHash(msg.Static[:]) + + // encrypt timestamp + + timestamp := Timestamp() + func() { + var key [chacha20poly1305.KeySize]byte + h.chainKey, key = KDF2(h.chainKey[:], h.staticStatic[:]) + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], nil) + }() + h.addHash(msg.Timestamp[:]) + h.state = HandshakeInitialCreated + return &msg, nil +} + +func (h *Handshake) ConsumeMessageInitial(msg *MessageInital) error { + if msg.Type != MessageInitalType { + panic(errors.New("bug: invalid inital message type")) + } + + hash := addToHash(InitalHash, h.device.publicKey[:]) + chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:]) + hash = addToHash(hash, msg.Ephemeral[:]) + + // + + ephemeral, err := newPrivateKey() + if err != nil { + return err + } + + // update handshake state + + h.lock.Lock() + defer h.lock.Unlock() + + h.hash = hash + h.chainKey = chainKey + h.remoteIndex = msg.Sender + h.ephemeral = ephemeral + h.state = HandshakeInitialConsumed + + return nil + +} + +func (h *Handshake) CreateMessageResponse() []byte { + + return nil +} diff --git a/src/noise_test.go b/src/noise_test.go new file mode 100644 index 0000000..b3ea54f --- /dev/null +++ b/src/noise_test.go @@ -0,0 +1,38 @@ +package main + +import ( + "testing" +) + +func TestHandshake(t *testing.T) { + var dev1 Device + var dev2 Device + + var err error + + dev1.privateKey, err = newPrivateKey() + if err != nil { + t.Fatal(err) + } + + dev2.privateKey, err = newPrivateKey() + if err != nil { + t.Fatal(err) + } + + var peer1 Peer + var peer2 Peer + + peer1.publicKey = dev1.privateKey.publicKey() + peer2.publicKey = dev2.privateKey.publicKey() + + var handshake1 Handshake + var handshake2 Handshake + + handshake1.device = &dev1 + handshake2.device = &dev2 + + handshake1.peer = &peer2 + handshake2.peer = &peer1 + +} diff --git a/src/noise.go b/src/noise_types.go similarity index 83% rename from src/noise.go rename to src/noise_types.go index 5508f9a..6dae6b2 100644 --- a/src/noise.go +++ b/src/noise_types.go @@ -12,10 +12,8 @@ const ( ) type ( - NoisePublicKey [NoisePublicKeySize]byte - NoisePrivateKey [NoisePrivateKeySize]byte - NoiseSymmetricKey [NoiseSymmetricKeySize]byte - NoiseNonce uint64 // padded to 12-bytes + NoisePublicKey [NoisePublicKeySize]byte + NoisePrivateKey [NoisePrivateKeySize]byte ) func loadExactHex(dst []byte, src string) error { diff --git a/src/tai64.go b/src/tai64.go new file mode 100644 index 0000000..d0d1432 --- /dev/null +++ b/src/tai64.go @@ -0,0 +1,23 @@ +package main + +import ( + "encoding/binary" + "time" +) + +const ( + TAI64NBase = uint64(4611686018427387914) + TAI64NSize = 12 +) + +type TAI64N [TAI64NSize]byte + +func Timestamp() TAI64N { + var tai64n TAI64N + now := time.Now() + secs := TAI64NBase + uint64(now.Unix()) + nano := uint32(now.UnixNano()) + binary.BigEndian.PutUint64(tai64n[:], secs) + binary.BigEndian.PutUint32(tai64n[8:], nano) + return tai64n +}