replay: clean up internals and better documentation
Signed-off-by: Riobard Zhan <me@riobard.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
c8fe925020
commit
22af3890f6
|
@ -3,81 +3,60 @@
|
||||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
|
||||||
package replay
|
package replay
|
||||||
|
|
||||||
/* Implementation of RFC6479
|
type block uint64
|
||||||
* https://tools.ietf.org/html/rfc6479
|
|
||||||
*
|
|
||||||
* The implementation is not safe for concurrent use!
|
|
||||||
*/
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// See: https://golang.org/src/math/big/arith.go
|
blockBitLog = 6 // 1<<6 == 64 bits
|
||||||
_Wordm = ^uintptr(0)
|
blockBits = 1 << blockBitLog // must be power of 2
|
||||||
_WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
|
ringBlocks = 1 << 7 // must be power of 2
|
||||||
_WordSize = 1 << _WordLogSize
|
windowSize = (ringBlocks - 1) * blockBits
|
||||||
|
blockMask = ringBlocks - 1
|
||||||
|
bitMask = blockBits - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
// A ReplayFilter rejects replayed messages by checking if message counter value is
|
||||||
CounterRedundantBitsLog = _WordLogSize + 3
|
// within a sliding window of previously received messages.
|
||||||
CounterRedundantBits = _WordSize * 8
|
// The zero value for ReplayFilter is an empty filter ready to use.
|
||||||
CounterBitsTotal = 8192
|
// Filters are unsafe for concurrent use.
|
||||||
CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
BacktrackWords = CounterBitsTotal / 8 / _WordSize
|
|
||||||
)
|
|
||||||
|
|
||||||
func minUint64(a uint64, b uint64) uint64 {
|
|
||||||
if a > b {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
type ReplayFilter struct {
|
type ReplayFilter struct {
|
||||||
counter uint64
|
last uint64
|
||||||
backtrack [BacktrackWords]uintptr
|
ring [ringBlocks]block
|
||||||
}
|
}
|
||||||
|
|
||||||
func (filter *ReplayFilter) Init() {
|
// Init resets the filter to empty state.
|
||||||
filter.counter = 0
|
func (f *ReplayFilter) Init() {
|
||||||
filter.backtrack[0] = 0
|
f.last = 0
|
||||||
|
f.ring[0] = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
|
// ValidateCounter checks if the counter should be accepted.
|
||||||
|
// Overlimit counters (>= limit) are always rejected.
|
||||||
|
func (f *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
|
||||||
if counter >= limit {
|
if counter >= limit {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
indexBlock := counter >> blockBitLog
|
||||||
indexWord := counter >> CounterRedundantBitsLog
|
if counter > f.last { // move window forward
|
||||||
|
current := f.last >> blockBitLog
|
||||||
if counter > filter.counter {
|
diff := indexBlock - current
|
||||||
|
if diff > ringBlocks {
|
||||||
// move window forward
|
diff = ringBlocks // cap diff to clear the whole ring
|
||||||
|
|
||||||
current := filter.counter >> CounterRedundantBitsLog
|
|
||||||
diff := minUint64(indexWord-current, BacktrackWords)
|
|
||||||
for i := uint64(1); i <= diff; i++ {
|
|
||||||
filter.backtrack[(current+i)%BacktrackWords] = 0
|
|
||||||
}
|
}
|
||||||
filter.counter = counter
|
for i := current + 1; i <= current+diff; i++ {
|
||||||
|
f.ring[i&blockMask] = 0
|
||||||
} else if filter.counter-counter > CounterWindowSize {
|
}
|
||||||
|
f.last = counter
|
||||||
// behind current window
|
} else if f.last-counter > windowSize { // behind current window
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
indexWord %= BacktrackWords
|
|
||||||
indexBit := counter & uint64(CounterRedundantBits-1)
|
|
||||||
|
|
||||||
// check and set bit
|
// check and set bit
|
||||||
|
indexBlock &= blockMask
|
||||||
oldValue := filter.backtrack[indexWord]
|
indexBit := counter & bitMask
|
||||||
newValue := oldValue | (1 << indexBit)
|
old := f.ring[indexBlock]
|
||||||
filter.backtrack[indexWord] = newValue
|
new := old | 1<<indexBit
|
||||||
return oldValue != newValue
|
f.ring[indexBlock] = new
|
||||||
|
return old != new
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,13 +19,13 @@ const RejectAfterMessages = (1 << 64) - (1 << 4) - 1
|
||||||
func TestReplay(t *testing.T) {
|
func TestReplay(t *testing.T) {
|
||||||
var filter ReplayFilter
|
var filter ReplayFilter
|
||||||
|
|
||||||
T_LIM := CounterWindowSize + 1
|
const T_LIM = windowSize + 1
|
||||||
|
|
||||||
testNumber := 0
|
testNumber := 0
|
||||||
T := func(n uint64, v bool) {
|
T := func(n uint64, expected bool) {
|
||||||
testNumber++
|
testNumber++
|
||||||
if filter.ValidateCounter(n, RejectAfterMessages) != v {
|
if filter.ValidateCounter(n, RejectAfterMessages) != expected {
|
||||||
t.Fatal("Test", testNumber, "failed", n, v)
|
t.Fatal("Test", testNumber, "failed", n, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ func TestReplay(t *testing.T) {
|
||||||
t.Log("Bulk test 1")
|
t.Log("Bulk test 1")
|
||||||
filter.Init()
|
filter.Init()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := uint64(1); i <= CounterWindowSize; i++ {
|
for i := uint64(1); i <= windowSize; i++ {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(0, true)
|
T(0, true)
|
||||||
|
@ -78,7 +78,7 @@ func TestReplay(t *testing.T) {
|
||||||
t.Log("Bulk test 2")
|
t.Log("Bulk test 2")
|
||||||
filter.Init()
|
filter.Init()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := uint64(2); i <= CounterWindowSize+1; i++ {
|
for i := uint64(2); i <= windowSize+1; i++ {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(1, true)
|
T(1, true)
|
||||||
|
@ -87,14 +87,14 @@ func TestReplay(t *testing.T) {
|
||||||
t.Log("Bulk test 3")
|
t.Log("Bulk test 3")
|
||||||
filter.Init()
|
filter.Init()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := CounterWindowSize + 1; i > 0; i-- {
|
for i := uint64(windowSize + 1); i > 0; i-- {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("Bulk test 4")
|
t.Log("Bulk test 4")
|
||||||
filter.Init()
|
filter.Init()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := CounterWindowSize + 2; i > 1; i-- {
|
for i := uint64(windowSize + 2); i > 1; i-- {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(0, false)
|
T(0, false)
|
||||||
|
@ -102,18 +102,18 @@ func TestReplay(t *testing.T) {
|
||||||
t.Log("Bulk test 5")
|
t.Log("Bulk test 5")
|
||||||
filter.Init()
|
filter.Init()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := CounterWindowSize; i > 0; i-- {
|
for i := uint64(windowSize); i > 0; i-- {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(CounterWindowSize+1, true)
|
T(windowSize+1, true)
|
||||||
T(0, false)
|
T(0, false)
|
||||||
|
|
||||||
t.Log("Bulk test 6")
|
t.Log("Bulk test 6")
|
||||||
filter.Init()
|
filter.Init()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := CounterWindowSize; i > 0; i-- {
|
for i := uint64(windowSize); i > 0; i-- {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(0, true)
|
T(0, true)
|
||||||
T(CounterWindowSize+1, true)
|
T(windowSize+1, true)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue