Restructuring of noise impl.
This commit is contained in:
parent
521e77fd54
commit
25190e4336
|
@ -99,11 +99,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||||
if ok {
|
if ok {
|
||||||
peer = found
|
peer = found
|
||||||
} else {
|
} else {
|
||||||
newPeer := &Peer{
|
peer = device.NewPeer(pubKey)
|
||||||
publicKey: pubKey,
|
|
||||||
}
|
|
||||||
peer = newPeer
|
|
||||||
device.peers[pubKey] = newPeer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case "replace_peers":
|
case "replace_peers":
|
||||||
|
@ -125,14 +121,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||||
|
|
||||||
case "remove":
|
case "remove":
|
||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
device.RemovePeer(peer.publicKey)
|
// device.RemovePeer(peer.publicKey)
|
||||||
peer = nil
|
peer = nil
|
||||||
|
|
||||||
case "preshared_key":
|
case "preshared_key":
|
||||||
err := func() error {
|
err := func() error {
|
||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
defer peer.mutex.Unlock()
|
defer peer.mutex.Unlock()
|
||||||
return peer.presharedKey.FromHex(value)
|
return peer.handshake.presharedKey.FromHex(value)
|
||||||
}()
|
}()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &IPCError{Code: ipcErrorInvalidPublicKey}
|
return &IPCError{Code: ipcErrorInvalidPublicKey}
|
||||||
|
@ -144,7 +140,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||||
return &IPCError{Code: ipcErrorInvalidIPAddress}
|
return &IPCError{Code: ipcErrorInvalidIPAddress}
|
||||||
}
|
}
|
||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
peer.endpoint = ip
|
// peer.endpoint = ip FIX
|
||||||
peer.mutex.Unlock()
|
peer.mutex.Unlock()
|
||||||
|
|
||||||
case "persistent_keepalive_interval":
|
case "persistent_keepalive_interval":
|
||||||
|
|
|
@ -1,17 +1,13 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* TODO: Locking may be a little broad here
|
|
||||||
*/
|
|
||||||
|
|
||||||
type Device struct {
|
type Device struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
peers map[NoisePublicKey]*Peer
|
peers map[NoisePublicKey]*Peer
|
||||||
sessions map[uint32]*Handshake
|
indices IndexTable
|
||||||
privateKey NoisePrivateKey
|
privateKey NoisePrivateKey
|
||||||
publicKey NoisePublicKey
|
publicKey NoisePublicKey
|
||||||
fwMark uint32
|
fwMark uint32
|
||||||
|
@ -19,43 +15,66 @@ type Device struct {
|
||||||
routingTable RoutingTable
|
routingTable RoutingTable
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dev *Device) NewID(h *Handshake) uint32 {
|
func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
|
||||||
dev.mutex.Lock()
|
device.mutex.Lock()
|
||||||
defer dev.mutex.Unlock()
|
defer device.mutex.Unlock()
|
||||||
for {
|
|
||||||
id := rand.Uint32()
|
// update key material
|
||||||
_, ok := dev.sessions[id]
|
|
||||||
if !ok {
|
device.privateKey = sk
|
||||||
dev.sessions[id] = h
|
device.publicKey = sk.publicKey()
|
||||||
return id
|
|
||||||
}
|
// do precomputations
|
||||||
|
|
||||||
|
for _, peer := range device.peers {
|
||||||
|
h := &peer.handshake
|
||||||
|
h.mutex.Lock()
|
||||||
|
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
|
||||||
|
h.mutex.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dev *Device) RemovePeer(key NoisePublicKey) {
|
func (device *Device) Init() {
|
||||||
dev.mutex.Lock()
|
device.mutex.Lock()
|
||||||
defer dev.mutex.Unlock()
|
defer device.mutex.Unlock()
|
||||||
peer, ok := dev.peers[key]
|
|
||||||
|
device.peers = make(map[NoisePublicKey]*Peer)
|
||||||
|
device.indices.Init()
|
||||||
|
device.listenPort = 0
|
||||||
|
device.routingTable.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
||||||
|
device.mutex.RLock()
|
||||||
|
defer device.mutex.RUnlock()
|
||||||
|
return device.peers[pk]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) RemovePeer(key NoisePublicKey) {
|
||||||
|
device.mutex.Lock()
|
||||||
|
defer device.mutex.Unlock()
|
||||||
|
|
||||||
|
peer, ok := device.peers[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
dev.routingTable.RemovePeer(peer)
|
device.routingTable.RemovePeer(peer)
|
||||||
delete(dev.peers, key)
|
delete(device.peers, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dev *Device) RemoveAllAllowedIps(peer *Peer) {
|
func (device *Device) RemoveAllAllowedIps(peer *Peer) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dev *Device) RemoveAllPeers() {
|
func (device *Device) RemoveAllPeers() {
|
||||||
dev.mutex.Lock()
|
device.mutex.Lock()
|
||||||
defer dev.mutex.Unlock()
|
defer device.mutex.Unlock()
|
||||||
|
|
||||||
for key, peer := range dev.peers {
|
for key, peer := range device.peers {
|
||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
dev.routingTable.RemovePeer(peer)
|
device.routingTable.RemovePeer(peer)
|
||||||
delete(dev.peers, key)
|
delete(device.peers, key)
|
||||||
peer.mutex.Unlock()
|
peer.mutex.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
82
src/index.go
Normal file
82
src/index.go
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Index=0 is reserved for unset indecies
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
type IndexTable struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
keypairs map[uint32]*KeyPair
|
||||||
|
handshakes map[uint32]*Handshake
|
||||||
|
}
|
||||||
|
|
||||||
|
func randUint32() (uint32, error) {
|
||||||
|
var buff [4]byte
|
||||||
|
_, err := rand.Read(buff[:])
|
||||||
|
id := uint32(buff[0])
|
||||||
|
id <<= 8
|
||||||
|
id |= uint32(buff[1])
|
||||||
|
id <<= 8
|
||||||
|
id |= uint32(buff[2])
|
||||||
|
id <<= 8
|
||||||
|
id |= uint32(buff[3])
|
||||||
|
return id, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *IndexTable) Init() {
|
||||||
|
table.mutex.Lock()
|
||||||
|
defer table.mutex.Unlock()
|
||||||
|
table.keypairs = make(map[uint32]*KeyPair)
|
||||||
|
table.handshakes = make(map[uint32]*Handshake)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
|
||||||
|
table.mutex.Lock()
|
||||||
|
defer table.mutex.Unlock()
|
||||||
|
for {
|
||||||
|
// generate random index
|
||||||
|
|
||||||
|
id, err := randUint32()
|
||||||
|
if err != nil {
|
||||||
|
return id, err
|
||||||
|
}
|
||||||
|
if id == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if index used
|
||||||
|
|
||||||
|
_, ok := table.keypairs[id]
|
||||||
|
if ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, ok = table.handshakes[id]
|
||||||
|
if ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// update the index
|
||||||
|
|
||||||
|
delete(table.handshakes, handshake.localIndex)
|
||||||
|
handshake.localIndex = id
|
||||||
|
table.handshakes[id] = handshake
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair {
|
||||||
|
table.mutex.RLock()
|
||||||
|
defer table.mutex.RUnlock()
|
||||||
|
return table.keypairs[id]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *IndexTable) LookupHandshake(id uint32) *Handshake {
|
||||||
|
table.mutex.RLock()
|
||||||
|
defer table.mutex.RUnlock()
|
||||||
|
return table.handshakes[id]
|
||||||
|
}
|
12
src/keypair.go
Normal file
12
src/keypair.go
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/cipher"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyPair struct {
|
||||||
|
recieveKey cipher.AEAD
|
||||||
|
recieveNonce NoiseNonce
|
||||||
|
sendKey cipher.AEAD
|
||||||
|
sendNonce NoiseNonce
|
||||||
|
}
|
|
@ -81,6 +81,6 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
|
||||||
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
|
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
|
||||||
apk := (*[NoisePublicKeySize]byte)(&pk)
|
apk := (*[NoisePublicKeySize]byte)(&pk)
|
||||||
ask := (*[NoisePrivateKeySize]byte)(sk)
|
ask := (*[NoisePrivateKeySize]byte)(sk)
|
||||||
curve25519.ScalarMult(&ss, apk, ask)
|
curve25519.ScalarMult(&ss, ask, apk)
|
||||||
return ss
|
return ss
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,18 +56,22 @@ type MessageTransport struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Handshake struct {
|
type Handshake struct {
|
||||||
lock sync.Mutex
|
state int
|
||||||
state int
|
mutex sync.Mutex
|
||||||
chainKey [blake2s.Size]byte // chain key
|
hash [blake2s.Size]byte // hash value
|
||||||
hash [blake2s.Size]byte // hash value
|
chainKey [blake2s.Size]byte // chain key
|
||||||
staticStatic NoisePublicKey // precomputed DH(S_i, S_r)
|
presharedKey NoiseSymmetricKey // psk
|
||||||
ephemeral NoisePrivateKey // ephemeral secret key
|
localEphemeral NoisePrivateKey // ephemeral secret key
|
||||||
remoteIndex uint32 // index for sending
|
localIndex uint32 // used to clear hash-table
|
||||||
device *Device
|
remoteIndex uint32 // index for sending
|
||||||
peer *Peer
|
remoteStatic NoisePublicKey // long term key
|
||||||
|
remoteEphemeral NoisePublicKey // ephemeral public key
|
||||||
|
precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
|
||||||
|
lastTimestamp TAI64N
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
EmptyMessage []byte
|
||||||
ZeroNonce [chacha20poly1305.NonceSize]byte
|
ZeroNonce [chacha20poly1305.NonceSize]byte
|
||||||
InitalChainKey [blake2s.Size]byte
|
InitalChainKey [blake2s.Size]byte
|
||||||
InitalHash [blake2s.Size]byte
|
InitalHash [blake2s.Size]byte
|
||||||
|
@ -78,102 +82,196 @@ func init() {
|
||||||
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
|
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshake) Precompute() {
|
func (h *Handshake) addToHash(data []byte) {
|
||||||
h.staticStatic = h.device.privateKey.sharedSecret(h.peer.publicKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handshake) ConsumeMessageResponse(msg *MessageResponse) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handshake) addHash(data []byte) {
|
|
||||||
h.hash = addToHash(h.hash, data)
|
h.hash = addToHash(h.hash, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshake) addChain(data []byte) {
|
func (h *Handshake) addToChainKey(data []byte) {
|
||||||
h.chainKey = addToChainKey(h.chainKey, data)
|
h.chainKey = addToChainKey(h.chainKey, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshake) CreateMessageInital() (*MessageInital, error) {
|
func (device *Device) Precompute(peer *Peer) {
|
||||||
h.lock.Lock()
|
h := &peer.handshake
|
||||||
defer h.lock.Unlock()
|
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
|
||||||
|
}
|
||||||
|
|
||||||
// reset handshake
|
func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
|
||||||
|
handshake := &peer.handshake
|
||||||
var err error
|
handshake.mutex.Lock()
|
||||||
h.ephemeral, err = newPrivateKey()
|
defer handshake.mutex.Unlock()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
h.chainKey = InitalChainKey
|
|
||||||
h.hash = addToHash(InitalHash, h.device.publicKey[:])
|
|
||||||
|
|
||||||
// create ephemeral key
|
// create ephemeral key
|
||||||
|
|
||||||
|
var err error
|
||||||
|
handshake.chainKey = InitalChainKey
|
||||||
|
handshake.hash = addToHash(InitalHash, handshake.remoteStatic[:])
|
||||||
|
handshake.localEphemeral, err = newPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// assign index
|
||||||
|
|
||||||
var msg MessageInital
|
var msg MessageInital
|
||||||
|
|
||||||
msg.Type = MessageInitalType
|
msg.Type = MessageInitalType
|
||||||
msg.Sender = h.device.NewID(h)
|
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
||||||
msg.Ephemeral = h.ephemeral.publicKey()
|
msg.Sender, err = device.indices.NewIndex(handshake)
|
||||||
h.chainKey = addToChainKey(h.chainKey, msg.Ephemeral[:])
|
|
||||||
h.hash = addToHash(h.hash, msg.Ephemeral[:])
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
handshake.addToChainKey(msg.Ephemeral[:])
|
||||||
|
handshake.addToHash(msg.Ephemeral[:])
|
||||||
|
|
||||||
// encrypt long-term "identity key"
|
// encrypt long-term "identity key"
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
var key [chacha20poly1305.KeySize]byte
|
var key [chacha20poly1305.KeySize]byte
|
||||||
ss := h.ephemeral.sharedSecret(h.peer.publicKey)
|
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||||
h.chainKey, key = KDF2(h.chainKey[:], ss[:])
|
handshake.chainKey, key = KDF2(handshake.chainKey[:], ss[:])
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
aead.Seal(msg.Static[:0], ZeroNonce[:], h.device.publicKey[:], nil)
|
aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
|
||||||
}()
|
}()
|
||||||
h.addHash(msg.Static[:])
|
handshake.addToHash(msg.Static[:])
|
||||||
|
|
||||||
// encrypt timestamp
|
// encrypt timestamp
|
||||||
|
|
||||||
timestamp := Timestamp()
|
timestamp := Timestamp()
|
||||||
func() {
|
func() {
|
||||||
var key [chacha20poly1305.KeySize]byte
|
var key [chacha20poly1305.KeySize]byte
|
||||||
h.chainKey, key = KDF2(h.chainKey[:], h.staticStatic[:])
|
handshake.chainKey, key = KDF2(
|
||||||
|
handshake.chainKey[:],
|
||||||
|
handshake.precomputedStaticStatic[:],
|
||||||
|
)
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], nil)
|
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
|
||||||
}()
|
}()
|
||||||
h.addHash(msg.Timestamp[:])
|
|
||||||
h.state = HandshakeInitialCreated
|
handshake.addToHash(msg.Timestamp[:])
|
||||||
|
handshake.state = HandshakeInitialCreated
|
||||||
|
|
||||||
return &msg, nil
|
return &msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshake) ConsumeMessageInitial(msg *MessageInital) error {
|
func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
|
||||||
if msg.Type != MessageInitalType {
|
if msg.Type != MessageInitalType {
|
||||||
panic(errors.New("bug: invalid inital message type"))
|
panic(errors.New("bug: invalid inital message type"))
|
||||||
}
|
}
|
||||||
|
|
||||||
hash := addToHash(InitalHash, h.device.publicKey[:])
|
hash := addToHash(InitalHash, device.publicKey[:])
|
||||||
chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
|
|
||||||
hash = addToHash(hash, msg.Ephemeral[:])
|
hash = addToHash(hash, msg.Ephemeral[:])
|
||||||
|
chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
|
||||||
|
|
||||||
//
|
// decrypt identity key
|
||||||
|
|
||||||
ephemeral, err := newPrivateKey()
|
var err error
|
||||||
|
var peerPK NoisePublicKey
|
||||||
|
func() {
|
||||||
|
var key [chacha20poly1305.KeySize]byte
|
||||||
|
ss := device.privateKey.sharedSecret(msg.Ephemeral)
|
||||||
|
chainKey, key = KDF2(chainKey[:], ss[:])
|
||||||
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
|
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
|
||||||
|
}()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
hash = addToHash(hash, msg.Static[:])
|
||||||
|
|
||||||
|
// find peer
|
||||||
|
|
||||||
|
peer := device.LookupPeer(peerPK)
|
||||||
|
if peer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
handshake := &peer.handshake
|
||||||
|
handshake.mutex.Lock()
|
||||||
|
defer handshake.mutex.Unlock()
|
||||||
|
|
||||||
|
// decrypt timestamp
|
||||||
|
|
||||||
|
var timestamp TAI64N
|
||||||
|
func() {
|
||||||
|
var key [chacha20poly1305.KeySize]byte
|
||||||
|
chainKey, key = KDF2(
|
||||||
|
chainKey[:],
|
||||||
|
handshake.precomputedStaticStatic[:],
|
||||||
|
)
|
||||||
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
|
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
hash = addToHash(hash, msg.Timestamp[:])
|
||||||
|
|
||||||
|
// check for replay attack
|
||||||
|
|
||||||
|
if !timestamp.After(handshake.lastTimestamp) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check for flood attack
|
||||||
|
|
||||||
// update handshake state
|
// update handshake state
|
||||||
|
|
||||||
h.lock.Lock()
|
handshake.hash = hash
|
||||||
defer h.lock.Unlock()
|
handshake.chainKey = chainKey
|
||||||
|
handshake.remoteIndex = msg.Sender
|
||||||
h.hash = hash
|
handshake.remoteEphemeral = msg.Ephemeral
|
||||||
h.chainKey = chainKey
|
handshake.state = HandshakeInitialConsumed
|
||||||
h.remoteIndex = msg.Sender
|
return peer
|
||||||
h.ephemeral = ephemeral
|
|
||||||
h.state = HandshakeInitialConsumed
|
|
||||||
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshake) CreateMessageResponse() []byte {
|
func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
|
||||||
|
handshake := &peer.handshake
|
||||||
|
handshake.mutex.Lock()
|
||||||
|
defer handshake.mutex.Unlock()
|
||||||
|
|
||||||
return nil
|
if handshake.state != HandshakeInitialConsumed {
|
||||||
|
panic(errors.New("bug: handshake initation must be consumed first"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// assign index
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var msg MessageResponse
|
||||||
|
msg.Type = MessageResponseType
|
||||||
|
msg.Sender, err = device.indices.NewIndex(handshake)
|
||||||
|
msg.Reciever = handshake.remoteIndex
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// create ephemeral key
|
||||||
|
|
||||||
|
handshake.localEphemeral, err = newPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
||||||
|
|
||||||
|
func() {
|
||||||
|
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
||||||
|
handshake.addToChainKey(ss[:])
|
||||||
|
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||||
|
handshake.addToChainKey(ss[:])
|
||||||
|
}()
|
||||||
|
|
||||||
|
// add preshared key (psk)
|
||||||
|
|
||||||
|
var tau [blake2s.Size]byte
|
||||||
|
var key [chacha20poly1305.KeySize]byte
|
||||||
|
handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
|
||||||
|
handshake.addToHash(tau[:])
|
||||||
|
|
||||||
|
func() {
|
||||||
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
|
aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:])
|
||||||
|
handshake.addToHash(msg.Empty[:])
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &msg, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,38 +1,93 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHandshake(t *testing.T) {
|
func assertNil(t *testing.T, err error) {
|
||||||
var dev1 Device
|
|
||||||
var dev2 Device
|
|
||||||
|
|
||||||
var err error
|
|
||||||
|
|
||||||
dev1.privateKey, err = newPrivateKey()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
dev2.privateKey, err = newPrivateKey()
|
func assertEqual(t *testing.T, a []byte, b []byte) {
|
||||||
|
if bytes.Compare(a, b) != 0 {
|
||||||
|
t.Fatal(a, "!=", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCurveWrappers(t *testing.T) {
|
||||||
|
sk1, err := newPrivateKey()
|
||||||
|
assertNil(t, err)
|
||||||
|
|
||||||
|
sk2, err := newPrivateKey()
|
||||||
|
assertNil(t, err)
|
||||||
|
|
||||||
|
pk1 := sk1.publicKey()
|
||||||
|
pk2 := sk2.publicKey()
|
||||||
|
|
||||||
|
ss1 := sk1.sharedSecret(pk2)
|
||||||
|
ss2 := sk2.sharedSecret(pk1)
|
||||||
|
|
||||||
|
if ss1 != ss2 {
|
||||||
|
t.Fatal("Failed to compute shared secet")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDevice(t *testing.T) *Device {
|
||||||
|
var device Device
|
||||||
|
sk, err := newPrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
device.Init()
|
||||||
|
device.SetPrivateKey(sk)
|
||||||
|
return &device
|
||||||
|
}
|
||||||
|
|
||||||
var peer1 Peer
|
func TestNoiseHandshake(t *testing.T) {
|
||||||
var peer2 Peer
|
|
||||||
|
|
||||||
peer1.publicKey = dev1.privateKey.publicKey()
|
dev1 := newDevice(t)
|
||||||
peer2.publicKey = dev2.privateKey.publicKey()
|
dev2 := newDevice(t)
|
||||||
|
|
||||||
var handshake1 Handshake
|
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
|
||||||
var handshake2 Handshake
|
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
|
||||||
|
|
||||||
handshake1.device = &dev1
|
assertEqual(
|
||||||
handshake2.device = &dev2
|
t,
|
||||||
|
peer1.handshake.precomputedStaticStatic[:],
|
||||||
|
peer2.handshake.precomputedStaticStatic[:],
|
||||||
|
)
|
||||||
|
|
||||||
handshake1.peer = &peer2
|
/* simulate handshake */
|
||||||
handshake2.peer = &peer1
|
|
||||||
|
// Initiation message
|
||||||
|
|
||||||
|
msg1, err := dev1.CreateMessageInitial(peer2)
|
||||||
|
assertNil(t, err)
|
||||||
|
|
||||||
|
packet := make([]byte, 0, 256)
|
||||||
|
writer := bytes.NewBuffer(packet)
|
||||||
|
err = binary.Write(writer, binary.LittleEndian, msg1)
|
||||||
|
peer := dev2.ConsumeMessageInitial(msg1)
|
||||||
|
if peer == nil {
|
||||||
|
t.Fatal("handshake failed at initiation message")
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEqual(
|
||||||
|
t,
|
||||||
|
peer1.handshake.chainKey[:],
|
||||||
|
peer2.handshake.chainKey[:],
|
||||||
|
)
|
||||||
|
|
||||||
|
assertEqual(
|
||||||
|
t,
|
||||||
|
peer1.handshake.hash[:],
|
||||||
|
peer2.handshake.hash[:],
|
||||||
|
)
|
||||||
|
|
||||||
|
// Response message
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
40
src/peer.go
40
src/peer.go
|
@ -6,17 +6,35 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type KeyPair struct {
|
|
||||||
recieveKey NoiseSymmetricKey
|
|
||||||
recieveNonce NoiseNonce
|
|
||||||
sendKey NoiseSymmetricKey
|
|
||||||
sendNonce NoiseNonce
|
|
||||||
}
|
|
||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
publicKey NoisePublicKey
|
endpointIP net.IP //
|
||||||
presharedKey NoiseSymmetricKey
|
endpointPort uint16 //
|
||||||
endpoint net.IP
|
persistentKeepaliveInterval time.Duration // 0 = disabled
|
||||||
persistentKeepaliveInterval time.Duration
|
handshake Handshake
|
||||||
|
device *Device
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
|
||||||
|
var peer Peer
|
||||||
|
|
||||||
|
// map public key
|
||||||
|
|
||||||
|
device.mutex.Lock()
|
||||||
|
device.peers[pk] = &peer
|
||||||
|
device.mutex.Unlock()
|
||||||
|
|
||||||
|
// precompute
|
||||||
|
|
||||||
|
peer.mutex.Lock()
|
||||||
|
peer.device = device
|
||||||
|
func(h *Handshake) {
|
||||||
|
h.mutex.Lock()
|
||||||
|
h.remoteStatic = pk
|
||||||
|
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
|
||||||
|
h.mutex.Unlock()
|
||||||
|
}(&peer.handshake)
|
||||||
|
peer.mutex.Unlock()
|
||||||
|
|
||||||
|
return &peer
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,13 @@ type RoutingTable struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (table *RoutingTable) Reset() {
|
||||||
|
table.mutex.Lock()
|
||||||
|
defer table.mutex.Unlock()
|
||||||
|
table.IPv4 = nil
|
||||||
|
table.IPv6 = nil
|
||||||
|
}
|
||||||
|
|
||||||
func (table *RoutingTable) RemovePeer(peer *Peer) {
|
func (table *RoutingTable) RemovePeer(peer *Peer) {
|
||||||
table.mutex.Lock()
|
table.mutex.Lock()
|
||||||
defer table.mutex.Unlock()
|
defer table.mutex.Unlock()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -21,3 +22,7 @@ func Timestamp() TAI64N {
|
||||||
binary.BigEndian.PutUint32(tai64n[8:], nano)
|
binary.BigEndian.PutUint32(tai64n[8:], nano)
|
||||||
return tai64n
|
return tai64n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t1 *TAI64N) After(t2 TAI64N) bool {
|
||||||
|
return bytes.Compare(t1[:], t2[:]) > 0
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue