ratelimiter: use a fake clock in tests and style cleanups

The existing test would occasionally flake out with:

	--- FAIL: TestRatelimiter (0.12s)
	    ratelimiter_test.go:99: Test failed for 127.0.0.1 , on: 7 ( not having refilled enough ) expected: false got: true
	FAIL
	FAIL    golang.zx2c4.com/wireguard/ratelimiter  0.171s

The fake clock also means the tests run much faster, so
testing this package with -count=1000 now takes < 100ms.

While here, several style cleanups. The most significant one
is unembeding the sync.Mutex fields in the rate limiter objects.
Embedded as they were, the lock methods were accessible
outside the ratelimiter package. As they aren't needed externally,
keep them internal to make them easier to reason about.

Passes `go test -race -count=10000 ./ratelimiter`

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
This commit is contained in:
David Crawshaw 2019-12-08 18:22:31 -05:00 committed by David Crawshaw
parent ae88e2a2cd
commit 9cd8909df2
2 changed files with 88 additions and 65 deletions

View file

@ -20,21 +20,23 @@ const (
) )
type RatelimiterEntry struct { type RatelimiterEntry struct {
sync.Mutex mu sync.Mutex
lastTime time.Time lastTime time.Time
tokens int64 tokens int64
} }
type Ratelimiter struct { type Ratelimiter struct {
sync.RWMutex mu sync.RWMutex
stopReset chan struct{} timeNow func() time.Time
stopReset chan struct{} // send to reset, close to stop
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
} }
func (rate *Ratelimiter) Close() { func (rate *Ratelimiter) Close() {
rate.Lock() rate.mu.Lock()
defer rate.Unlock() defer rate.mu.Unlock()
if rate.stopReset != nil { if rate.stopReset != nil {
close(rate.stopReset) close(rate.stopReset)
@ -42,11 +44,14 @@ func (rate *Ratelimiter) Close() {
} }
func (rate *Ratelimiter) Init() { func (rate *Ratelimiter) Init() {
rate.Lock() rate.mu.Lock()
defer rate.Unlock() defer rate.mu.Unlock()
if rate.timeNow == nil {
rate.timeNow = time.Now
}
// stop any ongoing garbage collection routine // stop any ongoing garbage collection routine
if rate.stopReset != nil { if rate.stopReset != nil {
close(rate.stopReset) close(rate.stopReset)
} }
@ -55,48 +60,50 @@ func (rate *Ratelimiter) Init() {
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
// start garbage collection routine stopReset := rate.stopReset // store in case Init is called again.
// Start garbage collection routine.
go func() { go func() {
ticker := time.NewTicker(time.Second) ticker := time.NewTicker(time.Second)
ticker.Stop() ticker.Stop()
for { for {
select { select {
case _, ok := <-rate.stopReset: case _, ok := <-stopReset:
ticker.Stop() ticker.Stop()
if ok { if !ok {
ticker = time.NewTicker(time.Second)
} else {
return return
} }
ticker = time.NewTicker(time.Second)
case <-ticker.C: case <-ticker.C:
func() { if rate.cleanup() {
rate.Lock() ticker.Stop()
defer rate.Unlock() }
}
}
}()
}
func (rate *Ratelimiter) cleanup() (empty bool) {
rate.mu.Lock()
defer rate.mu.Unlock()
for key, entry := range rate.tableIPv4 { for key, entry := range rate.tableIPv4 {
entry.Lock() entry.mu.Lock()
if time.Since(entry.lastTime) > garbageCollectTime { if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv4, key) delete(rate.tableIPv4, key)
} }
entry.Unlock() entry.mu.Unlock()
} }
for key, entry := range rate.tableIPv6 { for key, entry := range rate.tableIPv6 {
entry.Lock() entry.mu.Lock()
if time.Since(entry.lastTime) > garbageCollectTime { if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv6, key) delete(rate.tableIPv6, key)
} }
entry.Unlock() entry.mu.Unlock()
} }
if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 { return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
ticker.Stop()
}
}()
}
}
}()
} }
func (rate *Ratelimiter) Allow(ip net.IP) bool { func (rate *Ratelimiter) Allow(ip net.IP) bool {
@ -109,7 +116,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
IPv4 := ip.To4() IPv4 := ip.To4()
IPv6 := ip.To16() IPv6 := ip.To16()
rate.RLock() rate.mu.RLock()
if IPv4 != nil { if IPv4 != nil {
copy(keyIPv4[:], IPv4) copy(keyIPv4[:], IPv4)
@ -119,15 +126,15 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
entry = rate.tableIPv6[keyIPv6] entry = rate.tableIPv6[keyIPv6]
} }
rate.RUnlock() rate.mu.RUnlock()
// make new entry if not found // make new entry if not found
if entry == nil { if entry == nil {
entry = new(RatelimiterEntry) entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost entry.tokens = maxTokens - packetCost
entry.lastTime = time.Now() entry.lastTime = rate.timeNow()
rate.Lock() rate.mu.Lock()
if IPv4 != nil { if IPv4 != nil {
rate.tableIPv4[keyIPv4] = entry rate.tableIPv4[keyIPv4] = entry
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
@ -139,14 +146,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
rate.stopReset <- struct{}{} rate.stopReset <- struct{}{}
} }
} }
rate.Unlock() rate.mu.Unlock()
return true return true
} }
// add tokens to entry // add tokens to entry
entry.Lock() entry.mu.Lock()
now := time.Now() now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds() entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
entry.lastTime = now entry.lastTime = now
if entry.tokens > maxTokens { if entry.tokens > maxTokens {
@ -157,9 +164,9 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
if entry.tokens > packetCost { if entry.tokens > packetCost {
entry.tokens -= packetCost entry.tokens -= packetCost
entry.Unlock() entry.mu.Unlock()
return true return true
} }
entry.Unlock() entry.mu.Unlock()
return false return false
} }

