Restructuring of noise impl.

This commit is contained in:
Mathias Hall-Andersen 2017-06-24 15:34:17 +02:00
parent 521e77fd54
commit 25190e4336
10 changed files with 420 additions and 128 deletions

View file

@ -99,11 +99,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
if ok { if ok {
peer = found peer = found
} else { } else {
newPeer := &Peer{ peer = device.NewPeer(pubKey)
publicKey: pubKey,
}
peer = newPeer
device.peers[pubKey] = newPeer
} }
case "replace_peers": case "replace_peers":
@ -125,14 +121,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "remove": case "remove":
peer.mutex.Lock() peer.mutex.Lock()
device.RemovePeer(peer.publicKey) // device.RemovePeer(peer.publicKey)
peer = nil peer = nil
case "preshared_key": case "preshared_key":
err := func() error { err := func() error {
peer.mutex.Lock() peer.mutex.Lock()
defer peer.mutex.Unlock() defer peer.mutex.Unlock()
return peer.presharedKey.FromHex(value) return peer.handshake.presharedKey.FromHex(value)
}() }()
if err != nil { if err != nil {
return &IPCError{Code: ipcErrorInvalidPublicKey} return &IPCError{Code: ipcErrorInvalidPublicKey}
@ -144,7 +140,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalidIPAddress} return &IPCError{Code: ipcErrorInvalidIPAddress}
} }
peer.mutex.Lock() peer.mutex.Lock()
peer.endpoint = ip // peer.endpoint = ip FIX
peer.mutex.Unlock() peer.mutex.Unlock()
case "persistent_keepalive_interval": case "persistent_keepalive_interval":

View file

@ -1,17 +1,13 @@
package main package main
import ( import (
"math/rand"
"sync" "sync"
) )
/* TODO: Locking may be a little broad here
*/
type Device struct { type Device struct {
mutex sync.RWMutex mutex sync.RWMutex
peers map[NoisePublicKey]*Peer peers map[NoisePublicKey]*Peer
sessions map[uint32]*Handshake indices IndexTable
privateKey NoisePrivateKey privateKey NoisePrivateKey
publicKey NoisePublicKey publicKey NoisePublicKey
fwMark uint32 fwMark uint32
@ -19,43 +15,66 @@ type Device struct {
routingTable RoutingTable routingTable RoutingTable
} }
func (dev *Device) NewID(h *Handshake) uint32 { func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
dev.mutex.Lock() device.mutex.Lock()
defer dev.mutex.Unlock() defer device.mutex.Unlock()
for {
id := rand.Uint32() // update key material
_, ok := dev.sessions[id]
if !ok { device.privateKey = sk
dev.sessions[id] = h device.publicKey = sk.publicKey()
return id
} // do precomputations
for _, peer := range device.peers {
h := &peer.handshake
h.mutex.Lock()
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
h.mutex.Unlock()
} }
} }
func (dev *Device) RemovePeer(key NoisePublicKey) { func (device *Device) Init() {
dev.mutex.Lock() device.mutex.Lock()
defer dev.mutex.Unlock() defer device.mutex.Unlock()
peer, ok := dev.peers[key]
device.peers = make(map[NoisePublicKey]*Peer)
device.indices.Init()
device.listenPort = 0
device.routingTable.Reset()
}
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.mutex.RLock()
defer device.mutex.RUnlock()
return device.peers[pk]
}
func (device *Device) RemovePeer(key NoisePublicKey) {
device.mutex.Lock()
defer device.mutex.Unlock()
peer, ok := device.peers[key]
if !ok { if !ok {
return return
} }
peer.mutex.Lock() peer.mutex.Lock()
dev.routingTable.RemovePeer(peer) device.routingTable.RemovePeer(peer)
delete(dev.peers, key) delete(device.peers, key)
} }
func (dev *Device) RemoveAllAllowedIps(peer *Peer) { func (device *Device) RemoveAllAllowedIps(peer *Peer) {
} }
func (dev *Device) RemoveAllPeers() { func (device *Device) RemoveAllPeers() {
dev.mutex.Lock() device.mutex.Lock()
defer dev.mutex.Unlock() defer device.mutex.Unlock()
for key, peer := range dev.peers { for key, peer := range device.peers {
peer.mutex.Lock() peer.mutex.Lock()
dev.routingTable.RemovePeer(peer) device.routingTable.RemovePeer(peer)
delete(dev.peers, key) delete(device.peers, key)
peer.mutex.Unlock() peer.mutex.Unlock()
} }
} }

