device: timers: use pre-seeded per-thread unlocked fastrandn for jitter

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-10-28 13:47:50 +02:00
parent 60683d7361
commit eb6302c7eb
1 changed files with 5 additions and 10 deletions

View File

@ -8,19 +8,14 @@
package device package device
import ( import (
"crypto/rand"
unsafeRand "math/rand"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe" _ "unsafe"
) )
func init() { //go:linkname fastrandn runtime.fastrandn
var seed int64 func fastrandn(n uint32) uint32
rand.Read(unsafe.Slice((*byte)(unsafe.Pointer(&seed)), unsafe.Sizeof(seed)))
unsafeRand.Seed(seed)
}
// A Timer manages time-based aspects of the WireGuard protocol. // A Timer manages time-based aspects of the WireGuard protocol.
// Timer roughly copies the interface of the Linux kernel's struct timer_list. // Timer roughly copies the interface of the Linux kernel's struct timer_list.
@ -152,7 +147,7 @@ func expiredPersistentKeepalive(peer *Peer) {
/* Should be called after an authenticated data packet is sent. */ /* Should be called after an authenticated data packet is sent. */
func (peer *Peer) timersDataSent() { func (peer *Peer) timersDataSent() {
if peer.timersActive() && !peer.timers.newHandshake.IsPending() { if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(unsafeRand.Int63n(RekeyTimeoutJitterMaxMs))) peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
} }
} }
@ -184,7 +179,7 @@ func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
/* Should be called after a handshake initiation message is sent. */ /* Should be called after a handshake initiation message is sent. */
func (peer *Peer) timersHandshakeInitiated() { func (peer *Peer) timersHandshakeInitiated() {
if peer.timersActive() { if peer.timersActive() {
peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(unsafeRand.Int63n(RekeyTimeoutJitterMaxMs))) peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
} }
} }