Finer-grained start-stop synchronization

This commit is contained in:
Jason A. Donenfeld 2018-05-16 22:20:15 +02:00
parent 23eca94508
commit 846d721dfd
6 changed files with 33 additions and 5 deletions

View file

@ -12,6 +12,10 @@ import (
"net" "net"
) )
const (
ConnRoutineNumber = 2
)
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic /* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
*/ */
type Bind interface { type Bind interface {
@ -153,6 +157,8 @@ func (device *Device) BindUpdate() error {
// start receiving routines // start receiving routines
device.state.starting.Add(ConnRoutineNumber)
device.state.stopping.Add(ConnRoutineNumber)
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)

View file

@ -15,6 +15,7 @@ import (
const ( const (
DeviceRoutineNumberPerCPU = 3 DeviceRoutineNumberPerCPU = 3
DeviceRoutineNumberAdditional = 2
) )
type Device struct { type Device struct {
@ -25,6 +26,7 @@ type Device struct {
// synchronized resources (locks acquired in order) // synchronized resources (locks acquired in order)
state struct { state struct {
starting sync.WaitGroup
stopping sync.WaitGroup stopping sync.WaitGroup
mutex sync.Mutex mutex sync.Mutex
changing AtomicBool changing AtomicBool
@ -297,7 +299,10 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
// start workers // start workers
cpus := runtime.NumCPU() cpus := runtime.NumCPU()
device.state.stopping.Add(DeviceRoutineNumberPerCPU * cpus) device.state.starting.Wait()
device.state.stopping.Wait()
device.state.stopping.Add(DeviceRoutineNumberPerCPU * cpus + DeviceRoutineNumberAdditional)
device.state.starting.Add(DeviceRoutineNumberPerCPU * cpus + DeviceRoutineNumberAdditional)
for i := 0; i < cpus; i += 1 { for i := 0; i < cpus; i += 1 {
go device.RoutineEncryption() go device.RoutineEncryption()
go device.RoutineDecryption() go device.RoutineDecryption()
@ -307,6 +312,8 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
go device.RoutineReadFromTUN() go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader() go device.RoutineTUNEventReader()
device.state.starting.Wait()
return device return device
} }
@ -363,6 +370,9 @@ func (device *Device) Close() {
if device.isClosed.Swap(true) { if device.isClosed.Swap(true) {
return return
} }
device.state.starting.Wait()
device.log.Info.Println("Device closing") device.log.Info.Println("Device closing")
device.state.changing.Set(true) device.state.changing.Set(true)
device.state.mutex.Lock() device.state.mutex.Lock()

View file

@ -231,20 +231,21 @@ func (peer *Peer) Stop() {
// prevent simultaneous start/stop operations // prevent simultaneous start/stop operations
peer.routines.mutex.Lock()
defer peer.routines.mutex.Unlock()
if !peer.isRunning.Swap(false) { if !peer.isRunning.Swap(false) {
return return
} }
peer.routines.starting.Wait()
peer.routines.mutex.Lock()
defer peer.routines.mutex.Unlock()
peer.device.log.Debug.Println(peer, ": Stopping...") peer.device.log.Debug.Println(peer, ": Stopping...")
peer.timersStop() peer.timersStop()
// stop & wait for ongoing peer routines // stop & wait for ongoing peer routines
peer.routines.starting.Wait()
close(peer.routines.stop) close(peer.routines.stop)
peer.routines.stopping.Wait() peer.routines.stopping.Wait()

View file

@ -124,9 +124,11 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
logDebug := device.log.Debug logDebug := device.log.Debug
defer func() { defer func() {
logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
device.state.stopping.Done()
}() }()
logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - starting") logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - starting")
device.state.starting.Done()
// receive datagrams until conn is closed // receive datagrams until conn is closed
@ -257,6 +259,7 @@ func (device *Device) RoutineDecryption() {
device.state.stopping.Done() device.state.stopping.Done()
}() }()
logDebug.Println("Routine: decryption worker - started") logDebug.Println("Routine: decryption worker - started")
device.state.starting.Done()
for { for {
select { select {
@ -324,6 +327,7 @@ func (device *Device) RoutineHandshake() {
}() }()
logDebug.Println("Routine: handshake worker - started") logDebug.Println("Routine: handshake worker - started")
device.state.starting.Done()
var elem QueueHandshakeElement var elem QueueHandshakeElement
var ok bool var ok bool

View file

@ -247,9 +247,11 @@ func (device *Device) RoutineReadFromTUN() {
defer func() { defer func() {
logDebug.Println("Routine: TUN reader - stopped") logDebug.Println("Routine: TUN reader - stopped")
device.state.stopping.Done()
}() }()
logDebug.Println("Routine: TUN reader - started") logDebug.Println("Routine: TUN reader - started")
device.state.starting.Done()
for { for {
@ -424,6 +426,7 @@ func (device *Device) RoutineEncryption() {
}() }()
logDebug.Println("Routine: encryption worker - started") logDebug.Println("Routine: encryption worker - started")
device.state.starting.Done()
for { for {

4
tun.go
View file

@ -35,6 +35,8 @@ func (device *Device) RoutineTUNEventReader() {
logInfo := device.log.Info logInfo := device.log.Info
logError := device.log.Error logError := device.log.Error
device.state.starting.Done()
for event := range device.tun.device.Events() { for event := range device.tun.device.Events() {
if event&TUNEventMTUUpdate != 0 { if event&TUNEventMTUUpdate != 0 {
mtu, err := device.tun.device.MTU() mtu, err := device.tun.device.MTU()
@ -63,4 +65,6 @@ func (device *Device) RoutineTUNEventReader() {
device.Down() device.Down()
} }
} }
device.state.stopping.Done()
} }