82
src/index.go Normal file
View file

@ -0,0 +1,82 @@
package main
import (
"crypto/rand"
"sync"
)
/* Index=0 is reserved for unset indecies
*
*/
type IndexTable struct {
mutex sync.RWMutex
keypairs map[uint32]*KeyPair
handshakes map[uint32]*Handshake
}
func randUint32() (uint32, error) {
var buff [4]byte
_, err := rand.Read(buff[:])
id := uint32(buff[0])
id <<= 8
id |= uint32(buff[1])
id <<= 8
id |= uint32(buff[2])
id <<= 8
id |= uint32(buff[3])
return id, err
}
func (table *IndexTable) Init() {
table.mutex.Lock()
defer table.mutex.Unlock()
table.keypairs = make(map[uint32]*KeyPair)
table.handshakes = make(map[uint32]*Handshake)
}
func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
table.mutex.Lock()
defer table.mutex.Unlock()
for {
// generate random index
id, err := randUint32()
if err != nil {
return id, err
}
if id == 0 {
continue
}
// check if index used
_, ok := table.keypairs[id]
if ok {
continue
}
_, ok = table.handshakes[id]
if ok {
continue
}
// update the index
delete(table.handshakes, handshake.localIndex)
handshake.localIndex = id
table.handshakes[id] = handshake
return id, nil
}
}
func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.keypairs[id]
}
func (table *IndexTable) LookupHandshake(id uint32) *Handshake {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.handshakes[id]
}

12
src/keypair.go Normal file
View file

@ -0,0 +1,12 @@
package main
import (
"crypto/cipher"
)
type KeyPair struct {
recieveKey cipher.AEAD
recieveNonce NoiseNonce
sendKey cipher.AEAD
sendNonce NoiseNonce
}

View file

@ -81,6 +81,6 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
apk := (*[NoisePublicKeySize]byte)(&pk) apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk) ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, apk, ask) curve25519.ScalarMult(&ss, ask, apk)
return ss return ss
} }

View file

