diff --git a/peer.go b/peer.go index 1170720..5580cf6 100644 --- a/peer.go +++ b/peer.go @@ -54,8 +54,8 @@ type Peer struct { handshakeDeadline Timer // complete handshake timeout handshakeTimeout Timer // current handshake message timeout - sendLastMinuteHandshake bool - needAnotherKeepalive bool + sendLastMinuteHandshake AtomicBool + needAnotherKeepalive AtomicBool } queue struct { @@ -170,15 +170,8 @@ func (peer *Peer) SendBuffer(buffer []byte) error { /* Returns a short string identifier for logging */ func (peer *Peer) String() string { - if peer.endpoint == nil { - return fmt.Sprintf( - "peer(unknown %s)", - base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), - ) - } return fmt.Sprintf( - "peer(%s %s)", - peer.endpoint.DstToString(), + "peer(%s)", base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), ) } diff --git a/timer.go b/timer.go index 6cac40d..74e3a4e 100644 --- a/timer.go +++ b/timer.go @@ -1,44 +1,52 @@ package main import ( + "sync" "time" ) type Timer struct { - pending AtomicBool + mutex sync.Mutex + pending bool timer *time.Timer } /* Starts the timer if not already pending */ func (t *Timer) Start(dur time.Duration) bool { - if !t.pending.Swap(true) { + t.mutex.Lock() + defer t.mutex.Unlock() + + started := !t.pending + if started { t.timer.Reset(dur) - return true } - return false + return started } -/* Stops the timer - */ func (t *Timer) Stop() { - if t.pending.Swap(true) { - t.timer.Stop() - select { - case <-t.timer.C: - default: - } + t.mutex.Lock() + defer t.mutex.Unlock() + + t.timer.Stop() + select { + case <-t.timer.C: + default: } - t.pending.Set(false) + t.pending = false } func (t *Timer) Pending() bool { - return t.pending.Get() + t.mutex.Lock() + defer t.mutex.Unlock() + + return t.pending } func (t *Timer) Reset(dur time.Duration) { - t.pending.Set(false) - t.Start(dur) + t.mutex.Lock() + defer t.mutex.Unlock() + t.timer.Reset(dur) } func (t *Timer) Wait() <-chan time.Time { @@ -46,8 +54,8 @@ func (t *Timer) Wait() <-chan time.Time { } func NewTimer() (t Timer) { - t.pending.Set(false) - t.timer = time.NewTimer(0) + t.pending = false + t.timer = time.NewTimer(time.Hour) t.timer.Stop() select { case <-t.timer.C: diff --git a/timers.go b/timers.go index 1240c21..76dffb9 100644 --- a/timers.go +++ b/timers.go @@ -36,7 +36,7 @@ func (peer *Peer) KeepKeyFreshSending() { * NOTE: Not thread safe, but called by sequential receiver! */ func (peer *Peer) KeepKeyFreshReceiving() { - if peer.timer.sendLastMinuteHandshake { + if peer.timer.sendLastMinuteHandshake.Get() { return } kp := peer.keyPairs.Current() @@ -50,7 +50,7 @@ func (peer *Peer) KeepKeyFreshReceiving() { send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving if send { // do a last minute attempt at initiating a new handshake - peer.timer.sendLastMinuteHandshake = true + peer.timer.sendLastMinuteHandshake.Set(true) peer.signal.handshakeBegin.Send() } } @@ -87,7 +87,7 @@ func (peer *Peer) TimerDataSent() { */ func (peer *Peer) TimerDataReceived() { if !peer.timer.keepalivePassive.Start(KeepaliveTimeout) { - peer.timer.needAnotherKeepalive = true + peer.timer.needAnotherKeepalive.Set(true) } } @@ -238,8 +238,7 @@ func (peer *Peer) RoutineTimerHandler() { peer.SendKeepAlive() - if peer.timer.needAnotherKeepalive { - peer.timer.needAnotherKeepalive = false + if peer.timer.needAnotherKeepalive.Swap(false) { peer.timer.keepalivePassive.Reset(KeepaliveTimeout) } @@ -342,7 +341,7 @@ func (peer *Peer) RoutineTimerHandler() { peer.timer.handshakeDeadline.Stop() peer.signal.handshakeBegin.Enable() - peer.timer.sendLastMinuteHandshake = false + peer.timer.sendLastMinuteHandshake.Set(false) } } } diff --git a/tun.go b/tun.go index 6259f33..7b044ad 100644 --- a/tun.go +++ b/tun.go @@ -26,6 +26,7 @@ type TUNDevice interface { } func (device *Device) RoutineTUNEventReader() { + setUp := false logInfo := device.log.Info logError := device.log.Error @@ -45,13 +46,15 @@ func (device *Device) RoutineTUNEventReader() { } } - if event&TUNEventUp != 0 && !device.isUp.Get() { + if event&TUNEventUp != 0 && !setUp { logInfo.Println("Interface set up") + setUp = true device.Up() } - if event&TUNEventDown != 0 && device.isUp.Get() { + if event&TUNEventDown != 0 && setUp { logInfo.Println("Interface set down") + setUp = false device.Down() } }