Improved handling of key-material

This commit is contained in:
Mathias Hall-Andersen 2017-09-01 14:21:53 +02:00
parent 239d582cb2
commit 0294a5c0dd
7 changed files with 203 additions and 91 deletions

View file

@ -2,14 +2,39 @@ package main
import ( import (
"crypto/cipher" "crypto/cipher"
"golang.org/x/crypto/chacha20poly1305"
"reflect"
"sync" "sync"
"time" "time"
) )
type safeAEAD struct {
mutex sync.RWMutex
aead cipher.AEAD
}
func (con *safeAEAD) clear() {
// TODO: improve handling of key material
con.mutex.Lock()
if con.aead != nil {
val := reflect.ValueOf(con.aead)
elm := val.Elem()
typ := elm.Type()
elm.Set(reflect.Zero(typ))
con.aead = nil
}
con.mutex.Unlock()
}
func (con *safeAEAD) setKey(key *[chacha20poly1305.KeySize]byte) {
// TODO: improve handling of key material
con.aead, _ = chacha20poly1305.New(key[:])
}
type KeyPair struct { type KeyPair struct {
receive cipher.AEAD send safeAEAD
receive safeAEAD
replayFilter ReplayFilter replayFilter ReplayFilter
send cipher.AEAD
sendNonce uint64 sendNonce uint64
isInitiator bool isInitiator bool
created time.Time created time.Time
@ -31,7 +56,7 @@ func (kp *KeyPairs) Current() *KeyPair {
} }
func (device *Device) DeleteKeyPair(key *KeyPair) { func (device *Device) DeleteKeyPair(key *KeyPair) {
key.send = nil key.send.clear()
key.receive = nil key.receive.clear()
device.indices.Delete(key.localIndex) device.indices.Delete(key.localIndex)
} }

View file

@ -13,37 +13,47 @@ import (
* https://tools.ietf.org/html/rfc5869 * https://tools.ietf.org/html/rfc5869
*/ */
func HMAC(sum *[blake2s.Size]byte, key []byte, input []byte) { func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) {
mac := hmac.New(func() hash.Hash { mac := hmac.New(func() hash.Hash {
h, _ := blake2s.New256(nil) h, _ := blake2s.New256(nil)
return h return h
}, key) }, key)
mac.Write(input) mac.Write(in0)
mac.Sum(sum[:0]) mac.Sum(sum[:0])
} }
func KDF1(key []byte, input []byte) (t0 [blake2s.Size]byte) { func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) {
HMAC(&t0, key, input) mac := hmac.New(func() hash.Hash {
HMAC(&t0, t0[:], []byte{0x1}) h, _ := blake2s.New256(nil)
return h
}, key)
mac.Write(in0)
mac.Write(in1)
mac.Sum(sum[:0])
}
func KDF1(t0 *[blake2s.Size]byte, key, input []byte) {
HMAC1(t0, key, input)
HMAC1(t0, t0[:], []byte{0x1})
return return
} }
func KDF2(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte) { func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
var prk [blake2s.Size]byte var prk [blake2s.Size]byte
HMAC(&prk, key, input) HMAC1(&prk, key, input)
HMAC(&t0, prk[:], []byte{0x1}) HMAC1(t0, prk[:], []byte{0x1})
HMAC(&t1, prk[:], append(t0[:], 0x2)) HMAC2(t1, prk[:], t0[:], []byte{0x2})
prk = [blake2s.Size]byte{} setZero(prk[:])
return return
} }
func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte, t2 [blake2s.Size]byte) { func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
var prk [blake2s.Size]byte var prk [blake2s.Size]byte
HMAC(&prk, key, input) HMAC1(&prk, key, input)
HMAC(&t0, prk[:], []byte{0x1}) HMAC1(t0, prk[:], []byte{0x1})
HMAC(&t1, prk[:], append(t0[:], 0x2)) HMAC2(t1, prk[:], t0[:], []byte{0x2})
HMAC(&t2, prk[:], append(t1[:], 0x3)) HMAC2(t2, prk[:], t1[:], []byte{0x3})
prk = [blake2s.Size]byte{} setZero(prk[:])
return return
} }
@ -55,6 +65,12 @@ func isZero(val []byte) bool {
return acc == 0 return acc == 0
} }
func setZero(arr []byte) {
for i := range arr {
arr[i] = 0
}
}
/* curve25519 wrappers */ /* curve25519 wrappers */
func newPrivateKey() (sk NoisePrivateKey, err error) { func newPrivateKey() (sk NoisePrivateKey, err error) {

View file

@ -109,27 +109,31 @@ var (
ZeroNonce [chacha20poly1305.NonceSize]byte ZeroNonce [chacha20poly1305.NonceSize]byte
) )
func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte { func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
return KDF1(c[:], data) KDF1(dst, c[:], data)
} }
func mixHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte { func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
return blake2s.Sum256(append(h[:], data...)) hsh, _ := blake2s.New256(nil)
hsh.Write(h[:])
hsh.Write(data)
hsh.Sum(dst[:0])
hsh.Reset()
} }
func (h *Handshake) mixHash(data []byte) { func (h *Handshake) mixHash(data []byte) {
h.hash = mixHash(h.hash, data) mixHash(&h.hash, &h.hash, data)
} }
func (h *Handshake) mixKey(data []byte) { func (h *Handshake) mixKey(data []byte) {
h.chainKey = mixKey(h.chainKey, data) mixKey(&h.chainKey, &h.chainKey, data)
} }
/* Do basic precomputations /* Do basic precomputations
*/ */
func init() { func init() {
InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
InitialHash = mixHash(InitialChainKey, []byte(WGIdentifier)) mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
} }
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
@ -176,7 +180,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
func() { func() {
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
handshake.chainKey, key = KDF2(handshake.chainKey[:], ss[:]) KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
ss[:],
)
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:]) aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
}() }()
@ -187,7 +196,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
timestamp := Timestamp() timestamp := Timestamp()
func() { func() {
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
handshake.chainKey, key = KDF2( KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:], handshake.chainKey[:],
handshake.precomputedStaticStatic[:], handshake.precomputedStaticStatic[:],
) )
@ -197,7 +208,6 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(msg.Timestamp[:]) handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitiationCreated handshake.state = HandshakeInitiationCreated
return &msg, nil return &msg, nil
} }
@ -206,9 +216,14 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
return nil return nil
} }
hash := mixHash(InitialHash, device.publicKey[:]) var (
hash = mixHash(hash, msg.Ephemeral[:]) hash [blake2s.Size]byte
chainKey := mixKey(InitialChainKey, msg.Ephemeral[:]) chainKey [blake2s.Size]byte
)
mixHash(&hash, &InitialHash, device.publicKey[:])
mixHash(&hash, &hash, msg.Ephemeral[:])
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key // decrypt static key
@ -217,14 +232,14 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
func() { func() {
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
ss := device.privateKey.sharedSecret(msg.Ephemeral) ss := device.privateKey.sharedSecret(msg.Ephemeral)
chainKey, key = KDF2(chainKey[:], ss[:]) KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
}() }()
if err != nil { if err != nil {
return nil return nil
} }
hash = mixHash(hash, msg.Static[:]) mixHash(&hash, &hash, msg.Static[:])
// lookup peer // lookup peer
@ -244,7 +259,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock() handshake.mutex.RLock()
chainKey, key = KDF2( KDF2(
&chainKey,
&key,
chainKey[:], chainKey[:],
handshake.precomputedStaticStatic[:], handshake.precomputedStaticStatic[:],
) )
@ -254,7 +271,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
handshake.mutex.RUnlock() handshake.mutex.RUnlock()
return nil return nil
} }
hash = mixHash(hash, msg.Timestamp[:]) mixHash(&hash, &hash, msg.Timestamp[:])
// protect against replay & flood // protect against replay & flood
@ -327,7 +344,15 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
var tau [blake2s.Size]byte var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
KDF3(
&handshake.chainKey,
&tau,
&key,
handshake.chainKey[:],
handshake.presharedKey[:],
)
handshake.mixHash(tau[:]) handshake.mixHash(tau[:])
func() { func() {
@ -337,6 +362,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}() }()
handshake.state = HandshakeResponseCreated handshake.state = HandshakeResponseCreated
return &msg, nil return &msg, nil
} }
@ -371,22 +397,33 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// finish 3-way DH // finish 3-way DH
hash = mixHash(handshake.hash, msg.Ephemeral[:]) mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
chainKey = mixKey(handshake.chainKey, msg.Ephemeral[:]) mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
func() { func() {
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
chainKey = mixKey(chainKey, ss[:]) mixKey(&chainKey, &chainKey, ss[:])
ss = device.privateKey.sharedSecret(msg.Ephemeral) setZero(ss[:])
chainKey = mixKey(chainKey, ss[:]) }()
func() {
ss := device.privateKey.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
}() }()
// add preshared key (psk) // add preshared key (psk)
var tau [blake2s.Size]byte var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:]) KDF3(
hash = mixHash(hash, tau[:]) &chainKey,
&tau,
&key,
chainKey[:],
handshake.presharedKey[:],
)
mixHash(&hash, &hash, tau[:])
// authenticate // authenticate
@ -396,7 +433,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.log.Debug.Println("failed to open") device.log.Debug.Println("failed to open")
return false return false
} }
hash = mixHash(hash, msg.Empty[:]) mixHash(&hash, &hash, msg.Empty[:])
return true return true
}() }()
@ -415,6 +452,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.mutex.Unlock() handshake.mutex.Unlock()
setZero(hash[:])
setZero(chainKey[:])
return lookup.peer return lookup.peer
} }
@ -422,6 +462,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
* *
*/ */
func (peer *Peer) NewKeyPair() *KeyPair { func (peer *Peer) NewKeyPair() *KeyPair {
device := peer.device
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
@ -433,10 +474,20 @@ func (peer *Peer) NewKeyPair() *KeyPair {
var recvKey [chacha20poly1305.KeySize]byte var recvKey [chacha20poly1305.KeySize]byte
if handshake.state == HandshakeResponseConsumed { if handshake.state == HandshakeResponseConsumed {
sendKey, recvKey = KDF2(handshake.chainKey[:], nil) KDF2(
&sendKey,
&recvKey,
handshake.chainKey[:],
nil,
)
isInitiator = true isInitiator = true
} else if handshake.state == HandshakeResponseCreated { } else if handshake.state == HandshakeResponseCreated {
recvKey, sendKey = KDF2(handshake.chainKey[:], nil) KDF2(
&recvKey,
&sendKey,
handshake.chainKey[:],
nil,
)
isInitiator = false isInitiator = false
} else { } else {
return nil return nil
@ -444,16 +495,20 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// zero handshake // zero handshake
handshake.chainKey = [blake2s.Size]byte{} setZero(handshake.chainKey[:])
handshake.localEphemeral = NoisePrivateKey{} setZero(handshake.localEphemeral[:])
peer.handshake.state = HandshakeZeroed peer.handshake.state = HandshakeZeroed
// create AEAD instances // create AEAD instances
keyPair := new(KeyPair) keyPair := new(KeyPair)
keyPair.send.setKey(&sendKey)
keyPair.receive.setKey(&recvKey)
setZero(sendKey[:])
setZero(recvKey[:])
keyPair.created = time.Now() keyPair.created = time.Now()
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
keyPair.sendNonce = 0 keyPair.sendNonce = 0
keyPair.replayFilter.Init() keyPair.replayFilter.Init()
keyPair.isInitiator = isInitiator keyPair.isInitiator = isInitiator
@ -462,12 +517,14 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// remap index // remap index
indices := &peer.device.indices device.indices.Insert(
indices.Insert(handshake.localIndex, IndexTableEntry{ handshake.localIndex,
IndexTableEntry{
peer: peer, peer: peer,
keyPair: keyPair, keyPair: keyPair,
handshake: nil, handshake: nil,
}) },
)
handshake.localIndex = 0 handshake.localIndex = 0
// rotate key pairs // rotate key pairs
@ -479,7 +536,8 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// TODO: Adapt kernel behavior noise.c:161 // TODO: Adapt kernel behavior noise.c:161
if isInitiator { if isInitiator {
if kp.previous != nil { if kp.previous != nil {
indices.Delete(kp.previous.localIndex) device.DeleteKeyPair(kp.previous)
kp.previous = nil
} }
if kp.next != nil { if kp.next != nil {

View file

@ -251,7 +251,12 @@ func (device *Device) RoutineDecryption() {
var err error var err error
copy(nonce[4:], counter) copy(nonce[4:], counter)
elem.counter = binary.LittleEndian.Uint64(counter) elem.counter = binary.LittleEndian.Uint64(counter)
elem.packet, err = elem.keyPair.receive.Open( elem.keyPair.receive.mutex.RLock()
if elem.keyPair.receive.aead == nil {
// very unlikely (the key was deleted during queuing)
elem.Drop()
} else {
elem.packet, err = elem.keyPair.receive.aead.Open(
elem.buffer[:0], elem.buffer[:0],
nonce[:], nonce[:],
content, content,
@ -260,6 +265,8 @@ func (device *Device) RoutineDecryption() {
if err != nil { if err != nil {
elem.Drop() elem.Drop()
} }
}
elem.keyPair.receive.mutex.RUnlock()
elem.mutex.Unlock() elem.mutex.Unlock()
} }
} }
@ -507,6 +514,9 @@ func (peer *Peer) RoutineSequentialReceiver() {
kp.mutex.Lock() kp.mutex.Lock()
if kp.next == elem.keyPair { if kp.next == elem.keyPair {
peer.TimerHandshakeComplete() peer.TimerHandshakeComplete()
if kp.previous != nil {
device.DeleteKeyPair(kp.previous)
}
kp.previous = kp.current kp.previous = kp.current
kp.current = kp.next kp.current = kp.next
kp.next = nil kp.next = nil

View file

@ -349,12 +349,19 @@ func (device *Device) RoutineEncryption() {
// encrypt content (append to header) // encrypt content (append to header)
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
elem.packet = elem.keyPair.send.Seal( elem.keyPair.send.mutex.RLock()
if elem.keyPair.send.aead == nil {
// very unlikely (the key was deleted during queuing)
elem.Drop()
} else {
elem.packet = elem.keyPair.send.aead.Seal(
header, header,
nonce[:], nonce[:],
elem.packet, elem.packet,
nil, nil,
) )
}
elem.keyPair.send.mutex.RUnlock()
elem.mutex.Unlock() elem.mutex.Unlock()
// refresh key if necessary // refresh key if necessary

View file

@ -3,7 +3,6 @@ package main
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"golang.org/x/crypto/blake2s"
"math/rand" "math/rand"
"sync/atomic" "sync/atomic"
"time" "time"
@ -134,7 +133,6 @@ func (peer *Peer) TimerEphemeralKeyCreated() {
func (peer *Peer) RoutineTimerHandler() { func (peer *Peer) RoutineTimerHandler() {
device := peer.device device := peer.device
indices := &device.indices
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.String()) logDebug.Println("Routine, timer handler, started for peer", peer.String())
@ -186,35 +184,31 @@ func (peer *Peer) RoutineTimerHandler() {
kp := &peer.keyPairs kp := &peer.keyPairs
kp.mutex.Lock() kp.mutex.Lock()
// unmap indecies // remove key-pairs
indices.mutex.Lock()
if kp.previous != nil { if kp.previous != nil {
delete(indices.table, kp.previous.localIndex) device.DeleteKeyPair(kp.previous)
kp.previous = nil
} }
if kp.current != nil { if kp.current != nil {
delete(indices.table, kp.current.localIndex) device.DeleteKeyPair(kp.current)
kp.current = nil
} }
if kp.next != nil { if kp.next != nil {
delete(indices.table, kp.next.localIndex) device.DeleteKeyPair(kp.next)
}
delete(indices.table, hs.localIndex)
indices.mutex.Unlock()
// zero out key pairs (TODO: better than wait for GC)
kp.current = nil
kp.previous = nil
kp.next = nil kp.next = nil
}
kp.mutex.Unlock() kp.mutex.Unlock()
// zero out handshake // zero out handshake
device.indices.Delete(hs.localIndex)
hs.localIndex = 0 hs.localIndex = 0
hs.localEphemeral = NoisePrivateKey{} setZero(hs.localEphemeral[:])
hs.remoteEphemeral = NoisePublicKey{} setZero(hs.remoteEphemeral[:])
hs.chainKey = [blake2s.Size]byte{} setZero(hs.chainKey[:])
hs.hash = [blake2s.Size]byte{} setZero(hs.hash[:])
hs.mutex.Unlock() hs.mutex.Unlock()
} }
} }

View file

@ -63,6 +63,8 @@ func (tun *NativeTun) RoutineNetlinkListener() {
return return
} }
tun.events <- TUNEventUp // TODO: Fix network namespace problem
for msg := make([]byte, 1<<16); ; { for msg := make([]byte, 1<<16); ; {
msgn, _, _, _, err := unix.Recvmsg(sock, msg[:], nil, 0) msgn, _, _, _, err := unix.Recvmsg(sock, msg[:], nil, 0)