View file

@ -11,22 +11,21 @@ import (
"time" "time"
) )
type RatelimiterResult struct { type result struct {
allowed bool allowed bool
text string text string
wait time.Duration wait time.Duration
} }
func TestRatelimiter(t *testing.T) { func TestRatelimiter(t *testing.T) {
var rate Ratelimiter
var expectedResults []result
var ratelimiter Ratelimiter nano := func(nano int64) time.Duration {
var expectedResults []RatelimiterResult
Nano := func(nano int64) time.Duration {
return time.Nanosecond * time.Duration(nano) return time.Nanosecond * time.Duration(nano)
} }
Add := func(res RatelimiterResult) { add := func(res result) {
expectedResults = append( expectedResults = append(
expectedResults, expectedResults,
res, res,
@ -34,40 +33,40 @@ func TestRatelimiter(t *testing.T) {
} }
for i := 0; i < packetsBurstable; i++ { for i := 0; i < packetsBurstable; i++ {
Add(RatelimiterResult{ add(result{
allowed: true, allowed: true,
text: "initial burst", text: "initial burst",
}) })
} }
Add(RatelimiterResult{ add(result{
allowed: false, allowed: false,
text: "after burst", text: "after burst",
}) })
Add(RatelimiterResult{ add(result{
allowed: true, allowed: true,
wait: Nano(time.Second.Nanoseconds() / packetsPerSecond), wait: nano(time.Second.Nanoseconds() / packetsPerSecond),
text: "filling tokens for single packet", text: "filling tokens for single packet",
}) })
Add(RatelimiterResult{ add(result{
allowed: false, allowed: false,
text: "not having refilled enough", text: "not having refilled enough",
}) })
Add(RatelimiterResult{ add(result{
allowed: true, allowed: true,
wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)), wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
text: "filling tokens for two packet burst", text: "filling tokens for two packet burst",
}) })
Add(RatelimiterResult{ add(result{
allowed: true, allowed: true,
text: "second packet in 2 packet burst", text: "second packet in 2 packet burst",
}) })
Add(RatelimiterResult{ add(result{
allowed: false, allowed: false,
text: "packet following 2 packet burst", text: "packet following 2 packet burst",
}) })
@ -89,14 +88,31 @@ func TestRatelimiter(t *testing.T) {
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
} }
ratelimiter.Init() now := time.Now()
rate.timeNow = func() time.Time {
return now
}
defer func() {
// Lock to avoid data race with cleanup goroutine from Init.
rate.mu.Lock()
defer rate.mu.Unlock()
rate.timeNow = time.Now
}()
timeSleep := func(d time.Duration) {
now = now.Add(d + 1)
rate.cleanup()
}
rate.Init()
defer rate.Close()
for i, res := range expectedResults { for i, res := range expectedResults {
time.Sleep(res.wait) timeSleep(res.wait)
for _, ip := range ips { for _, ip := range ips {
allowed := ratelimiter.Allow(ip) allowed := rate.Allow(ip)
if allowed != res.allowed { if allowed != res.allowed {
t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed) t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
} }
} }
} }