diff --git a/device/device.go b/device/device.go index d9367e5..99f5e60 100644 --- a/device/device.go +++ b/device/device.go @@ -163,7 +163,7 @@ func deviceUpdateState(device *Device) { device.peers.RLock() for _, peer := range device.peers.keyMap { peer.Start() - if peer.persistentKeepaliveInterval > 0 { + if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { peer.SendKeepalive() } } diff --git a/device/device_test.go b/device/device_test.go index 65942ec..e143914 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -215,7 +215,20 @@ func TestConcurrencySafety(t *testing.T) { }() warmup.Wait() - // coming soon: more things here... + // Change persistent_keepalive_interval concurrently with tunnel use. + t.Run("persistentKeepaliveInterval", func(t *testing.T) { + cfg := uapiCfg( + "public_key", "f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725", + "persistent_keepalive_interval", "1", + ) + for i := 0; i < 1000; i++ { + cfg.Seek(0, io.SeekStart) + err := pair[0].dev.IpcSetOperation(cfg) + if err != nil { + t.Fatal(err) + } + } + }) close(done) } diff --git a/device/peer.go b/device/peer.go index c2397cc..31b75c7 100644 --- a/device/peer.go +++ b/device/peer.go @@ -27,7 +27,7 @@ type Peer struct { handshake Handshake device *Device endpoint conn.Endpoint - persistentKeepaliveInterval uint16 + persistentKeepaliveInterval uint32 // accessed atomically disableRoaming bool // These fields are accessed with atomic operations, which must be diff --git a/device/timers.go b/device/timers.go index 48cef94..e94da36 100644 --- a/device/timers.go +++ b/device/timers.go @@ -138,7 +138,7 @@ func expiredZeroKeyMaterial(peer *Peer) { } func expiredPersistentKeepalive(peer *Peer) { - if peer.persistentKeepaliveInterval > 0 { + if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { peer.SendKeepalive() } } @@ -201,8 +201,9 @@ func (peer *Peer) timersSessionDerived() { /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { - if peer.persistentKeepaliveInterval > 0 && peer.timersActive() { - peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second) + keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval) + if keepalive > 0 && peer.timersActive() { + peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second) } } diff --git a/device/uapi.go b/device/uapi.go index c0e522b..3f26607 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -86,7 +86,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error { send(fmt.Sprintf("last_handshake_time_nsec=%d", nano)) send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))) send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))) - send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) + send(fmt.Sprintf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))) for _, ip := range device.allowedips.EntriesForPeer(peer) { send("allowed_ip=" + ip.String()) @@ -333,8 +333,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error { return &IPCError{ipc.IpcErrorInvalid} } - old := peer.persistentKeepaliveInterval - peer.persistentKeepaliveInterval = uint16(secs) + old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs)) // send immediate keepalive if we're turning it on and before it wasn't on