@ -56,18 +56,22 @@ type MessageTransport struct {
} }
type Handshake struct { type Handshake struct {
lock sync.Mutex
state int state int
chainKey [blake2s.Size]byte // chain key mutex sync.Mutex
hash [blake2s.Size]byte // hash value hash [blake2s.Size]byte // hash value
staticStatic NoisePublicKey // precomputed DH(S_i, S_r) chainKey [blake2s.Size]byte // chain key
ephemeral NoisePrivateKey // ephemeral secret key presharedKey NoiseSymmetricKey // psk
localEphemeral NoisePrivateKey // ephemeral secret key
localIndex uint32 // used to clear hash-table
remoteIndex uint32 // index for sending remoteIndex uint32 // index for sending
device *Device remoteStatic NoisePublicKey // long term key
peer *Peer remoteEphemeral NoisePublicKey // ephemeral public key
precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
lastTimestamp TAI64N
} }
var ( var (
EmptyMessage []byte
ZeroNonce [chacha20poly1305.NonceSize]byte ZeroNonce [chacha20poly1305.NonceSize]byte
InitalChainKey [blake2s.Size]byte InitalChainKey [blake2s.Size]byte
InitalHash [blake2s.Size]byte InitalHash [blake2s.Size]byte
@ -78,102 +82,196 @@ func init() {
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...)) InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
} }
func (h *Handshake) Precompute() { func (h *Handshake) addToHash(data []byte) {
h.staticStatic = h.device.privateKey.sharedSecret(h.peer.publicKey)
}
func (h *Handshake) ConsumeMessageResponse(msg *MessageResponse) {
}
func (h *Handshake) addHash(data []byte) {
h.hash = addToHash(h.hash, data) h.hash = addToHash(h.hash, data)
} }
func (h *Handshake) addChain(data []byte) { func (h *Handshake) addToChainKey(data []byte) {
h.chainKey = addToChainKey(h.chainKey, data) h.chainKey = addToChainKey(h.chainKey, data)
} }
func (h *Handshake) CreateMessageInital() (*MessageInital, error) { func (device *Device) Precompute(peer *Peer) {
h.lock.Lock() h := &peer.handshake
defer h.lock.Unlock() h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
}
// reset handshake func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
handshake := &peer.handshake
var err error handshake.mutex.Lock()
h.ephemeral, err = newPrivateKey() defer handshake.mutex.Unlock()
if err != nil {
return nil, err
}
h.chainKey = InitalChainKey
h.hash = addToHash(InitalHash, h.device.publicKey[:])
// create ephemeral key // create ephemeral key
var err error
handshake.chainKey = InitalChainKey
handshake.hash = addToHash(InitalHash, handshake.remoteStatic[:])
handshake.localEphemeral, err = newPrivateKey()
if err != nil {
return nil, err
}
// assign index
var msg MessageInital var msg MessageInital
msg.Type = MessageInitalType msg.Type = MessageInitalType
msg.Sender = h.device.NewID(h) msg.Ephemeral = handshake.localEphemeral.publicKey()
msg.Ephemeral = h.ephemeral.publicKey() msg.Sender, err = device.indices.NewIndex(handshake)
h.chainKey = addToChainKey(h.chainKey, msg.Ephemeral[:])
h.hash = addToHash(h.hash, msg.Ephemeral[:]) if err != nil {
return nil, err
}
handshake.addToChainKey(msg.Ephemeral[:])
handshake.addToHash(msg.Ephemeral[:])
// encrypt long-term "identity key" // encrypt long-term "identity key"
func() { func() {
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
ss := h.ephemeral.sharedSecret(h.peer.publicKey) ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
h.chainKey, key = KDF2(h.chainKey[:], ss[:]) handshake.chainKey, key = KDF2(handshake.chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], h.device.publicKey[:], nil) aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
}() }()
h.addHash(msg.Static[:]) handshake.addToHash(msg.Static[:])
// encrypt timestamp // encrypt timestamp
timestamp := Timestamp() timestamp := Timestamp()
func() { func() {
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
h.chainKey, key = KDF2(h.chainKey[:], h.staticStatic[:]) handshake.chainKey, key = KDF2(
handshake.chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], nil) aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
}() }()
h.addHash(msg.Timestamp[:])
h.state = HandshakeInitialCreated handshake.addToHash(msg.Timestamp[:])
handshake.state = HandshakeInitialCreated
return &msg, nil return &msg, nil
} }
func (h *Handshake) ConsumeMessageInitial(msg *MessageInital) error { func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
if msg.Type != MessageInitalType { if msg.Type != MessageInitalType {
panic(errors.New("bug: invalid inital message type")) panic(errors.New("bug: invalid inital message type"))
} }
hash := addToHash(InitalHash, h.device.publicKey[:]) hash := addToHash(InitalHash, device.publicKey[:])
chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
hash = addToHash(hash, msg.Ephemeral[:]) hash = addToHash(hash, msg.Ephemeral[:])
chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
// // decrypt identity key
ephemeral, err := newPrivateKey() var err error
var peerPK NoisePublicKey
func() {
var key [chacha20poly1305.KeySize]byte
ss := device.privateKey.sharedSecret(msg.Ephemeral)
chainKey, key = KDF2(chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
}()
if err != nil { if err != nil {
return err return nil
} }
hash = addToHash(hash, msg.Static[:])
// find peer
peer := device.LookupPeer(peerPK)
if peer == nil {
return nil
}
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
// decrypt timestamp
var timestamp TAI64N
func() {
var key [chacha20poly1305.KeySize]byte
chainKey, key = KDF2(
chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
}()
if err != nil {
return nil
}
hash = addToHash(hash, msg.Timestamp[:])
// check for replay attack
if !timestamp.After(handshake.lastTimestamp) {
return nil
}
// check for flood attack
// update handshake state // update handshake state
h.lock.Lock() handshake.hash = hash
defer h.lock.Unlock() handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
h.hash = hash handshake.remoteEphemeral = msg.Ephemeral
h.chainKey = chainKey handshake.state = HandshakeInitialConsumed
h.remoteIndex = msg.Sender return peer
h.ephemeral = ephemeral
h.state = HandshakeInitialConsumed
return nil
} }
func (h *Handshake) CreateMessageResponse() []byte { func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
return nil if handshake.state != HandshakeInitialConsumed {
panic(errors.New("bug: handshake initation must be consumed first"))
}
// assign index
var err error
var msg MessageResponse
msg.Type = MessageResponseType
msg.Sender, err = device.indices.NewIndex(handshake)
msg.Reciever = handshake.remoteIndex
if err != nil {
return nil, err
}
// create ephemeral key
handshake.localEphemeral, err = newPrivateKey()
if err != nil {
return nil, err
}
msg.Ephemeral = handshake.localEphemeral.publicKey()
func() {
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
handshake.addToChainKey(ss[:])
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
handshake.addToChainKey(ss[:])
}()
// add preshared key (psk)
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
handshake.addToHash(tau[:])
func() {
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:])
handshake.addToHash(msg.Empty[:])
}()
return &msg, nil
} }

