Added replay protection

This commit is contained in:
Mathias Hall-Andersen 2017-07-10 12:09:19 +02:00
parent 4ad62aaa6a
commit 44c9896883
7 changed files with 227 additions and 42 deletions

View file

@ -8,6 +8,7 @@ import (
type KeyPair struct { type KeyPair struct {
receive cipher.AEAD receive cipher.AEAD
replayFilter ReplayFilter
send cipher.AEAD send cipher.AEAD
sendNonce uint64 sendNonce uint64
isInitiator bool isInitiator bool

View file

@ -19,6 +19,13 @@ func min(a uint, b uint) uint {
return a return a
} }
func minUint64(a uint64, b uint64) uint64 {
if a > b {
return b
}
return a
}
func signalSend(c chan struct{}) { func signalSend(c chan struct{}) {
select { select {
case c <- struct{}{}: case c <- struct{}{}:

View file

@ -415,6 +415,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
return lookup.peer return lookup.peer
} }
/* Derives a new key-pair from the current handshake state
*
*/
func (peer *Peer) NewKeyPair() *KeyPair { func (peer *Peer) NewKeyPair() *KeyPair {
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
@ -445,10 +448,11 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// create AEAD instances // create AEAD instances
keyPair := new(KeyPair) keyPair := new(KeyPair)
keyPair.created = time.Now()
keyPair.send, _ = chacha20poly1305.New(sendKey[:]) keyPair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.receive, _ = chacha20poly1305.New(recvKey[:]) keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
keyPair.sendNonce = 0 keyPair.sendNonce = 0
keyPair.created = time.Now() keyPair.replayFilter.Init()
keyPair.isInitiator = isInitiator keyPair.isInitiator = isInitiator
keyPair.localIndex = peer.handshake.localIndex keyPair.localIndex = peer.handshake.localIndex
keyPair.remoteIndex = peer.handshake.remoteIndex keyPair.remoteIndex = peer.handshake.remoteIndex
@ -462,8 +466,6 @@ func (peer *Peer) NewKeyPair() *KeyPair {
}) })
handshake.localIndex = 0 handshake.localIndex = 0
// TODO: start timer for keypair (clearing)
// rotate key pairs // rotate key pairs
kp := &peer.keyPairs kp := &peer.keyPairs

View file

@ -432,6 +432,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
// check for replay // check for replay
if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
return
}
// time (passive) keep-alive // time (passive) keep-alive
peer.TimerStartKeepalive() peer.TimerStartKeepalive()

71
src/replay.go Normal file
View file

@ -0,0 +1,71 @@
package main
/* Implementation of RFC6479
* https://tools.ietf.org/html/rfc6479
*
* The implementation is not safe for concurrent use!
*/
const (
// See: https://golang.org/src/math/big/arith.go
_Wordm = ^uintptr(0)
_WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
_WordSize = 1 << _WordLogSize
)
const (
CounterRedundantBitsLog = _WordLogSize + 3
CounterRedundantBits = _WordSize * 8
CounterBitsTotal = 2048
CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
)
const (
BacktrackWords = CounterBitsTotal / _WordSize
)
type ReplayFilter struct {
counter uint64
backtrack [BacktrackWords]uintptr
}
func (filter *ReplayFilter) Init() {
filter.counter = 0
filter.backtrack[0] = 0
}
func (filter *ReplayFilter) ValidateCounter(counter uint64) bool {
if counter >= RejectAfterMessages {
return false
}
indexWord := counter >> CounterRedundantBitsLog
if counter > filter.counter {
// move window forward
current := filter.counter >> CounterRedundantBitsLog
diff := minUint64(indexWord-current, BacktrackWords)
for i := uint64(1); i <= diff; i++ {
filter.backtrack[(current+i)%BacktrackWords] = 0
}
filter.counter = counter
} else if filter.counter-counter > CounterWindowSize {
// behind current window
return false
}
indexWord %= BacktrackWords
indexBit := counter & uint64(CounterRedundantBits-1)
// check and set bit
oldValue := filter.backtrack[indexWord]
newValue := oldValue | (1 << indexBit)
filter.backtrack[indexWord] = newValue
return oldValue != newValue
}

114
src/replay_test.go Normal file
View file

@ -0,0 +1,114 @@
package main
import (
"testing"
)
/* Ported from the linux kernel implementation
*
*
*/
/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
func TestReplay(t *testing.T) {
var filter ReplayFilter
T_LIM := CounterWindowSize + 1
testNumber := 0
T := func(n uint64, v bool) {
testNumber++
if filter.ValidateCounter(n) != v {
t.Fatal("Test", testNumber, "failed", n, v)
}
}
filter.Init()
/* 1 */ T(0, true)
/* 2 */ T(1, true)
/* 3 */ T(1, false)
/* 4 */ T(9, true)
/* 5 */ T(8, true)
/* 6 */ T(7, true)
/* 7 */ T(7, false)
/* 8 */ T(T_LIM, true)
/* 9 */ T(T_LIM-1, true)
/* 10 */ T(T_LIM-1, false)
/* 11 */ T(T_LIM-2, true)
/* 12 */ T(2, true)
/* 13 */ T(2, false)
/* 14 */ T(T_LIM+16, true)
/* 15 */ T(3, false)
/* 16 */ T(T_LIM+16, false)
/* 17 */ T(T_LIM*4, true)
/* 18 */ T(T_LIM*4-(T_LIM-1), true)
/* 19 */ T(10, false)
/* 20 */ T(T_LIM*4-T_LIM, false)
/* 21 */ T(T_LIM*4-(T_LIM+1), false)
/* 22 */ T(T_LIM*4-(T_LIM-2), true)
/* 23 */ T(T_LIM*4+1-T_LIM, false)
/* 24 */ T(0, false)
/* 25 */ T(RejectAfterMessages, false)
/* 26 */ T(RejectAfterMessages-1, true)
/* 27 */ T(RejectAfterMessages, false)
/* 28 */ T(RejectAfterMessages-1, false)
/* 29 */ T(RejectAfterMessages-2, true)
/* 30 */ T(RejectAfterMessages+1, false)
/* 31 */ T(RejectAfterMessages+2, false)
/* 32 */ T(RejectAfterMessages-2, false)
/* 33 */ T(RejectAfterMessages-3, true)
/* 34 */ T(0, false)
t.Log("Bulk test 1")
filter.Init()
testNumber = 0
for i := uint64(1); i <= CounterWindowSize; i++ {
T(i, true)
}
T(0, true)
T(0, false)
t.Log("Bulk test 2")
filter.Init()
testNumber = 0
for i := uint64(2); i <= CounterWindowSize+1; i++ {
T(i, true)
}
T(1, true)
T(0, false)
t.Log("Bulk test 3")
filter.Init()
testNumber = 0
for i := CounterWindowSize + 1; i > 0; i-- {
T(i, true)
}
t.Log("Bulk test 4")
filter.Init()
testNumber = 0
for i := CounterWindowSize + 2; i > 1; i-- {
T(i, true)
}
T(0, false)
t.Log("Bulk test 5")
filter.Init()
testNumber = 0
for i := CounterWindowSize; i > 0; i-- {
T(i, true)
}
T(CounterWindowSize+1, true)
T(0, false)
t.Log("Bulk test 6")
filter.Init()
testNumber = 0
for i := CounterWindowSize; i > 0; i-- {
T(i, true)
}
T(0, true)
T(CounterWindowSize+1, true)
}

View file

@ -12,22 +12,15 @@ import (
* *
*/ */
func (peer *Peer) KeepKeyFreshSending() { func (peer *Peer) KeepKeyFreshSending() {
send := func() bool { kp := peer.keyPairs.Current()
peer.keyPairs.mutex.RLock()
defer peer.keyPairs.mutex.RUnlock()
kp := peer.keyPairs.current
if kp == nil { if kp == nil {
return false return
} }
if !kp.isInitiator { if !kp.isInitiator {
return false return
} }
nonce := atomic.LoadUint64(&kp.sendNonce) nonce := atomic.LoadUint64(&kp.sendNonce)
return nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime
}()
if send { if send {
signalSend(peer.signal.handshakeBegin) signalSend(peer.signal.handshakeBegin)
} }
@ -37,22 +30,15 @@ func (peer *Peer) KeepKeyFreshSending() {
* *
*/ */
func (peer *Peer) KeepKeyFreshReceiving() { func (peer *Peer) KeepKeyFreshReceiving() {
send := func() bool { kp := peer.keyPairs.Current()
peer.keyPairs.mutex.RLock()
defer peer.keyPairs.mutex.RUnlock()
kp := peer.keyPairs.current
if kp == nil { if kp == nil {
return false return
} }
if !kp.isInitiator { if !kp.isInitiator {
return false return
} }
nonce := atomic.LoadUint64(&kp.sendNonce) nonce := atomic.LoadUint64(&kp.sendNonce)
return nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
}()
if send { if send {
signalSend(peer.signal.handshakeBegin) signalSend(peer.signal.handshakeBegin)
} }