diff --git a/device/device.go b/device/device.go index 9ea7c24..c637e38 100644 --- a/device/device.go +++ b/device/device.go @@ -21,17 +21,26 @@ import ( ) type Device struct { - isUp AtomicBool // device is (going) up - isClosed AtomicBool // device is closed? (acting as guard) - log *Logger + log *Logger // synchronized resources (locks acquired in order) state struct { + // state holds the device's state. It is accessed atomically. + // Use the device.deviceState method to read it. + // If state.mu is (r)locked, state is the current state of the device. + // Without state.mu (r)locked, state is either the current state + // of the device or the intended future state of the device. + // For example, while executing a call to Up, state will be deviceStateUp. + // There is no guarantee that that intended future state of the device + // will become the actual state; Up can fail. + // The device can also change state multiple times between time of check and time of use. + // Unsynchronized uses of state must therefore be advisory/best-effort only. + state uint32 // actually a deviceState, but typed uint32 for conveniene + // stopping blocks until all inputs to Device have been closed. stopping sync.WaitGroup - sync.Mutex - changing AtomicBool - current bool + // mu protects state changes. + mu sync.Mutex } net struct { @@ -87,6 +96,43 @@ type Device struct { closed chan struct{} } +// deviceState represents the state of a Device. +// There are four states: new, down, up, closed. +// However, state new should never be observable. +// Transitions: +// +// new -> down -----+ +// ↑↓ ↓ +// up -> closed +// +type deviceState uint32 + +//go:generate stringer -type deviceState -trimprefix=deviceState +const ( + deviceStateNew deviceState = iota + deviceStateDown + deviceStateUp + deviceStateClosed +) + +// deviceState returns device.state.state as a deviceState +// See those docs for how to interpret this value. +func (device *Device) deviceState() deviceState { + return deviceState(atomic.LoadUint32(&device.state.state)) +} + +// isClosed reports whether the device is closed (or is closing). +// See device.state.state comments for how to interpret this value. +func (device *Device) isClosed() bool { + return device.deviceState() == deviceStateClosed +} + +// isUp reports whether the device is up (or is attempting to come up). +// See device.state.state comments for how to interpret this value. +func (device *Device) isUp() bool { + return device.deviceState() == deviceStateUp +} + // An outboundQueue is a channel of QueueOutboundElements awaiting encryption. // An outboundQueue is ref-counted using its wg field. // An outboundQueue created with newOutboundQueue has one reference. @@ -154,91 +200,82 @@ func newHandshakeQueue() *handshakeQueue { * Must hold device.peers.Mutex */ func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { - // stop routing and processing of packets - device.allowedips.RemoveByPeer(peer) peer.Stop() // remove from peer map - delete(device.peers.keyMap, key) device.peers.empty.Set(len(device.peers.keyMap) == 0) } -func deviceUpdateState(device *Device) { - - // check if state already being updated (guard) - - if device.state.changing.Swap(true) { +// changeState attempts to change the device state to match want. +func (device *Device) changeState(want deviceState) { + device.state.mu.Lock() + defer device.state.mu.Unlock() + old := device.deviceState() + if old == deviceStateClosed { + // once closed, always closed + device.log.Verbosef("Interface closed, ignored requested state %s", want) return } - - // compare to current state of device - - device.state.Lock() - - newIsUp := device.isUp.Get() - - if newIsUp == device.state.current { - device.state.changing.Set(false) - device.state.Unlock() + switch want { + case old: + device.log.Verbosef("Interface already in state %s", want) return - } - - // change state of device - - switch newIsUp { - case true: - if err := device.BindUpdate(); err != nil { - device.log.Errorf("Unable to update bind: %v", err) - device.isUp.Set(false) + case deviceStateUp: + atomic.StoreUint32(&device.state.state, uint32(deviceStateUp)) + if ok := device.upLocked(); ok { break } - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Start() - if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { - peer.SendKeepalive() - } - } - device.peers.RUnlock() + fallthrough // up failed; bring the device all the way back down + case deviceStateDown: + atomic.StoreUint32(&device.state.state, uint32(deviceStateDown)) + device.downLocked() + } + device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) +} - case false: - device.BindClose() - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Stop() - } - device.peers.RUnlock() +// upLocked attempts to bring the device up and reports whether it succeeded. +// The caller must hold device.state.mu and is responsible for updating device.state.state. +func (device *Device) upLocked() bool { + if err := device.BindUpdate(); err != nil { + device.log.Errorf("Unable to update bind: %v", err) + return false } - // update state variables + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Start() + if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { + peer.SendKeepalive() + } + } + device.peers.RUnlock() + return true +} - device.state.current = newIsUp - device.state.changing.Set(false) - device.state.Unlock() +// downLocked attempts to bring the device down. +// The caller must hold device.state.mu and is responsible for updating device.state.state. +func (device *Device) downLocked() { + err := device.BindClose() + if err != nil { + device.log.Errorf("Bind close failed: %v", err) + } - // check for state change in the mean time - - deviceUpdateState(device) + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Stop() + } + device.peers.RUnlock() } func (device *Device) Up() { - - // closed device cannot be brought up - - if device.isClosed.Get() { - return - } - - device.isUp.Set(true) - deviceUpdateState(device) + device.changeState(deviceStateUp) } func (device *Device) Down() { - device.isUp.Set(false) - deviceUpdateState(device) + device.changeState(deviceStateDown) } func (device *Device) IsUnderLoad() bool { @@ -310,6 +347,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { func NewDevice(tunDevice tun.Device, logger *Logger) *Device { device := new(Device) + device.state.state = uint32(deviceStateDown) device.closed = make(chan struct{}) device.log = logger device.tun.device = tunDevice @@ -382,19 +420,16 @@ func (device *Device) RemoveAllPeers() { } func (device *Device) Close() { - if device.isClosed.Swap(true) { + device.state.mu.Lock() + defer device.state.mu.Unlock() + if device.isClosed() { return } - + atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed)) device.log.Verbosef("Device closing") - device.state.changing.Set(true) - device.state.Lock() - defer device.state.Unlock() device.tun.device.Close() - device.BindClose() - - device.isUp.Set(false) + device.downLocked() // Remove peers before closing queues, // because peers assume that queues are active. @@ -410,8 +445,7 @@ func (device *Device) Close() { device.rate.limiter.Close() - device.state.changing.Set(false) - device.log.Verbosef("Interface closed") + device.log.Verbosef("Device closed") close(device.closed) } @@ -420,7 +454,7 @@ func (device *Device) Wait() chan struct{} { } func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { - if device.isClosed.Get() { + if !device.isUp() { return } @@ -457,27 +491,23 @@ func (device *Device) Bind() conn.Bind { } func (device *Device) BindSetMark(mark uint32) error { - device.net.Lock() defer device.net.Unlock() // check if modified - if device.net.fwmark == mark { return nil } // update fwmark on existing bind - device.net.fwmark = mark - if device.isUp.Get() && device.net.bind != nil { + if device.isUp() && device.net.bind != nil { if err := device.net.bind.SetMark(mark); err != nil { return err } } // clear cached source addresses - device.peers.RLock() for _, peer := range device.peers.keyMap { peer.Lock() @@ -492,70 +522,63 @@ func (device *Device) BindSetMark(mark uint32) error { } func (device *Device) BindUpdate() error { - device.net.Lock() defer device.net.Unlock() // close existing sockets - if err := unsafeCloseBind(device); err != nil { return err } // open new sockets - - if device.isUp.Get() { - - // bind to new port - - var err error - netc := &device.net - netc.bind, netc.port, err = conn.CreateBind(netc.port) - if err != nil { - netc.bind = nil - netc.port = 0 - return err - } - netc.netlinkCancel, err = device.startRouteListener(netc.bind) - if err != nil { - netc.bind.Close() - netc.bind = nil - netc.port = 0 - return err - } - - // set fwmark - - if netc.fwmark != 0 { - err = netc.bind.SetMark(netc.fwmark) - if err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - // start receiving routines - - device.net.stopping.Add(2) - device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption - device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake - go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) - go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) - - device.log.Verbosef("UDP bind has been updated") + if !device.isUp() { + return nil } + // bind to new port + var err error + netc := &device.net + netc.bind, netc.port, err = conn.CreateBind(netc.port) + if err != nil { + netc.bind = nil + netc.port = 0 + return err + } + netc.netlinkCancel, err = device.startRouteListener(netc.bind) + if err != nil { + netc.bind.Close() + netc.bind = nil + netc.port = 0 + return err + } + + // set fwmark + if netc.fwmark != 0 { + err = netc.bind.SetMark(netc.fwmark) + if err != nil { + return err + } + } + + // clear cached source addresses + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + // start receiving routines + device.net.stopping.Add(2) + device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption + device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake + go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) + go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) + + device.log.Verbosef("UDP bind has been updated") return nil } diff --git a/device/device_test.go b/device/device_test.go index 50e3dbc..56ecd17 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -172,7 +172,7 @@ NextAttempt: // The device might still not be up, e.g. due to an error // in RoutineTUNEventReader's call to dev.Up that got swallowed. // Assume it's due to a transient error (port in use), and retry. - if !p.dev.isUp.Get() { + if !p.dev.isUp() { tb.Logf("device %d did not come up, trying again", i) p.dev.Close() continue NextAttempt diff --git a/device/devicestate_string.go b/device/devicestate_string.go new file mode 100644 index 0000000..e8f16b0 --- /dev/null +++ b/device/devicestate_string.go @@ -0,0 +1,26 @@ +// Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT. + +package device + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[deviceStateNew-0] + _ = x[deviceStateDown-1] + _ = x[deviceStateUp-2] + _ = x[deviceStateClosed-3] +} + +const _deviceState_name = "NewDownUpClosed" + +var _deviceState_index = [...]uint8{0, 3, 7, 9, 15} + +func (i deviceState) String() string { + if i >= deviceState(len(_deviceState_index)-1) { + return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]] +} diff --git a/device/peer.go b/device/peer.go index 0bf19fd..abe8a08 100644 --- a/device/peer.go +++ b/device/peer.go @@ -62,7 +62,7 @@ type Peer struct { } func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { - if device.isClosed.Get() { + if device.isClosed() { return nil, errors.New("device closed") } @@ -107,7 +107,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { device.peers.empty.Set(false) // start peer - if peer.device.isUp.Get() { + if peer.device.isUp() { peer.Start() } @@ -121,7 +121,7 @@ func (peer *Peer) SendBuffer(buffer []byte) error { if peer.device.net.bind == nil { // Packets can leak through to SendBuffer while the device is closing. // When that happens, drop them silently to avoid spurious errors. - if peer.device.isClosed.Get() { + if peer.device.isClosed() { return nil } return errors.New("no bind") @@ -152,7 +152,7 @@ func (peer *Peer) String() string { func (peer *Peer) Start() { // should never start a peer on a closed device - if peer.device.isClosed.Get() { + if peer.device.isClosed() { return } diff --git a/device/receive.go b/device/receive.go index 21d9dbc..c6a28f7 100644 --- a/device/receive.go +++ b/device/receive.go @@ -474,7 +474,7 @@ func (peer *Peer) RoutineSequentialReceiver() { } _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent) - if err != nil && !device.isClosed.Get() { + if err != nil && !device.isClosed() { device.log.Errorf("Failed to write packet to TUN device: %v", err) } if len(peer.queue.inbound) == 0 { diff --git a/device/send.go b/device/send.go index b9bcb33..982fec0 100644 --- a/device/send.go +++ b/device/send.go @@ -225,7 +225,7 @@ func (device *Device) RoutineReadFromTUN() { size, err := device.tun.device.Read(elem.buffer[:], offset) if err != nil { - if !device.isClosed.Get() { + if !device.isClosed() { device.log.Errorf("Failed to read packet from TUN device: %v", err) device.Close() } @@ -291,7 +291,7 @@ func (peer *Peer) StagePacket(elem *QueueOutboundElement) { func (peer *Peer) SendStagedPackets() { top: - if len(peer.queue.staged) == 0 || !peer.device.isUp.Get() { + if len(peer.queue.staged) == 0 || !peer.device.isUp() { return } diff --git a/device/timers.go b/device/timers.go index 1ea91c7..f740cf0 100644 --- a/device/timers.go +++ b/device/timers.go @@ -73,7 +73,7 @@ func (timer *Timer) IsPending() bool { } func (peer *Peer) timersActive() bool { - return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && !peer.device.peers.empty.Get() + return peer.isRunning.Get() && peer.device != nil && peer.device.isUp() && !peer.device.peers.empty.Get() } func expiredRetransmitHandshake(peer *Peer) { diff --git a/device/uapi.go b/device/uapi.go index 3af37e7..406880f 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -258,7 +258,7 @@ type ipcSetPeer struct { } func (peer *ipcSetPeer) handlePostConfig() { - if peer.Peer != nil && !peer.dummy && peer.Peer.device.isUp.Get() { + if peer.Peer != nil && !peer.dummy && peer.Peer.device.isUp() { peer.SendStagedPackets() } } @@ -354,7 +354,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error if err != nil { return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err) } - if device.isUp.Get() && !peer.dummy { + if device.isUp() && !peer.dummy { peer.SendKeepalive() } }