Added replay protection
This commit is contained in:
parent
4ad62aaa6a
commit
44c9896883
|
@ -8,6 +8,7 @@ import (
|
||||||
|
|
||||||
type KeyPair struct {
|
type KeyPair struct {
|
||||||
receive cipher.AEAD
|
receive cipher.AEAD
|
||||||
|
replayFilter ReplayFilter
|
||||||
send cipher.AEAD
|
send cipher.AEAD
|
||||||
sendNonce uint64
|
sendNonce uint64
|
||||||
isInitiator bool
|
isInitiator bool
|
||||||
|
|
|
@ -19,6 +19,13 @@ func min(a uint, b uint) uint {
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func minUint64(a uint64, b uint64) uint64 {
|
||||||
|
if a > b {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
func signalSend(c chan struct{}) {
|
func signalSend(c chan struct{}) {
|
||||||
select {
|
select {
|
||||||
case c <- struct{}{}:
|
case c <- struct{}{}:
|
||||||
|
|
|
@ -415,6 +415,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
return lookup.peer
|
return lookup.peer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Derives a new key-pair from the current handshake state
|
||||||
|
*
|
||||||
|
*/
|
||||||
func (peer *Peer) NewKeyPair() *KeyPair {
|
func (peer *Peer) NewKeyPair() *KeyPair {
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
|
@ -445,10 +448,11 @@ func (peer *Peer) NewKeyPair() *KeyPair {
|
||||||
// create AEAD instances
|
// create AEAD instances
|
||||||
|
|
||||||
keyPair := new(KeyPair)
|
keyPair := new(KeyPair)
|
||||||
|
keyPair.created = time.Now()
|
||||||
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
|
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
|
||||||
keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
|
keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
|
||||||
keyPair.sendNonce = 0
|
keyPair.sendNonce = 0
|
||||||
keyPair.created = time.Now()
|
keyPair.replayFilter.Init()
|
||||||
keyPair.isInitiator = isInitiator
|
keyPair.isInitiator = isInitiator
|
||||||
keyPair.localIndex = peer.handshake.localIndex
|
keyPair.localIndex = peer.handshake.localIndex
|
||||||
keyPair.remoteIndex = peer.handshake.remoteIndex
|
keyPair.remoteIndex = peer.handshake.remoteIndex
|
||||||
|
@ -462,8 +466,6 @@ func (peer *Peer) NewKeyPair() *KeyPair {
|
||||||
})
|
})
|
||||||
handshake.localIndex = 0
|
handshake.localIndex = 0
|
||||||
|
|
||||||
// TODO: start timer for keypair (clearing)
|
|
||||||
|
|
||||||
// rotate key pairs
|
// rotate key pairs
|
||||||
|
|
||||||
kp := &peer.keyPairs
|
kp := &peer.keyPairs
|
||||||
|
|
|
@ -432,6 +432,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||||
|
|
||||||
// check for replay
|
// check for replay
|
||||||
|
|
||||||
|
if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// time (passive) keep-alive
|
// time (passive) keep-alive
|
||||||
|
|
||||||
peer.TimerStartKeepalive()
|
peer.TimerStartKeepalive()
|
||||||
|
|
71
src/replay.go
Normal file
71
src/replay.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
/* Implementation of RFC6479
|
||||||
|
* https://tools.ietf.org/html/rfc6479
|
||||||
|
*
|
||||||
|
* The implementation is not safe for concurrent use!
|
||||||
|
*/
|
||||||
|
|
||||||
|
const (
|
||||||
|
// See: https://golang.org/src/math/big/arith.go
|
||||||
|
_Wordm = ^uintptr(0)
|
||||||
|
_WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
|
||||||
|
_WordSize = 1 << _WordLogSize
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CounterRedundantBitsLog = _WordLogSize + 3
|
||||||
|
CounterRedundantBits = _WordSize * 8
|
||||||
|
CounterBitsTotal = 2048
|
||||||
|
CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
BacktrackWords = CounterBitsTotal / _WordSize
|
||||||
|
)
|
||||||
|
|
||||||
|
type ReplayFilter struct {
|
||||||
|
counter uint64
|
||||||
|
backtrack [BacktrackWords]uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (filter *ReplayFilter) Init() {
|
||||||
|
filter.counter = 0
|
||||||
|
filter.backtrack[0] = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (filter *ReplayFilter) ValidateCounter(counter uint64) bool {
|
||||||
|
if counter >= RejectAfterMessages {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
indexWord := counter >> CounterRedundantBitsLog
|
||||||
|
|
||||||
|
if counter > filter.counter {
|
||||||
|
|
||||||
|
// move window forward
|
||||||
|
|
||||||
|
current := filter.counter >> CounterRedundantBitsLog
|
||||||
|
diff := minUint64(indexWord-current, BacktrackWords)
|
||||||
|
for i := uint64(1); i <= diff; i++ {
|
||||||
|
filter.backtrack[(current+i)%BacktrackWords] = 0
|
||||||
|
}
|
||||||
|
filter.counter = counter
|
||||||
|
|
||||||
|
} else if filter.counter-counter > CounterWindowSize {
|
||||||
|
|
||||||
|
// behind current window
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
indexWord %= BacktrackWords
|
||||||
|
indexBit := counter & uint64(CounterRedundantBits-1)
|
||||||
|
|
||||||
|
// check and set bit
|
||||||
|
|
||||||
|
oldValue := filter.backtrack[indexWord]
|
||||||
|
newValue := oldValue | (1 << indexBit)
|
||||||
|
filter.backtrack[indexWord] = newValue
|
||||||
|
return oldValue != newValue
|
||||||
|
}
|
114
src/replay_test.go
Normal file
114
src/replay_test.go
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Ported from the linux kernel implementation
|
||||||
|
*
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
|
||||||
|
|
||||||
|
func TestReplay(t *testing.T) {
|
||||||
|
var filter ReplayFilter
|
||||||
|
|
||||||
|
T_LIM := CounterWindowSize + 1
|
||||||
|
|
||||||
|
testNumber := 0
|
||||||
|
T := func(n uint64, v bool) {
|
||||||
|
testNumber++
|
||||||
|
if filter.ValidateCounter(n) != v {
|
||||||
|
t.Fatal("Test", testNumber, "failed", n, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filter.Init()
|
||||||
|
|
||||||
|
/* 1 */ T(0, true)
|
||||||
|
/* 2 */ T(1, true)
|
||||||
|
/* 3 */ T(1, false)
|
||||||
|
/* 4 */ T(9, true)
|
||||||
|
/* 5 */ T(8, true)
|
||||||
|
/* 6 */ T(7, true)
|
||||||
|
/* 7 */ T(7, false)
|
||||||
|
/* 8 */ T(T_LIM, true)
|
||||||
|
/* 9 */ T(T_LIM-1, true)
|
||||||
|
/* 10 */ T(T_LIM-1, false)
|
||||||
|
/* 11 */ T(T_LIM-2, true)
|
||||||
|
/* 12 */ T(2, true)
|
||||||
|
/* 13 */ T(2, false)
|
||||||
|
/* 14 */ T(T_LIM+16, true)
|
||||||
|
/* 15 */ T(3, false)
|
||||||
|
/* 16 */ T(T_LIM+16, false)
|
||||||
|
/* 17 */ T(T_LIM*4, true)
|
||||||
|
/* 18 */ T(T_LIM*4-(T_LIM-1), true)
|
||||||
|
/* 19 */ T(10, false)
|
||||||
|
/* 20 */ T(T_LIM*4-T_LIM, false)
|
||||||
|
/* 21 */ T(T_LIM*4-(T_LIM+1), false)
|
||||||
|
/* 22 */ T(T_LIM*4-(T_LIM-2), true)
|
||||||
|
/* 23 */ T(T_LIM*4+1-T_LIM, false)
|
||||||
|
/* 24 */ T(0, false)
|
||||||
|
/* 25 */ T(RejectAfterMessages, false)
|
||||||
|
/* 26 */ T(RejectAfterMessages-1, true)
|
||||||
|
/* 27 */ T(RejectAfterMessages, false)
|
||||||
|
/* 28 */ T(RejectAfterMessages-1, false)
|
||||||
|
/* 29 */ T(RejectAfterMessages-2, true)
|
||||||
|
/* 30 */ T(RejectAfterMessages+1, false)
|
||||||
|
/* 31 */ T(RejectAfterMessages+2, false)
|
||||||
|
/* 32 */ T(RejectAfterMessages-2, false)
|
||||||
|
/* 33 */ T(RejectAfterMessages-3, true)
|
||||||
|
/* 34 */ T(0, false)
|
||||||
|
|
||||||
|
t.Log("Bulk test 1")
|
||||||
|
filter.Init()
|
||||||
|
testNumber = 0
|
||||||
|
for i := uint64(1); i <= CounterWindowSize; i++ {
|
||||||
|
T(i, true)
|
||||||
|
}
|
||||||
|
T(0, true)
|
||||||
|
T(0, false)
|
||||||
|
|
||||||
|
t.Log("Bulk test 2")
|
||||||
|
filter.Init()
|
||||||
|
testNumber = 0
|
||||||
|
for i := uint64(2); i <= CounterWindowSize+1; i++ {
|
||||||
|
T(i, true)
|
||||||
|
}
|
||||||
|
T(1, true)
|
||||||
|
T(0, false)
|
||||||
|
|
||||||
|
t.Log("Bulk test 3")
|
||||||
|
filter.Init()
|
||||||
|
testNumber = 0
|
||||||
|
for i := CounterWindowSize + 1; i > 0; i-- {
|
||||||
|
T(i, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Bulk test 4")
|
||||||
|
filter.Init()
|
||||||
|
testNumber = 0
|
||||||
|
for i := CounterWindowSize + 2; i > 1; i-- {
|
||||||
|
T(i, true)
|
||||||
|
}
|
||||||
|
T(0, false)
|
||||||
|
|
||||||
|
t.Log("Bulk test 5")
|
||||||
|
filter.Init()
|
||||||
|
testNumber = 0
|
||||||
|
for i := CounterWindowSize; i > 0; i-- {
|
||||||
|
T(i, true)
|
||||||
|
}
|
||||||
|
T(CounterWindowSize+1, true)
|
||||||
|
T(0, false)
|
||||||
|
|
||||||
|
t.Log("Bulk test 6")
|
||||||
|
filter.Init()
|
||||||
|
testNumber = 0
|
||||||
|
for i := CounterWindowSize; i > 0; i-- {
|
||||||
|
T(i, true)
|
||||||
|
}
|
||||||
|
T(0, true)
|
||||||
|
T(CounterWindowSize+1, true)
|
||||||
|
}
|
|
@ -12,22 +12,15 @@ import (
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) KeepKeyFreshSending() {
|
func (peer *Peer) KeepKeyFreshSending() {
|
||||||
send := func() bool {
|
kp := peer.keyPairs.Current()
|
||||||
peer.keyPairs.mutex.RLock()
|
|
||||||
defer peer.keyPairs.mutex.RUnlock()
|
|
||||||
|
|
||||||
kp := peer.keyPairs.current
|
|
||||||
if kp == nil {
|
if kp == nil {
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !kp.isInitiator {
|
if !kp.isInitiator {
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce := atomic.LoadUint64(&kp.sendNonce)
|
nonce := atomic.LoadUint64(&kp.sendNonce)
|
||||||
return nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime
|
send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime
|
||||||
}()
|
|
||||||
if send {
|
if send {
|
||||||
signalSend(peer.signal.handshakeBegin)
|
signalSend(peer.signal.handshakeBegin)
|
||||||
}
|
}
|
||||||
|
@ -37,22 +30,15 @@ func (peer *Peer) KeepKeyFreshSending() {
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) KeepKeyFreshReceiving() {
|
func (peer *Peer) KeepKeyFreshReceiving() {
|
||||||
send := func() bool {
|
kp := peer.keyPairs.Current()
|
||||||
peer.keyPairs.mutex.RLock()
|
|
||||||
defer peer.keyPairs.mutex.RUnlock()
|
|
||||||
|
|
||||||
kp := peer.keyPairs.current
|
|
||||||
if kp == nil {
|
if kp == nil {
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !kp.isInitiator {
|
if !kp.isInitiator {
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce := atomic.LoadUint64(&kp.sendNonce)
|
nonce := atomic.LoadUint64(&kp.sendNonce)
|
||||||
return nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
|
send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
|
||||||
}()
|
|
||||||
if send {
|
if send {
|
||||||
signalSend(peer.signal.handshakeBegin)
|
signalSend(peer.signal.handshakeBegin)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue