From 25190e43369a79dc77a740dc8cd28b8a9fcb235e Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 24 Jun 2017 15:34:17 +0200 Subject: [PATCH] Restructuring of noise impl. --- src/config.go | 12 +-- src/device.go | 75 ++++++++------ src/index.go | 82 ++++++++++++++++ src/keypair.go | 12 +++ src/noise_helpers.go | 2 +- src/noise_protocol.go | 222 ++++++++++++++++++++++++++++++------------ src/noise_test.go | 91 +++++++++++++---- src/peer.go | 40 +++++--- src/routing.go | 7 ++ src/tai64.go | 5 + 10 files changed, 420 insertions(+), 128 deletions(-) create mode 100644 src/index.go create mode 100644 src/keypair.go diff --git a/src/config.go b/src/config.go index a61b940..8865194 100644 --- a/src/config.go +++ b/src/config.go @@ -99,11 +99,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { if ok { peer = found } else { - newPeer := &Peer{ - publicKey: pubKey, - } - peer = newPeer - device.peers[pubKey] = newPeer + peer = device.NewPeer(pubKey) } case "replace_peers": @@ -125,14 +121,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "remove": peer.mutex.Lock() - device.RemovePeer(peer.publicKey) + // device.RemovePeer(peer.publicKey) peer = nil case "preshared_key": err := func() error { peer.mutex.Lock() defer peer.mutex.Unlock() - return peer.presharedKey.FromHex(value) + return peer.handshake.presharedKey.FromHex(value) }() if err != nil { return &IPCError{Code: ipcErrorInvalidPublicKey} @@ -144,7 +140,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return &IPCError{Code: ipcErrorInvalidIPAddress} } peer.mutex.Lock() - peer.endpoint = ip + // peer.endpoint = ip FIX peer.mutex.Unlock() case "persistent_keepalive_interval": diff --git a/src/device.go b/src/device.go index 9f1daa6..9969034 100644 --- a/src/device.go +++ b/src/device.go @@ -1,17 +1,13 @@ package main import ( - "math/rand" "sync" ) -/* TODO: Locking may be a little broad here - */ - type Device struct { mutex sync.RWMutex peers map[NoisePublicKey]*Peer - sessions map[uint32]*Handshake + indices IndexTable privateKey NoisePrivateKey publicKey NoisePublicKey fwMark uint32 @@ -19,43 +15,66 @@ type Device struct { routingTable RoutingTable } -func (dev *Device) NewID(h *Handshake) uint32 { - dev.mutex.Lock() - defer dev.mutex.Unlock() - for { - id := rand.Uint32() - _, ok := dev.sessions[id] - if !ok { - dev.sessions[id] = h - return id - } +func (device *Device) SetPrivateKey(sk NoisePrivateKey) { + device.mutex.Lock() + defer device.mutex.Unlock() + + // update key material + + device.privateKey = sk + device.publicKey = sk.publicKey() + + // do precomputations + + for _, peer := range device.peers { + h := &peer.handshake + h.mutex.Lock() + h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) + h.mutex.Unlock() } } -func (dev *Device) RemovePeer(key NoisePublicKey) { - dev.mutex.Lock() - defer dev.mutex.Unlock() - peer, ok := dev.peers[key] +func (device *Device) Init() { + device.mutex.Lock() + defer device.mutex.Unlock() + + device.peers = make(map[NoisePublicKey]*Peer) + device.indices.Init() + device.listenPort = 0 + device.routingTable.Reset() +} + +func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { + device.mutex.RLock() + defer device.mutex.RUnlock() + return device.peers[pk] +} + +func (device *Device) RemovePeer(key NoisePublicKey) { + device.mutex.Lock() + defer device.mutex.Unlock() + + peer, ok := device.peers[key] if !ok { return } peer.mutex.Lock() - dev.routingTable.RemovePeer(peer) - delete(dev.peers, key) + device.routingTable.RemovePeer(peer) + delete(device.peers, key) } -func (dev *Device) RemoveAllAllowedIps(peer *Peer) { +func (device *Device) RemoveAllAllowedIps(peer *Peer) { } -func (dev *Device) RemoveAllPeers() { - dev.mutex.Lock() - defer dev.mutex.Unlock() +func (device *Device) RemoveAllPeers() { + device.mutex.Lock() + defer device.mutex.Unlock() - for key, peer := range dev.peers { + for key, peer := range device.peers { peer.mutex.Lock() - dev.routingTable.RemovePeer(peer) - delete(dev.peers, key) + device.routingTable.RemovePeer(peer) + delete(device.peers, key) peer.mutex.Unlock() } } diff --git a/src/index.go b/src/index.go new file mode 100644 index 0000000..83a7e29 --- /dev/null +++ b/src/index.go @@ -0,0 +1,82 @@ +package main + +import ( + "crypto/rand" + "sync" +) + +/* Index=0 is reserved for unset indecies + * + */ + +type IndexTable struct { + mutex sync.RWMutex + keypairs map[uint32]*KeyPair + handshakes map[uint32]*Handshake +} + +func randUint32() (uint32, error) { + var buff [4]byte + _, err := rand.Read(buff[:]) + id := uint32(buff[0]) + id <<= 8 + id |= uint32(buff[1]) + id <<= 8 + id |= uint32(buff[2]) + id <<= 8 + id |= uint32(buff[3]) + return id, err +} + +func (table *IndexTable) Init() { + table.mutex.Lock() + defer table.mutex.Unlock() + table.keypairs = make(map[uint32]*KeyPair) + table.handshakes = make(map[uint32]*Handshake) +} + +func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) { + table.mutex.Lock() + defer table.mutex.Unlock() + for { + // generate random index + + id, err := randUint32() + if err != nil { + return id, err + } + if id == 0 { + continue + } + + // check if index used + + _, ok := table.keypairs[id] + if ok { + continue + } + _, ok = table.handshakes[id] + if ok { + continue + } + + // update the index + + delete(table.handshakes, handshake.localIndex) + handshake.localIndex = id + table.handshakes[id] = handshake + return id, nil + } +} + +func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.keypairs[id] +} + +func (table *IndexTable) LookupHandshake(id uint32) *Handshake { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.handshakes[id] +} diff --git a/src/keypair.go b/src/keypair.go new file mode 100644 index 0000000..22a8244 --- /dev/null +++ b/src/keypair.go @@ -0,0 +1,12 @@ +package main + +import ( + "crypto/cipher" +) + +type KeyPair struct { + recieveKey cipher.AEAD + recieveNonce NoiseNonce + sendKey cipher.AEAD + sendNonce NoiseNonce +} diff --git a/src/noise_helpers.go b/src/noise_helpers.go index df25011..eadbc07 100644 --- a/src/noise_helpers.go +++ b/src/noise_helpers.go @@ -81,6 +81,6 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { apk := (*[NoisePublicKeySize]byte)(&pk) ask := (*[NoisePrivateKeySize]byte)(sk) - curve25519.ScalarMult(&ss, apk, ask) + curve25519.ScalarMult(&ss, ask, apk) return ss } diff --git a/src/noise_protocol.go b/src/noise_protocol.go index e7c8774..b9c8981 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -56,18 +56,22 @@ type MessageTransport struct { } type Handshake struct { - lock sync.Mutex - state int - chainKey [blake2s.Size]byte // chain key - hash [blake2s.Size]byte // hash value - staticStatic NoisePublicKey // precomputed DH(S_i, S_r) - ephemeral NoisePrivateKey // ephemeral secret key - remoteIndex uint32 // index for sending - device *Device - peer *Peer + state int + mutex sync.Mutex + hash [blake2s.Size]byte // hash value + chainKey [blake2s.Size]byte // chain key + presharedKey NoiseSymmetricKey // psk + localEphemeral NoisePrivateKey // ephemeral secret key + localIndex uint32 // used to clear hash-table + remoteIndex uint32 // index for sending + remoteStatic NoisePublicKey // long term key + remoteEphemeral NoisePublicKey // ephemeral public key + precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret + lastTimestamp TAI64N } var ( + EmptyMessage []byte ZeroNonce [chacha20poly1305.NonceSize]byte InitalChainKey [blake2s.Size]byte InitalHash [blake2s.Size]byte @@ -78,102 +82,196 @@ func init() { InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...)) } -func (h *Handshake) Precompute() { - h.staticStatic = h.device.privateKey.sharedSecret(h.peer.publicKey) -} - -func (h *Handshake) ConsumeMessageResponse(msg *MessageResponse) { - -} - -func (h *Handshake) addHash(data []byte) { +func (h *Handshake) addToHash(data []byte) { h.hash = addToHash(h.hash, data) } -func (h *Handshake) addChain(data []byte) { +func (h *Handshake) addToChainKey(data []byte) { h.chainKey = addToChainKey(h.chainKey, data) } -func (h *Handshake) CreateMessageInital() (*MessageInital, error) { - h.lock.Lock() - defer h.lock.Unlock() +func (device *Device) Precompute(peer *Peer) { + h := &peer.handshake + h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) +} - // reset handshake - - var err error - h.ephemeral, err = newPrivateKey() - if err != nil { - return nil, err - } - h.chainKey = InitalChainKey - h.hash = addToHash(InitalHash, h.device.publicKey[:]) +func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() // create ephemeral key + var err error + handshake.chainKey = InitalChainKey + handshake.hash = addToHash(InitalHash, handshake.remoteStatic[:]) + handshake.localEphemeral, err = newPrivateKey() + if err != nil { + return nil, err + } + + // assign index + var msg MessageInital + msg.Type = MessageInitalType - msg.Sender = h.device.NewID(h) - msg.Ephemeral = h.ephemeral.publicKey() - h.chainKey = addToChainKey(h.chainKey, msg.Ephemeral[:]) - h.hash = addToHash(h.hash, msg.Ephemeral[:]) + msg.Ephemeral = handshake.localEphemeral.publicKey() + msg.Sender, err = device.indices.NewIndex(handshake) + + if err != nil { + return nil, err + } + + handshake.addToChainKey(msg.Ephemeral[:]) + handshake.addToHash(msg.Ephemeral[:]) // encrypt long-term "identity key" func() { var key [chacha20poly1305.KeySize]byte - ss := h.ephemeral.sharedSecret(h.peer.publicKey) - h.chainKey, key = KDF2(h.chainKey[:], ss[:]) + ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + handshake.chainKey, key = KDF2(handshake.chainKey[:], ss[:]) aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Static[:0], ZeroNonce[:], h.device.publicKey[:], nil) + aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:]) }() - h.addHash(msg.Static[:]) + handshake.addToHash(msg.Static[:]) // encrypt timestamp timestamp := Timestamp() func() { var key [chacha20poly1305.KeySize]byte - h.chainKey, key = KDF2(h.chainKey[:], h.staticStatic[:]) + handshake.chainKey, key = KDF2( + handshake.chainKey[:], + handshake.precomputedStaticStatic[:], + ) aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], nil) + aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) }() - h.addHash(msg.Timestamp[:]) - h.state = HandshakeInitialCreated + + handshake.addToHash(msg.Timestamp[:]) + handshake.state = HandshakeInitialCreated + return &msg, nil } -func (h *Handshake) ConsumeMessageInitial(msg *MessageInital) error { +func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { if msg.Type != MessageInitalType { panic(errors.New("bug: invalid inital message type")) } - hash := addToHash(InitalHash, h.device.publicKey[:]) - chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:]) + hash := addToHash(InitalHash, device.publicKey[:]) hash = addToHash(hash, msg.Ephemeral[:]) + chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:]) - // + // decrypt identity key - ephemeral, err := newPrivateKey() + var err error + var peerPK NoisePublicKey + func() { + var key [chacha20poly1305.KeySize]byte + ss := device.privateKey.sharedSecret(msg.Ephemeral) + chainKey, key = KDF2(chainKey[:], ss[:]) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) + }() if err != nil { - return err + return nil } + hash = addToHash(hash, msg.Static[:]) + + // find peer + + peer := device.LookupPeer(peerPK) + if peer == nil { + return nil + } + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + // decrypt timestamp + + var timestamp TAI64N + func() { + var key [chacha20poly1305.KeySize]byte + chainKey, key = KDF2( + chainKey[:], + handshake.precomputedStaticStatic[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) + }() + if err != nil { + return nil + } + hash = addToHash(hash, msg.Timestamp[:]) + + // check for replay attack + + if !timestamp.After(handshake.lastTimestamp) { + return nil + } + + // check for flood attack // update handshake state - h.lock.Lock() - defer h.lock.Unlock() - - h.hash = hash - h.chainKey = chainKey - h.remoteIndex = msg.Sender - h.ephemeral = ephemeral - h.state = HandshakeInitialConsumed - - return nil - + handshake.hash = hash + handshake.chainKey = chainKey + handshake.remoteIndex = msg.Sender + handshake.remoteEphemeral = msg.Ephemeral + handshake.state = HandshakeInitialConsumed + return peer } -func (h *Handshake) CreateMessageResponse() []byte { +func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() - return nil + if handshake.state != HandshakeInitialConsumed { + panic(errors.New("bug: handshake initation must be consumed first")) + } + + // assign index + + var err error + var msg MessageResponse + msg.Type = MessageResponseType + msg.Sender, err = device.indices.NewIndex(handshake) + msg.Reciever = handshake.remoteIndex + if err != nil { + return nil, err + } + + // create ephemeral key + + handshake.localEphemeral, err = newPrivateKey() + if err != nil { + return nil, err + } + msg.Ephemeral = handshake.localEphemeral.publicKey() + + func() { + ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) + handshake.addToChainKey(ss[:]) + ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + handshake.addToChainKey(ss[:]) + }() + + // add preshared key (psk) + + var tau [blake2s.Size]byte + var key [chacha20poly1305.KeySize]byte + handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:]) + handshake.addToHash(tau[:]) + + func() { + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:]) + handshake.addToHash(msg.Empty[:]) + }() + + return &msg, nil } diff --git a/src/noise_test.go b/src/noise_test.go index b3ea54f..8d6a0fa 100644 --- a/src/noise_test.go +++ b/src/noise_test.go @@ -1,38 +1,93 @@ package main import ( + "bytes" + "encoding/binary" "testing" ) -func TestHandshake(t *testing.T) { - var dev1 Device - var dev2 Device - - var err error - - dev1.privateKey, err = newPrivateKey() +func assertNil(t *testing.T, err error) { if err != nil { t.Fatal(err) } +} - dev2.privateKey, err = newPrivateKey() +func assertEqual(t *testing.T, a []byte, b []byte) { + if bytes.Compare(a, b) != 0 { + t.Fatal(a, "!=", b) + } +} + +func TestCurveWrappers(t *testing.T) { + sk1, err := newPrivateKey() + assertNil(t, err) + + sk2, err := newPrivateKey() + assertNil(t, err) + + pk1 := sk1.publicKey() + pk2 := sk2.publicKey() + + ss1 := sk1.sharedSecret(pk2) + ss2 := sk2.sharedSecret(pk1) + + if ss1 != ss2 { + t.Fatal("Failed to compute shared secet") + } +} + +func newDevice(t *testing.T) *Device { + var device Device + sk, err := newPrivateKey() if err != nil { t.Fatal(err) } + device.Init() + device.SetPrivateKey(sk) + return &device +} - var peer1 Peer - var peer2 Peer +func TestNoiseHandshake(t *testing.T) { - peer1.publicKey = dev1.privateKey.publicKey() - peer2.publicKey = dev2.privateKey.publicKey() + dev1 := newDevice(t) + dev2 := newDevice(t) - var handshake1 Handshake - var handshake2 Handshake + peer1 := dev2.NewPeer(dev1.privateKey.publicKey()) + peer2 := dev1.NewPeer(dev2.privateKey.publicKey()) - handshake1.device = &dev1 - handshake2.device = &dev2 + assertEqual( + t, + peer1.handshake.precomputedStaticStatic[:], + peer2.handshake.precomputedStaticStatic[:], + ) - handshake1.peer = &peer2 - handshake2.peer = &peer1 + /* simulate handshake */ + + // Initiation message + + msg1, err := dev1.CreateMessageInitial(peer2) + assertNil(t, err) + + packet := make([]byte, 0, 256) + writer := bytes.NewBuffer(packet) + err = binary.Write(writer, binary.LittleEndian, msg1) + peer := dev2.ConsumeMessageInitial(msg1) + if peer == nil { + t.Fatal("handshake failed at initiation message") + } + + assertEqual( + t, + peer1.handshake.chainKey[:], + peer2.handshake.chainKey[:], + ) + + assertEqual( + t, + peer1.handshake.hash[:], + peer2.handshake.hash[:], + ) + + // Response message } diff --git a/src/peer.go b/src/peer.go index db5e99f..f6eb555 100644 --- a/src/peer.go +++ b/src/peer.go @@ -6,17 +6,35 @@ import ( "time" ) -type KeyPair struct { - recieveKey NoiseSymmetricKey - recieveNonce NoiseNonce - sendKey NoiseSymmetricKey - sendNonce NoiseNonce -} - type Peer struct { mutex sync.RWMutex - publicKey NoisePublicKey - presharedKey NoiseSymmetricKey - endpoint net.IP - persistentKeepaliveInterval time.Duration + endpointIP net.IP // + endpointPort uint16 // + persistentKeepaliveInterval time.Duration // 0 = disabled + handshake Handshake + device *Device +} + +func (device *Device) NewPeer(pk NoisePublicKey) *Peer { + var peer Peer + + // map public key + + device.mutex.Lock() + device.peers[pk] = &peer + device.mutex.Unlock() + + // precompute + + peer.mutex.Lock() + peer.device = device + func(h *Handshake) { + h.mutex.Lock() + h.remoteStatic = pk + h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) + h.mutex.Unlock() + }(&peer.handshake) + peer.mutex.Unlock() + + return &peer } diff --git a/src/routing.go b/src/routing.go index 0aa111c..553df11 100644 --- a/src/routing.go +++ b/src/routing.go @@ -13,6 +13,13 @@ type RoutingTable struct { mutex sync.RWMutex } +func (table *RoutingTable) Reset() { + table.mutex.Lock() + defer table.mutex.Unlock() + table.IPv4 = nil + table.IPv6 = nil +} + func (table *RoutingTable) RemovePeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() diff --git a/src/tai64.go b/src/tai64.go index d0d1432..2299a37 100644 --- a/src/tai64.go +++ b/src/tai64.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "encoding/binary" "time" ) @@ -21,3 +22,7 @@ func Timestamp() TAI64N { binary.BigEndian.PutUint32(tai64n[8:], nano) return tai64n } + +func (t1 *TAI64N) After(t2 TAI64N) bool { + return bytes.Compare(t1[:], t2[:]) > 0 +}