replay: clean up internals and better documentation

Signed-off-by: Riobard Zhan <me@riobard.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Riobard Zhan 2020-09-10 01:55:24 +08:00 committed by Jason A. Donenfeld
parent c8fe925020
commit 22af3890f6
2 changed files with 50 additions and 71 deletions

View file

@ -3,81 +3,60 @@
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
package replay
/* Implementation of RFC6479
* https://tools.ietf.org/html/rfc6479
*
* The implementation is not safe for concurrent use!
*/
type block uint64
const (
// See: https://golang.org/src/math/big/arith.go
_Wordm = ^uintptr(0)
_WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
_WordSize = 1 << _WordLogSize
blockBitLog = 6 // 1<<6 == 64 bits
blockBits = 1 << blockBitLog // must be power of 2
ringBlocks = 1 << 7 // must be power of 2
windowSize = (ringBlocks - 1) * blockBits
blockMask = ringBlocks - 1
bitMask = blockBits - 1
)
const (
CounterRedundantBitsLog = _WordLogSize + 3
CounterRedundantBits = _WordSize * 8
CounterBitsTotal = 8192
CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
)
const (
BacktrackWords = CounterBitsTotal / 8 / _WordSize
)
func minUint64(a uint64, b uint64) uint64 {
if a > b {
return b
}
return a
}
// A ReplayFilter rejects replayed messages by checking if message counter value is
// within a sliding window of previously received messages.
// The zero value for ReplayFilter is an empty filter ready to use.
// Filters are unsafe for concurrent use.
type ReplayFilter struct {
counter uint64
backtrack [BacktrackWords]uintptr
last uint64
ring [ringBlocks]block
}
func (filter *ReplayFilter) Init() {
filter.counter = 0
filter.backtrack[0] = 0
// Init resets the filter to empty state.
func (f *ReplayFilter) Init() {
f.last = 0
f.ring[0] = 0
}
func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
// ValidateCounter checks if the counter should be accepted.
// Overlimit counters (>= limit) are always rejected.
func (f *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
if counter >= limit {
return false
}
indexWord := counter >> CounterRedundantBitsLog
if counter > filter.counter {
// move window forward
current := filter.counter >> CounterRedundantBitsLog
diff := minUint64(indexWord-current, BacktrackWords)
for i := uint64(1); i <= diff; i++ {
filter.backtrack[(current+i)%BacktrackWords] = 0
indexBlock := counter >> blockBitLog
if counter > f.last { // move window forward
current := f.last >> blockBitLog
diff := indexBlock - current
if diff > ringBlocks {
diff = ringBlocks // cap diff to clear the whole ring
}
filter.counter = counter
} else if filter.counter-counter > CounterWindowSize {
// behind current window
for i := current + 1; i <= current+diff; i++ {
f.ring[i&blockMask] = 0
}
f.last = counter
} else if f.last-counter > windowSize { // behind current window
return false
}
indexWord %= BacktrackWords
indexBit := counter & uint64(CounterRedundantBits-1)
// check and set bit
oldValue := filter.backtrack[indexWord]
newValue := oldValue | (1 << indexBit)
filter.backtrack[indexWord] = newValue
return oldValue != newValue
indexBlock &= blockMask
indexBit := counter & bitMask
old := f.ring[indexBlock]
new := old | 1<<indexBit
f.ring[indexBlock] = new
return old != new
}

View file

@ -19,13 +19,13 @@ const RejectAfterMessages = (1 << 64) - (1 << 4) - 1
func TestReplay(t *testing.T) {
var filter ReplayFilter
T_LIM := CounterWindowSize + 1
const T_LIM = windowSize + 1
testNumber := 0
T := func(n uint64, v bool) {
T := func(n uint64, expected bool) {
testNumber++
if filter.ValidateCounter(n, RejectAfterMessages) != v {
t.Fatal("Test", testNumber, "failed", n, v)
if filter.ValidateCounter(n, RejectAfterMessages) != expected {
t.Fatal("Test", testNumber, "failed", n, expected)
}
}
@ -69,7 +69,7 @@ func TestReplay(t *testing.T) {
t.Log("Bulk test 1")
filter.Init()
testNumber = 0
for i := uint64(1); i <= CounterWindowSize; i++ {
for i := uint64(1); i <= windowSize; i++ {
T(i, true)
}
T(0, true)
@ -78,7 +78,7 @@ func TestReplay(t *testing.T) {
t.Log("Bulk test 2")
filter.Init()
testNumber = 0
for i := uint64(2); i <= CounterWindowSize+1; i++ {
for i := uint64(2); i <= windowSize+1; i++ {
T(i, true)
}
T(1, true)
@ -87,14 +87,14 @@ func TestReplay(t *testing.T) {
t.Log("Bulk test 3")
filter.Init()
testNumber = 0
for i := CounterWindowSize + 1; i > 0; i-- {
for i := uint64(windowSize + 1); i > 0; i-- {
T(i, true)
}
t.Log("Bulk test 4")
filter.Init()
testNumber = 0
for i := CounterWindowSize + 2; i > 1; i-- {
for i := uint64(windowSize + 2); i > 1; i-- {
T(i, true)
}
T(0, false)
@ -102,18 +102,18 @@ func TestReplay(t *testing.T) {
t.Log("Bulk test 5")
filter.Init()
testNumber = 0
for i := CounterWindowSize; i > 0; i-- {
for i := uint64(windowSize); i > 0; i-- {
T(i, true)
}
T(CounterWindowSize+1, true)
T(windowSize+1, true)
T(0, false)
t.Log("Bulk test 6")
filter.Init()
testNumber = 0
for i := CounterWindowSize; i > 0; i-- {
for i := uint64(windowSize); i > 0; i-- {
T(i, true)
}
T(0, true)
T(CounterWindowSize+1, true)
T(windowSize+1, true)
}