diff --git a/src/keypair.go b/src/keypair.go index ba9c437..644d040 100644 --- a/src/keypair.go +++ b/src/keypair.go @@ -2,14 +2,39 @@ package main import ( "crypto/cipher" + "golang.org/x/crypto/chacha20poly1305" + "reflect" "sync" "time" ) +type safeAEAD struct { + mutex sync.RWMutex + aead cipher.AEAD +} + +func (con *safeAEAD) clear() { + // TODO: improve handling of key material + con.mutex.Lock() + if con.aead != nil { + val := reflect.ValueOf(con.aead) + elm := val.Elem() + typ := elm.Type() + elm.Set(reflect.Zero(typ)) + con.aead = nil + } + con.mutex.Unlock() +} + +func (con *safeAEAD) setKey(key *[chacha20poly1305.KeySize]byte) { + // TODO: improve handling of key material + con.aead, _ = chacha20poly1305.New(key[:]) +} + type KeyPair struct { - receive cipher.AEAD + send safeAEAD + receive safeAEAD replayFilter ReplayFilter - send cipher.AEAD sendNonce uint64 isInitiator bool created time.Time @@ -31,7 +56,7 @@ func (kp *KeyPairs) Current() *KeyPair { } func (device *Device) DeleteKeyPair(key *KeyPair) { - key.send = nil - key.receive = nil + key.send.clear() + key.receive.clear() device.indices.Delete(key.localIndex) } diff --git a/src/noise_helpers.go b/src/noise_helpers.go index 105f78f..24302c0 100644 --- a/src/noise_helpers.go +++ b/src/noise_helpers.go @@ -13,37 +13,47 @@ import ( * https://tools.ietf.org/html/rfc5869 */ -func HMAC(sum *[blake2s.Size]byte, key []byte, input []byte) { +func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) { mac := hmac.New(func() hash.Hash { h, _ := blake2s.New256(nil) return h }, key) - mac.Write(input) + mac.Write(in0) mac.Sum(sum[:0]) } -func KDF1(key []byte, input []byte) (t0 [blake2s.Size]byte) { - HMAC(&t0, key, input) - HMAC(&t0, t0[:], []byte{0x1}) +func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) { + mac := hmac.New(func() hash.Hash { + h, _ := blake2s.New256(nil) + return h + }, key) + mac.Write(in0) + mac.Write(in1) + mac.Sum(sum[:0]) +} + +func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { + HMAC1(t0, key, input) + HMAC1(t0, t0[:], []byte{0x1}) return } -func KDF2(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte) { +func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { var prk [blake2s.Size]byte - HMAC(&prk, key, input) - HMAC(&t0, prk[:], []byte{0x1}) - HMAC(&t1, prk[:], append(t0[:], 0x2)) - prk = [blake2s.Size]byte{} + HMAC1(&prk, key, input) + HMAC1(t0, prk[:], []byte{0x1}) + HMAC2(t1, prk[:], t0[:], []byte{0x2}) + setZero(prk[:]) return } -func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byte, t2 [blake2s.Size]byte) { +func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { var prk [blake2s.Size]byte - HMAC(&prk, key, input) - HMAC(&t0, prk[:], []byte{0x1}) - HMAC(&t1, prk[:], append(t0[:], 0x2)) - HMAC(&t2, prk[:], append(t1[:], 0x3)) - prk = [blake2s.Size]byte{} + HMAC1(&prk, key, input) + HMAC1(t0, prk[:], []byte{0x1}) + HMAC2(t1, prk[:], t0[:], []byte{0x2}) + HMAC2(t2, prk[:], t1[:], []byte{0x3}) + setZero(prk[:]) return } @@ -55,6 +65,12 @@ func isZero(val []byte) bool { return acc == 0 } +func setZero(arr []byte) { + for i := range arr { + arr[i] = 0 + } +} + /* curve25519 wrappers */ func newPrivateKey() (sk NoisePrivateKey, err error) { diff --git a/src/noise_protocol.go b/src/noise_protocol.go index 1f1301e..a50e3dc 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -109,27 +109,31 @@ var ( ZeroNonce [chacha20poly1305.NonceSize]byte ) -func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte { - return KDF1(c[:], data) +func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { + KDF1(dst, c[:], data) } -func mixHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte { - return blake2s.Sum256(append(h[:], data...)) +func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { + hsh, _ := blake2s.New256(nil) + hsh.Write(h[:]) + hsh.Write(data) + hsh.Sum(dst[:0]) + hsh.Reset() } func (h *Handshake) mixHash(data []byte) { - h.hash = mixHash(h.hash, data) + mixHash(&h.hash, &h.hash, data) } func (h *Handshake) mixKey(data []byte) { - h.chainKey = mixKey(h.chainKey, data) + mixKey(&h.chainKey, &h.chainKey, data) } /* Do basic precomputations */ func init() { InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) - InitialHash = mixHash(InitialChainKey, []byte(WGIdentifier)) + mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) } func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { @@ -176,7 +180,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e func() { var key [chacha20poly1305.KeySize]byte ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - handshake.chainKey, key = KDF2(handshake.chainKey[:], ss[:]) + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + ss[:], + ) aead, _ := chacha20poly1305.New(key[:]) aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:]) }() @@ -187,7 +196,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e timestamp := Timestamp() func() { var key [chacha20poly1305.KeySize]byte - handshake.chainKey, key = KDF2( + KDF2( + &handshake.chainKey, + &key, handshake.chainKey[:], handshake.precomputedStaticStatic[:], ) @@ -197,7 +208,6 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(msg.Timestamp[:]) handshake.state = HandshakeInitiationCreated - return &msg, nil } @@ -206,9 +216,14 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { return nil } - hash := mixHash(InitialHash, device.publicKey[:]) - hash = mixHash(hash, msg.Ephemeral[:]) - chainKey := mixKey(InitialChainKey, msg.Ephemeral[:]) + var ( + hash [blake2s.Size]byte + chainKey [blake2s.Size]byte + ) + + mixHash(&hash, &InitialHash, device.publicKey[:]) + mixHash(&hash, &hash, msg.Ephemeral[:]) + mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) // decrypt static key @@ -217,14 +232,14 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { func() { var key [chacha20poly1305.KeySize]byte ss := device.privateKey.sharedSecret(msg.Ephemeral) - chainKey, key = KDF2(chainKey[:], ss[:]) + KDF2(&chainKey, &key, chainKey[:], ss[:]) aead, _ := chacha20poly1305.New(key[:]) _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) }() if err != nil { return nil } - hash = mixHash(hash, msg.Static[:]) + mixHash(&hash, &hash, msg.Static[:]) // lookup peer @@ -244,7 +259,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { var key [chacha20poly1305.KeySize]byte handshake.mutex.RLock() - chainKey, key = KDF2( + KDF2( + &chainKey, + &key, chainKey[:], handshake.precomputedStaticStatic[:], ) @@ -254,7 +271,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { handshake.mutex.RUnlock() return nil } - hash = mixHash(hash, msg.Timestamp[:]) + mixHash(&hash, &hash, msg.Timestamp[:]) // protect against replay & flood @@ -327,7 +344,15 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error var tau [blake2s.Size]byte var key [chacha20poly1305.KeySize]byte - handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:]) + + KDF3( + &handshake.chainKey, + &tau, + &key, + handshake.chainKey[:], + handshake.presharedKey[:], + ) + handshake.mixHash(tau[:]) func() { @@ -337,6 +362,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error }() handshake.state = HandshakeResponseCreated + return &msg, nil } @@ -371,22 +397,33 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { // finish 3-way DH - hash = mixHash(handshake.hash, msg.Ephemeral[:]) - chainKey = mixKey(handshake.chainKey, msg.Ephemeral[:]) + mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) + mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) func() { ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) - chainKey = mixKey(chainKey, ss[:]) - ss = device.privateKey.sharedSecret(msg.Ephemeral) - chainKey = mixKey(chainKey, ss[:]) + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) + }() + + func() { + ss := device.privateKey.sharedSecret(msg.Ephemeral) + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) }() // add preshared key (psk) var tau [blake2s.Size]byte var key [chacha20poly1305.KeySize]byte - chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:]) - hash = mixHash(hash, tau[:]) + KDF3( + &chainKey, + &tau, + &key, + chainKey[:], + handshake.presharedKey[:], + ) + mixHash(&hash, &hash, tau[:]) // authenticate @@ -396,7 +433,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { device.log.Debug.Println("failed to open") return false } - hash = mixHash(hash, msg.Empty[:]) + mixHash(&hash, &hash, msg.Empty[:]) return true }() @@ -415,6 +452,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { handshake.mutex.Unlock() + setZero(hash[:]) + setZero(chainKey[:]) + return lookup.peer } @@ -422,6 +462,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { * */ func (peer *Peer) NewKeyPair() *KeyPair { + device := peer.device handshake := &peer.handshake handshake.mutex.Lock() defer handshake.mutex.Unlock() @@ -433,10 +474,20 @@ func (peer *Peer) NewKeyPair() *KeyPair { var recvKey [chacha20poly1305.KeySize]byte if handshake.state == HandshakeResponseConsumed { - sendKey, recvKey = KDF2(handshake.chainKey[:], nil) + KDF2( + &sendKey, + &recvKey, + handshake.chainKey[:], + nil, + ) isInitiator = true } else if handshake.state == HandshakeResponseCreated { - recvKey, sendKey = KDF2(handshake.chainKey[:], nil) + KDF2( + &recvKey, + &sendKey, + handshake.chainKey[:], + nil, + ) isInitiator = false } else { return nil @@ -444,16 +495,20 @@ func (peer *Peer) NewKeyPair() *KeyPair { // zero handshake - handshake.chainKey = [blake2s.Size]byte{} - handshake.localEphemeral = NoisePrivateKey{} + setZero(handshake.chainKey[:]) + setZero(handshake.localEphemeral[:]) peer.handshake.state = HandshakeZeroed // create AEAD instances keyPair := new(KeyPair) + keyPair.send.setKey(&sendKey) + keyPair.receive.setKey(&recvKey) + + setZero(sendKey[:]) + setZero(recvKey[:]) + keyPair.created = time.Now() - keyPair.send, _ = chacha20poly1305.New(sendKey[:]) - keyPair.receive, _ = chacha20poly1305.New(recvKey[:]) keyPair.sendNonce = 0 keyPair.replayFilter.Init() keyPair.isInitiator = isInitiator @@ -462,12 +517,14 @@ func (peer *Peer) NewKeyPair() *KeyPair { // remap index - indices := &peer.device.indices - indices.Insert(handshake.localIndex, IndexTableEntry{ - peer: peer, - keyPair: keyPair, - handshake: nil, - }) + device.indices.Insert( + handshake.localIndex, + IndexTableEntry{ + peer: peer, + keyPair: keyPair, + handshake: nil, + }, + ) handshake.localIndex = 0 // rotate key pairs @@ -479,7 +536,8 @@ func (peer *Peer) NewKeyPair() *KeyPair { // TODO: Adapt kernel behavior noise.c:161 if isInitiator { if kp.previous != nil { - indices.Delete(kp.previous.localIndex) + device.DeleteKeyPair(kp.previous) + kp.previous = nil } if kp.next != nil { diff --git a/src/receive.go b/src/receive.go index ca7bb6e..97646d8 100644 --- a/src/receive.go +++ b/src/receive.go @@ -251,15 +251,22 @@ func (device *Device) RoutineDecryption() { var err error copy(nonce[4:], counter) elem.counter = binary.LittleEndian.Uint64(counter) - elem.packet, err = elem.keyPair.receive.Open( - elem.buffer[:0], - nonce[:], - content, - nil, - ) - if err != nil { + elem.keyPair.receive.mutex.RLock() + if elem.keyPair.receive.aead == nil { + // very unlikely (the key was deleted during queuing) elem.Drop() + } else { + elem.packet, err = elem.keyPair.receive.aead.Open( + elem.buffer[:0], + nonce[:], + content, + nil, + ) + if err != nil { + elem.Drop() + } } + elem.keyPair.receive.mutex.RUnlock() elem.mutex.Unlock() } } @@ -507,6 +514,9 @@ func (peer *Peer) RoutineSequentialReceiver() { kp.mutex.Lock() if kp.next == elem.keyPair { peer.TimerHandshakeComplete() + if kp.previous != nil { + device.DeleteKeyPair(kp.previous) + } kp.previous = kp.current kp.current = kp.next kp.next = nil diff --git a/src/send.go b/src/send.go index 7d4014a..c598ad4 100644 --- a/src/send.go +++ b/src/send.go @@ -349,12 +349,19 @@ func (device *Device) RoutineEncryption() { // encrypt content (append to header) binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.packet = elem.keyPair.send.Seal( - header, - nonce[:], - elem.packet, - nil, - ) + elem.keyPair.send.mutex.RLock() + if elem.keyPair.send.aead == nil { + // very unlikely (the key was deleted during queuing) + elem.Drop() + } else { + elem.packet = elem.keyPair.send.aead.Seal( + header, + nonce[:], + elem.packet, + nil, + ) + } + elem.keyPair.send.mutex.RUnlock() elem.mutex.Unlock() // refresh key if necessary diff --git a/src/timers.go b/src/timers.go index de54a96..ad8866f 100644 --- a/src/timers.go +++ b/src/timers.go @@ -3,7 +3,6 @@ package main import ( "bytes" "encoding/binary" - "golang.org/x/crypto/blake2s" "math/rand" "sync/atomic" "time" @@ -134,7 +133,6 @@ func (peer *Peer) TimerEphemeralKeyCreated() { func (peer *Peer) RoutineTimerHandler() { device := peer.device - indices := &device.indices logDebug := device.log.Debug logDebug.Println("Routine, timer handler, started for peer", peer.String()) @@ -186,35 +184,31 @@ func (peer *Peer) RoutineTimerHandler() { kp := &peer.keyPairs kp.mutex.Lock() - // unmap indecies + // remove key-pairs - indices.mutex.Lock() if kp.previous != nil { - delete(indices.table, kp.previous.localIndex) + device.DeleteKeyPair(kp.previous) + kp.previous = nil } if kp.current != nil { - delete(indices.table, kp.current.localIndex) + device.DeleteKeyPair(kp.current) + kp.current = nil } if kp.next != nil { - delete(indices.table, kp.next.localIndex) + device.DeleteKeyPair(kp.next) + kp.next = nil } - delete(indices.table, hs.localIndex) - indices.mutex.Unlock() - - // zero out key pairs (TODO: better than wait for GC) - - kp.current = nil - kp.previous = nil - kp.next = nil kp.mutex.Unlock() // zero out handshake + device.indices.Delete(hs.localIndex) + hs.localIndex = 0 - hs.localEphemeral = NoisePrivateKey{} - hs.remoteEphemeral = NoisePublicKey{} - hs.chainKey = [blake2s.Size]byte{} - hs.hash = [blake2s.Size]byte{} + setZero(hs.localEphemeral[:]) + setZero(hs.remoteEphemeral[:]) + setZero(hs.chainKey[:]) + setZero(hs.hash[:]) hs.mutex.Unlock() } } diff --git a/src/tun_linux.go b/src/tun_linux.go index b9541c9..58a762a 100644 --- a/src/tun_linux.go +++ b/src/tun_linux.go @@ -63,6 +63,8 @@ func (tun *NativeTun) RoutineNetlinkListener() { return } + tun.events <- TUNEventUp // TODO: Fix network namespace problem + for msg := make([]byte, 1<<16); ; { msgn, _, _, _, err := unix.Recvmsg(sock, msg[:], nil, 0)