wireguard-go/ratelimiter/ratelimiter.go

166 lines
3.1 KiB
Go
Raw Normal View History

2019-01-02 00:55:51 +00:00
/* SPDX-License-Identifier: MIT
*
2019-01-02 00:55:51 +00:00
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter
import (
"net"
"sync"
"time"
)
const (
packetsPerSecond = 20
packetsBurstable = 5
garbageCollectTime = time.Second
packetCost = 1000000000 / packetsPerSecond
maxTokens = packetCost * packetsBurstable
)
type RatelimiterEntry struct {
sync.Mutex
lastTime time.Time
tokens int64
}
type Ratelimiter struct {
sync.RWMutex
stopReset chan struct{}
2018-02-11 21:53:39 +00:00
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
}
func (rate *Ratelimiter) Close() {
rate.Lock()
defer rate.Unlock()
2018-02-11 21:53:39 +00:00
if rate.stopReset != nil {
close(rate.stopReset)
2018-02-11 21:53:39 +00:00
}
}
func (rate *Ratelimiter) Init() {
rate.Lock()
defer rate.Unlock()
2018-02-11 21:53:39 +00:00
// stop any ongoing garbage collection routine
if rate.stopReset != nil {
close(rate.stopReset)
2018-02-11 21:53:39 +00:00
}
rate.stopReset = make(chan struct{})
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
2018-02-11 21:53:39 +00:00
// start garbage collection routine
2018-02-11 21:53:39 +00:00
go func() {
2018-05-13 16:42:06 +00:00
ticker := time.NewTicker(time.Second)
ticker.Stop()
2018-02-11 21:53:39 +00:00
for {
select {
case _, ok := <-rate.stopReset:
2018-05-13 16:42:06 +00:00
ticker.Stop()
if ok {
ticker = time.NewTicker(time.Second)
} else {
return
}
2018-05-13 16:42:06 +00:00
case <-ticker.C:
func() {
rate.Lock()
defer rate.Unlock()
for key, entry := range rate.tableIPv4 {
entry.Lock()
if time.Since(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv4, key)
}
entry.Unlock()
}
for key, entry := range rate.tableIPv6 {
entry.Lock()
if time.Since(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv6, key)
}
entry.Unlock()
}
if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
ticker.Stop()
}
}()
2018-02-11 21:53:39 +00:00
}
}
}()
}
func (rate *Ratelimiter) Allow(ip net.IP) bool {
var entry *RatelimiterEntry
2018-05-13 16:42:06 +00:00
var keyIPv4 [net.IPv4len]byte
var keyIPv6 [net.IPv6len]byte
// lookup entry
IPv4 := ip.To4()
IPv6 := ip.To16()
rate.RLock()
if IPv4 != nil {
2018-05-13 16:42:06 +00:00
copy(keyIPv4[:], IPv4)
entry = rate.tableIPv4[keyIPv4]
} else {
2018-05-13 16:42:06 +00:00
copy(keyIPv6[:], IPv6)
entry = rate.tableIPv6[keyIPv6]
}
rate.RUnlock()
// make new entry if not found
if entry == nil {
entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost
entry.lastTime = time.Now()
rate.Lock()
if IPv4 != nil {
2018-05-13 16:42:06 +00:00
rate.tableIPv4[keyIPv4] = entry
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
rate.stopReset <- struct{}{}
}
} else {
2018-05-13 16:42:06 +00:00
rate.tableIPv6[keyIPv6] = entry
if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
rate.stopReset <- struct{}{}
}
}
rate.Unlock()
return true
}
// add tokens to entry
entry.Lock()
now := time.Now()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
entry.lastTime = now
if entry.tokens > maxTokens {
entry.tokens = maxTokens
}
// subtract cost of packet
if entry.tokens > packetCost {
entry.tokens -= packetCost
entry.Unlock()
return true
}
entry.Unlock()
return false
}