View file

@ -1,38 +1,93 @@
package main package main
import ( import (
"bytes"
"encoding/binary"
"testing" "testing"
) )
func TestHandshake(t *testing.T) { func assertNil(t *testing.T, err error) {
var dev1 Device
var dev2 Device
var err error
dev1.privateKey, err = newPrivateKey()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
}
dev2.privateKey, err = newPrivateKey() func assertEqual(t *testing.T, a []byte, b []byte) {
if bytes.Compare(a, b) != 0 {
t.Fatal(a, "!=", b)
}
}
func TestCurveWrappers(t *testing.T) {
sk1, err := newPrivateKey()
assertNil(t, err)
sk2, err := newPrivateKey()
assertNil(t, err)
pk1 := sk1.publicKey()
pk2 := sk2.publicKey()
ss1 := sk1.sharedSecret(pk2)
ss2 := sk2.sharedSecret(pk1)
if ss1 != ss2 {
t.Fatal("Failed to compute shared secet")
}
}
func newDevice(t *testing.T) *Device {
var device Device
sk, err := newPrivateKey()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
device.Init()
device.SetPrivateKey(sk)
return &device
}
var peer1 Peer func TestNoiseHandshake(t *testing.T) {
var peer2 Peer
peer1.publicKey = dev1.privateKey.publicKey() dev1 := newDevice(t)
peer2.publicKey = dev2.privateKey.publicKey() dev2 := newDevice(t)
var handshake1 Handshake peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
var handshake2 Handshake peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
handshake1.device = &dev1 assertEqual(
handshake2.device = &dev2 t,
peer1.handshake.precomputedStaticStatic[:],
peer2.handshake.precomputedStaticStatic[:],
)
handshake1.peer = &peer2 /* simulate handshake */
handshake2.peer = &peer1
// Initiation message
msg1, err := dev1.CreateMessageInitial(peer2)
assertNil(t, err)
packet := make([]byte, 0, 256)
writer := bytes.NewBuffer(packet)
err = binary.Write(writer, binary.LittleEndian, msg1)
peer := dev2.ConsumeMessageInitial(msg1)
if peer == nil {
t.Fatal("handshake failed at initiation message")
}
assertEqual(
t,
peer1.handshake.chainKey[:],
peer2.handshake.chainKey[:],
)
assertEqual(
t,
peer1.handshake.hash[:],
peer2.handshake.hash[:],
)
// Response message
} }

View file

@ -6,17 +6,35 @@ import (
"time" "time"
) )
type KeyPair struct {
recieveKey NoiseSymmetricKey
recieveNonce NoiseNonce
sendKey NoiseSymmetricKey
sendNonce NoiseNonce
}
type Peer struct { type Peer struct {
mutex sync.RWMutex mutex sync.RWMutex
publicKey NoisePublicKey endpointIP net.IP //
presharedKey NoiseSymmetricKey endpointPort uint16 //
endpoint net.IP persistentKeepaliveInterval time.Duration // 0 = disabled
persistentKeepaliveInterval time.Duration handshake Handshake
device *Device
}
func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
var peer Peer
// map public key
device.mutex.Lock()
device.peers[pk] = &peer
device.mutex.Unlock()
// precompute
peer.mutex.Lock()
peer.device = device
func(h *Handshake) {
h.mutex.Lock()
h.remoteStatic = pk
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
h.mutex.Unlock()
}(&peer.handshake)
peer.mutex.Unlock()
return &peer
} }

View file

@ -13,6 +13,13 @@ type RoutingTable struct {
mutex sync.RWMutex mutex sync.RWMutex
} }
func (table *RoutingTable) Reset() {
table.mutex.Lock()
defer table.mutex.Unlock()
table.IPv4 = nil
table.IPv6 = nil
}
func (table *RoutingTable) RemovePeer(peer *Peer) { func (table *RoutingTable) RemovePeer(peer *Peer) {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"time" "time"
) )
@ -21,3 +22,7 @@ func Timestamp() TAI64N {
binary.BigEndian.PutUint32(tai64n[8:], nano) binary.BigEndian.PutUint32(tai64n[8:], nano)
return tai64n return tai64n
} }
func (t1 *TAI64N) After(t2 TAI64N) bool {
return bytes.Compare(t1[:], t2[:]) > 0
}