diff --git a/device.go b/device.go index e288ebe..9f93f21 100644 --- a/device.go +++ b/device.go @@ -1,6 +1,7 @@ package main import ( + "git.zx2c4.com/wireguard-go/internal/ratelimiter" "runtime" "sync" "sync/atomic" @@ -50,7 +51,7 @@ type Device struct { rate struct { underLoadUntil atomic.Value - limiter Ratelimiter + limiter ratelimiter.Ratelimiter } pool struct { @@ -300,7 +301,6 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() - go device.rate.limiter.RoutineGarbageCollector(device.signal.stop) return device } @@ -355,6 +355,7 @@ func (device *Device) Close() { device.BindClose() device.isUp.Set(false) device.RemoveAllPeers() + device.rate.limiter.Close() device.log.Info.Println("Interface closed") } diff --git a/ratelimiter.go b/internal/ratelimiter/ratelimiter.go similarity index 79% rename from ratelimiter.go rename to internal/ratelimiter/ratelimiter.go index 6e5f005..f9fc673 100644 --- a/ratelimiter.go +++ b/internal/ratelimiter/ratelimiter.go @@ -1,4 +1,4 @@ -package main +package ratelimiter /* Copyright (C) 2015-2017 Jason A. Donenfeld . All Rights Reserved. */ @@ -26,21 +26,48 @@ type RatelimiterEntry struct { } type Ratelimiter struct { - mutex sync.RWMutex - lastGarbageCollect time.Time - tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry - tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry + mutex sync.RWMutex + stop chan struct{} + tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry + tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry +} + +func (rate *Ratelimiter) Close() { + rate.mutex.Lock() + defer rate.mutex.Unlock() + + if rate.stop != nil { + close(rate.stop) + } } func (rate *Ratelimiter) Init() { rate.mutex.Lock() defer rate.mutex.Unlock() + + if rate.stop != nil { + close(rate.stop) + } + + rate.stop = make(chan struct{}) rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) - rate.lastGarbageCollect = time.Now() + + go func() { + timer := time.NewTimer(time.Second) + for { + select { + case <-rate.stop: + return + case <-timer.C: + rate.garbageCollectEntries() + timer.Reset(time.Second) + } + } + }() } -func (rate *Ratelimiter) GarbageCollectEntries() { +func (rate *Ratelimiter) garbageCollectEntries() { rate.mutex.Lock() // remove unused IPv4 entries @@ -66,19 +93,6 @@ func (rate *Ratelimiter) GarbageCollectEntries() { rate.mutex.Unlock() } -func (rate *Ratelimiter) RoutineGarbageCollector(stop Signal) { - timer := time.NewTimer(time.Second) - for { - select { - case <-stop.Wait(): - return - case <-timer.C: - rate.GarbageCollectEntries() - timer.Reset(time.Second) - } - } -} - func (rate *Ratelimiter) Allow(ip net.IP) bool { var entry *RatelimiterEntry var KeyIPv4 [net.IPv4len]byte diff --git a/ratelimiter_test.go b/internal/ratelimiter/ratelimiter_test.go similarity index 99% rename from ratelimiter_test.go rename to internal/ratelimiter/ratelimiter_test.go index 13b6a23..a6f618b 100644 --- a/ratelimiter_test.go +++ b/internal/ratelimiter/ratelimiter_test.go @@ -1,4 +1,4 @@ -package main +package ratelimiter import ( "net" diff --git a/internal/tai64n/tai64.go b/internal/tai64n/tai64n.go similarity index 100% rename from internal/tai64n/tai64.go rename to internal/tai64n/tai64n.go diff --git a/internal/tai64n/tai64n_test.go b/internal/tai64n/tai64n_test.go new file mode 100644 index 0000000..389b65c --- /dev/null +++ b/internal/tai64n/tai64n_test.go @@ -0,0 +1,21 @@ +package tai64n + +import ( + "testing" + "time" +) + +/* Testing the essential property of the timestamp + * as used by WireGuard. + */ +func TestMonotonic(t *testing.T) { + old := Now() + for i := 0; i < 10000; i++ { + time.Sleep(time.Nanosecond) + next := Now() + if !next.After(old) { + t.Error("TAI64N, not monotonically increasing on nano-second scale") + } + old = next + } +} diff --git a/timers.go b/timers.go index 70e907c..1240c21 100644 --- a/timers.go +++ b/timers.go @@ -120,7 +120,7 @@ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { */ func (peer *Peer) TimerHandshakeComplete() { peer.signal.handshakeCompleted.Send() - peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) + peer.device.log.Info.Println(peer.String(), ": New handshake completed") } /* Event: