Work on timer teardown + bug fixes
Added waitgroups to peer struct for routine start / stop synchronisation
This commit is contained in:
		
							parent
							
								
									d73f960aab
								
							
						
					
					
						commit
						1dd590b91b
					
				
					 8 changed files with 102 additions and 47 deletions
				
			
		
							
								
								
									
										11
									
								
								src/conn.go
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								src/conn.go
									
									
									
									
									
								
							|  | @ -64,13 +64,9 @@ func unsafeCloseBind(device *Device) error { | |||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func updateBind(device *Device) error { | ||||
| 	device.mutex.Lock() | ||||
| 	defer device.mutex.Unlock() | ||||
| 
 | ||||
| 	netc := &device.net | ||||
| 	netc.mutex.Lock() | ||||
| 	defer netc.mutex.Unlock() | ||||
| /* Must hold device and net lock | ||||
|  */ | ||||
| func unsafeUpdateBind(device *Device) error { | ||||
| 
 | ||||
| 	// close existing sockets
 | ||||
| 
 | ||||
|  | @ -89,6 +85,7 @@ func updateBind(device *Device) error { | |||
| 		// bind to new port
 | ||||
| 
 | ||||
| 		var err error | ||||
| 		netc := &device.net | ||||
| 		netc.bind, netc.port, err = CreateBind(netc.port) | ||||
| 		if err != nil { | ||||
| 			netc.bind = nil | ||||
|  |  | |||
|  | @ -1,6 +1,7 @@ | |||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/sasha-s/go-deadlock" | ||||
| 	"runtime" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
|  | @ -21,12 +22,12 @@ type Device struct { | |||
| 		messageBuffers sync.Pool | ||||
| 	} | ||||
| 	net struct { | ||||
| 		mutex  sync.RWMutex | ||||
| 		mutex  deadlock.RWMutex | ||||
| 		bind   Bind   // bind interface
 | ||||
| 		port   uint16 // listening port
 | ||||
| 		fwmark uint32 // mark value (0 = disabled)
 | ||||
| 	} | ||||
| 	mutex        sync.RWMutex | ||||
| 	mutex        deadlock.RWMutex | ||||
| 	privateKey   NoisePrivateKey | ||||
| 	publicKey    NoisePublicKey | ||||
| 	routingTable RoutingTable | ||||
|  | @ -49,8 +50,15 @@ func (device *Device) Up() { | |||
| 	device.mutex.Lock() | ||||
| 	defer device.mutex.Unlock() | ||||
| 
 | ||||
| 	device.isUp.Set(true) | ||||
| 	updateBind(device) | ||||
| 	device.net.mutex.Lock() | ||||
| 	defer device.net.mutex.Unlock() | ||||
| 
 | ||||
| 	if device.isUp.Swap(true) { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	unsafeUpdateBind(device) | ||||
| 
 | ||||
| 	for _, peer := range device.peers { | ||||
| 		peer.Start() | ||||
| 	} | ||||
|  | @ -60,8 +68,12 @@ func (device *Device) Down() { | |||
| 	device.mutex.Lock() | ||||
| 	defer device.mutex.Unlock() | ||||
| 
 | ||||
| 	device.isUp.Set(false) | ||||
| 	if !device.isUp.Swap(false) { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	closeBind(device) | ||||
| 
 | ||||
| 	for _, peer := range device.peers { | ||||
| 		peer.Stop() | ||||
| 	} | ||||
|  | @ -75,7 +87,6 @@ func removePeerUnsafe(device *Device, key NoisePublicKey) { | |||
| 	if !ok { | ||||
| 		return | ||||
| 	} | ||||
| 	peer.mutex.Lock() | ||||
| 	peer.Stop() | ||||
| 	device.routingTable.RemovePeer(peer) | ||||
| 	delete(device.peers, key) | ||||
|  |  | |||
							
								
								
									
										62
									
								
								src/peer.go
									
									
									
									
									
								
							
							
						
						
									
										62
									
								
								src/peer.go
									
									
									
									
									
								
							|  | @ -8,6 +8,10 @@ import ( | |||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	PeerRoutineNumber = 4 | ||||
| ) | ||||
| 
 | ||||
| type Peer struct { | ||||
| 	id                          uint | ||||
| 	mutex                       sync.RWMutex | ||||
|  | @ -34,7 +38,6 @@ type Peer struct { | |||
| 		flushNonceQueue    Signal // size 1, empty queued packets
 | ||||
| 		messageSend        Signal // size 1, message was send to peer
 | ||||
| 		messageReceived    Signal // size 1, authenticated message recv
 | ||||
| 		stop               Signal // size 0, stop all goroutines in peer
 | ||||
| 	} | ||||
| 	timer struct { | ||||
| 		// state related to WireGuard timers
 | ||||
|  | @ -54,6 +57,12 @@ type Peer struct { | |||
| 		outbound chan *QueueOutboundElement // sequential ordering of work
 | ||||
| 		inbound  chan *QueueInboundElement  // sequential ordering of work
 | ||||
| 	} | ||||
| 	routines struct { | ||||
| 		mutex    sync.Mutex     // held when stopping / starting routines
 | ||||
| 		starting sync.WaitGroup // routines pending start
 | ||||
| 		stopping sync.WaitGroup // routines pending stop
 | ||||
| 		stop     Signal         // size 0, stop all goroutines in peer
 | ||||
| 	} | ||||
| 	mac CookieGenerator | ||||
| } | ||||
| 
 | ||||
|  | @ -121,6 +130,10 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { | |||
| 	peer.signal.handshakeCompleted = NewSignal() | ||||
| 	peer.signal.flushNonceQueue = NewSignal() | ||||
| 
 | ||||
| 	peer.routines.mutex.Lock() | ||||
| 	peer.routines.stop = NewSignal() | ||||
| 	peer.routines.mutex.Unlock() | ||||
| 
 | ||||
| 	return peer, nil | ||||
| } | ||||
| 
 | ||||
|  | @ -156,32 +169,43 @@ func (peer *Peer) String() string { | |||
| 	) | ||||
| } | ||||
| 
 | ||||
| /* Starts all routines for a given peer | ||||
|  * | ||||
|  * Requires that the caller holds the exclusive peer lock! | ||||
|  */ | ||||
| func unsafePeerStart(peer *Peer) { | ||||
| 	peer.signal.stop.Broadcast() | ||||
| 	peer.signal.stop = NewSignal() | ||||
| func (peer *Peer) Start() { | ||||
| 
 | ||||
| 	var wait sync.WaitGroup | ||||
| 	peer.routines.mutex.Lock() | ||||
| 	defer peer.routines.mutex.Lock() | ||||
| 
 | ||||
| 	wait.Add(1) | ||||
| 	// stop & wait for ungoing routines (if any)
 | ||||
| 
 | ||||
| 	peer.routines.stop.Broadcast() | ||||
| 	peer.routines.starting.Wait() | ||||
| 	peer.routines.stopping.Wait() | ||||
| 
 | ||||
| 	// reset signal and start (new) routines
 | ||||
| 
 | ||||
| 	peer.routines.stop = NewSignal() | ||||
| 	peer.routines.starting.Add(PeerRoutineNumber) | ||||
| 	peer.routines.stopping.Add(PeerRoutineNumber) | ||||
| 
 | ||||
| 	go peer.RoutineNonce() | ||||
| 	go peer.RoutineTimerHandler(&wait) | ||||
| 	go peer.RoutineTimerHandler() | ||||
| 	go peer.RoutineSequentialSender() | ||||
| 	go peer.RoutineSequentialReceiver() | ||||
| 
 | ||||
| 	wait.Wait() | ||||
| } | ||||
| 
 | ||||
| func (peer *Peer) Start() { | ||||
| 	peer.mutex.Lock() | ||||
| 	unsafePeerStart(peer) | ||||
| 	peer.mutex.Unlock() | ||||
| 	peer.routines.starting.Wait() | ||||
| } | ||||
| 
 | ||||
| func (peer *Peer) Stop() { | ||||
| 	peer.signal.stop.Broadcast() | ||||
| 
 | ||||
| 	peer.routines.mutex.Lock() | ||||
| 	defer peer.routines.mutex.Lock() | ||||
| 
 | ||||
| 	// stop & wait for ungoing routines (if any)
 | ||||
| 
 | ||||
| 	peer.routines.stop.Broadcast() | ||||
| 	peer.routines.starting.Wait() | ||||
| 	peer.routines.stopping.Wait() | ||||
| 
 | ||||
| 	// reset signal (to handle repeated stopping)
 | ||||
| 
 | ||||
| 	peer.routines.stop = NewSignal() | ||||
| } | ||||
|  |  | |||
|  | @ -497,7 +497,7 @@ func (peer *Peer) RoutineSequentialReceiver() { | |||
| 
 | ||||
| 		select { | ||||
| 
 | ||||
| 		case <-peer.signal.stop.Wait(): | ||||
| 		case <-peer.routines.stop.Wait(): | ||||
| 			logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id) | ||||
| 			return | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										11
									
								
								src/send.go
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								src/send.go
									
									
									
									
									
								
							|  | @ -192,7 +192,7 @@ func (peer *Peer) RoutineNonce() { | |||
| 	for { | ||||
| 	NextPacket: | ||||
| 		select { | ||||
| 		case <-peer.signal.stop.Wait(): | ||||
| 		case <-peer.routines.stop.Wait(): | ||||
| 			return | ||||
| 
 | ||||
| 		case elem := <-peer.queue.nonce: | ||||
|  | @ -217,7 +217,7 @@ func (peer *Peer) RoutineNonce() { | |||
| 					logDebug.Println("Clearing queue for", peer.String()) | ||||
| 					peer.FlushNonceQueue() | ||||
| 					goto NextPacket | ||||
| 				case <-peer.signal.stop.Wait(): | ||||
| 				case <-peer.routines.stop.Wait(): | ||||
| 					return | ||||
| 				} | ||||
| 			} | ||||
|  | @ -309,15 +309,20 @@ func (device *Device) RoutineEncryption() { | |||
|  * The routine terminates then the outbound queue is closed. | ||||
|  */ | ||||
| func (peer *Peer) RoutineSequentialSender() { | ||||
| 
 | ||||
| 	defer peer.routines.stopping.Done() | ||||
| 
 | ||||
| 	device := peer.device | ||||
| 
 | ||||
| 	logDebug := device.log.Debug | ||||
| 	logDebug.Println("Routine, sequential sender, started for", peer.String()) | ||||
| 
 | ||||
| 	peer.routines.starting.Done() | ||||
| 
 | ||||
| 	for { | ||||
| 		select { | ||||
| 
 | ||||
| 		case <-peer.signal.stop.Wait(): | ||||
| 		case <-peer.routines.stop.Wait(): | ||||
| 			logDebug.Println( | ||||
| 				"Routine, sequential sender, stopped for", peer.String()) | ||||
| 			return | ||||
|  |  | |||
|  | @ -4,7 +4,6 @@ import ( | |||
| 	"bytes" | ||||
| 	"encoding/binary" | ||||
| 	"math/rand" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
| ) | ||||
|  | @ -182,7 +181,10 @@ func (peer *Peer) sendNewHandshake() error { | |||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) { | ||||
| func (peer *Peer) RoutineTimerHandler() { | ||||
| 
 | ||||
| 	defer peer.routines.stopping.Done() | ||||
| 
 | ||||
| 	device := peer.device | ||||
| 
 | ||||
| 	logInfo := device.log.Info | ||||
|  | @ -203,15 +205,20 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) { | |||
| 		peer.timer.keepalivePersistent.Reset(duration) | ||||
| 	} | ||||
| 
 | ||||
| 	// signal that timers are reset
 | ||||
| 	// signal synchronised setup complete
 | ||||
| 
 | ||||
| 	ready.Done() | ||||
| 	peer.routines.starting.Done() | ||||
| 
 | ||||
| 	// handle timer events
 | ||||
| 
 | ||||
| 	for { | ||||
| 		select { | ||||
| 
 | ||||
| 		/* stopping */ | ||||
| 
 | ||||
| 		case <-peer.routines.stop.Wait(): | ||||
| 			return | ||||
| 
 | ||||
| 		/* timers */ | ||||
| 
 | ||||
| 		// keep-alive
 | ||||
|  | @ -312,9 +319,6 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) { | |||
| 
 | ||||
| 		/* signals */ | ||||
| 
 | ||||
| 		case <-peer.signal.stop.Wait(): | ||||
| 			return | ||||
| 
 | ||||
| 		case <-peer.signal.handshakeBegin.Wait(): | ||||
| 
 | ||||
| 			peer.signal.handshakeBegin.Disable() | ||||
|  |  | |||
|  | @ -45,14 +45,14 @@ func (device *Device) RoutineTUNEventReader() { | |||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if event&TUNEventUp != 0 { | ||||
| 		if event&TUNEventUp != 0 && !device.isUp.Get() { | ||||
| 			logInfo.Println("Interface set up") | ||||
| 			device.Up() | ||||
| 		} | ||||
| 
 | ||||
| 		if event&TUNEventDown != 0 { | ||||
| 		if event&TUNEventDown != 0 && device.isUp.Get() { | ||||
| 			logInfo.Println("Interface set down") | ||||
| 			device.Up() | ||||
| 			device.Down() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  |  | |||
							
								
								
									
										16
									
								
								src/uapi.go
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								src/uapi.go
									
									
									
									
									
								
							|  | @ -133,13 +133,27 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | |||
| 				device.SetPrivateKey(sk) | ||||
| 
 | ||||
| 			case "listen_port": | ||||
| 
 | ||||
| 				// parse port number
 | ||||
| 
 | ||||
| 				port, err := strconv.ParseUint(value, 10, 16) | ||||
| 				if err != nil { | ||||
| 					logError.Println("Failed to parse listen_port:", err) | ||||
| 					return &IPCError{Code: ipcErrorInvalid} | ||||
| 				} | ||||
| 
 | ||||
| 				// update port and rebind
 | ||||
| 
 | ||||
| 				device.mutex.Lock() | ||||
| 				device.net.mutex.Lock() | ||||
| 
 | ||||
| 				device.net.port = uint16(port) | ||||
| 				if err := updateBind(device); err != nil { | ||||
| 				err = unsafeUpdateBind(device) | ||||
| 
 | ||||
| 				device.net.mutex.Unlock() | ||||
| 				device.mutex.Unlock() | ||||
| 
 | ||||
| 				if err != nil { | ||||
| 					logError.Println("Failed to set listen_port:", err) | ||||
| 					return &IPCError{Code: ipcErrorPortInUse} | ||||
| 				} | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue