Fixed deadlock in index.go
This commit is contained in:
		
							parent
							
								
									dd4da93749
								
							
						
					
					
						commit
						c5d7efc246
					
				
					 8 changed files with 194 additions and 152 deletions
				
			
		
							
								
								
									
										162
									
								
								src/config.go
									
									
									
									
									
								
							
							
						
						
									
										162
									
								
								src/config.go
									
									
									
									
									
								
							|  | @ -8,39 +8,36 @@ import ( | |||
| 	"net" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync/atomic" | ||||
| 	"syscall" | ||||
| ) | ||||
| 
 | ||||
| // #include <errno.h>
 | ||||
| import "C" | ||||
| 
 | ||||
| /* TODO: More fine grained? | ||||
|  */ | ||||
| const ( | ||||
| 	ipcErrorNoPeer       = C.EPROTO | ||||
| 	ipcErrorNoKeyValue   = C.EPROTO | ||||
| 	ipcErrorInvalidKey   = C.EPROTO | ||||
| 	ipcErrorInvalidValue = C.EPROTO | ||||
| 	ipcErrorIO           = syscall.EIO | ||||
| 	ipcErrorNoPeer       = syscall.EPROTO | ||||
| 	ipcErrorNoKeyValue   = syscall.EPROTO | ||||
| 	ipcErrorInvalidKey   = syscall.EPROTO | ||||
| 	ipcErrorInvalidValue = syscall.EPROTO | ||||
| ) | ||||
| 
 | ||||
| type IPCError struct { | ||||
| 	Code int | ||||
| 	Code syscall.Errno | ||||
| } | ||||
| 
 | ||||
| func (s *IPCError) Error() string { | ||||
| 	return fmt.Sprintf("IPC error: %d", s.Code) | ||||
| } | ||||
| 
 | ||||
| func (s *IPCError) ErrorCode() int { | ||||
| 	return s.Code | ||||
| func (s *IPCError) ErrorCode() uintptr { | ||||
| 	return uintptr(s.Code) | ||||
| } | ||||
| 
 | ||||
| func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { | ||||
| 
 | ||||
| 	device.mutex.RLock() | ||||
| 	defer device.mutex.RUnlock() | ||||
| func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | ||||
| 
 | ||||
| 	// create lines
 | ||||
| 
 | ||||
| 	device.mutex.RLock() | ||||
| 
 | ||||
| 	lines := make([]string, 0, 100) | ||||
| 	send := func(line string) { | ||||
| 		lines = append(lines, line) | ||||
|  | @ -63,19 +60,25 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { | |||
| 			} | ||||
| 			send(fmt.Sprintf("tx_bytes=%d", peer.txBytes)) | ||||
| 			send(fmt.Sprintf("rx_bytes=%d", peer.rxBytes)) | ||||
| 			send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) | ||||
| 			send(fmt.Sprintf("persistent_keepalive_interval=%d", | ||||
| 				atomic.LoadUint64(&peer.persistentKeepaliveInterval), | ||||
| 			)) | ||||
| 			for _, ip := range device.routingTable.AllowedIPs(peer) { | ||||
| 				send("allowed_ip=" + ip.String()) | ||||
| 			} | ||||
| 		}() | ||||
| 	} | ||||
| 
 | ||||
| 	device.mutex.RUnlock() | ||||
| 
 | ||||
| 	// send lines
 | ||||
| 
 | ||||
| 	for _, line := range lines { | ||||
| 		_, err := socket.WriteString(line + "\n") | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 			return &IPCError{ | ||||
| 				Code: ipcErrorIO, | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
|  | @ -83,13 +86,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { | |||
| } | ||||
| 
 | ||||
| func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | ||||
| 	logger := device.log.Debug | ||||
| 	scanner := bufio.NewScanner(socket) | ||||
| 	logError := device.log.Error | ||||
| 	logDebug := device.log.Debug | ||||
| 
 | ||||
| 	var peer *Peer | ||||
| 	for scanner.Scan() { | ||||
| 
 | ||||
| 		// Parse line
 | ||||
| 		// parse line
 | ||||
| 
 | ||||
| 		line := scanner.Text() | ||||
| 		if line == "" { | ||||
|  | @ -97,7 +101,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | |||
| 		} | ||||
| 		parts := strings.Split(line, "=") | ||||
| 		if len(parts) != 2 { | ||||
| 			device.log.Debug.Println(parts) | ||||
| 			return &IPCError{Code: ipcErrorNoKeyValue} | ||||
| 		} | ||||
| 		key := parts[0] | ||||
|  | @ -105,7 +108,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | |||
| 
 | ||||
| 		switch key { | ||||
| 
 | ||||
| 		/* Interface configuration */ | ||||
| 		/* interface configuration */ | ||||
| 
 | ||||
| 		case "private_key": | ||||
| 			if value == "" { | ||||
|  | @ -116,7 +119,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | |||
| 				var sk NoisePrivateKey | ||||
| 				err := sk.FromHex(value) | ||||
| 				if err != nil { | ||||
| 					logger.Println("Failed to set private_key:", err) | ||||
| 					logError.Println("Failed to set private_key:", err) | ||||
| 					return &IPCError{Code: ipcErrorInvalidValue} | ||||
| 				} | ||||
| 				device.SetPrivateKey(sk) | ||||
|  | @ -126,22 +129,26 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | |||
| 			var port int | ||||
| 			_, err := fmt.Sscanf(value, "%d", &port) | ||||
| 			if err != nil || port > (1<<16) || port < 0 { | ||||
| 				logger.Println("Failed to set listen_port:", err) | ||||
| 				logError.Println("Failed to set listen_port:", err) | ||||
| 				return &IPCError{Code: ipcErrorInvalidValue} | ||||
| 			} | ||||
| 			device.net.mutex.Lock() | ||||
| 			device.net.addr.Port = port | ||||
| 			device.net.conn, err = net.ListenUDP("udp", device.net.addr) | ||||
| 			device.net.mutex.Unlock() | ||||
| 			if err != nil { | ||||
| 				logError.Println("Failed to create UDP listener:", err) | ||||
| 				return &IPCError{Code: ipcErrorInvalidValue} | ||||
| 			} | ||||
| 
 | ||||
| 		case "fwmark": | ||||
| 			logger.Println("FWMark not handled yet") | ||||
| 			logError.Println("FWMark not handled yet") | ||||
| 
 | ||||
| 		case "public_key": | ||||
| 			var pubKey NoisePublicKey | ||||
| 			err := pubKey.FromHex(value) | ||||
| 			if err != nil { | ||||
| 				logger.Println("Failed to get peer by public_key:", err) | ||||
| 				logError.Println("Failed to get peer by public_key:", err) | ||||
| 				return &IPCError{Code: ipcErrorInvalidValue} | ||||
| 			} | ||||
| 			device.mutex.RLock() | ||||
|  | @ -153,22 +160,23 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | |||
| 				peer = device.NewPeer(pubKey) | ||||
| 			} | ||||
| 			if peer == nil { | ||||
| 				panic(errors.New("bug: failed to find peer")) | ||||
| 				panic(errors.New("bug: failed to find / create peer")) | ||||
| 			} | ||||
| 
 | ||||
| 		case "replace_peers": | ||||
| 			if value == "true" { | ||||
| 				device.RemoveAllPeers() | ||||
| 			} else { | ||||
| 				logger.Println("Failed to set replace_peers, invalid value:", value) | ||||
| 				logError.Println("Failed to set replace_peers, invalid value:", value) | ||||
| 				return &IPCError{Code: ipcErrorInvalidValue} | ||||
| 			} | ||||
| 
 | ||||
| 		default: | ||||
| 			/* Peer configuration */ | ||||
| 
 | ||||
| 			/* peer configuration */ | ||||
| 
 | ||||
| 			if peer == nil { | ||||
| 				logger.Println("No peer referenced, before peer operation") | ||||
| 				logError.Println("No peer referenced, before peer operation") | ||||
| 				return &IPCError{Code: ipcErrorNoPeer} | ||||
| 			} | ||||
| 
 | ||||
|  | @ -178,7 +186,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | |||
| 				peer.mutex.Lock() | ||||
| 				device.RemovePeer(peer.handshake.remoteStatic) | ||||
| 				peer.mutex.Unlock() | ||||
| 				logger.Println("Remove peer") | ||||
| 				logDebug.Println("Removing", peer.String()) | ||||
| 				peer = nil | ||||
| 
 | ||||
| 			case "preshared_key": | ||||
|  | @ -188,14 +196,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | |||
| 					return peer.handshake.presharedKey.FromHex(value) | ||||
| 				}() | ||||
| 				if err != nil { | ||||
| 					logger.Println("Failed to set preshared_key:", err) | ||||
| 					logError.Println("Failed to set preshared_key:", err) | ||||
| 					return &IPCError{Code: ipcErrorInvalidValue} | ||||
| 				} | ||||
| 
 | ||||
| 			case "endpoint": | ||||
| 				addr, err := net.ResolveUDPAddr("udp", value) | ||||
| 				if err != nil { | ||||
| 					logger.Println("Failed to set endpoint:", value) | ||||
| 					logError.Println("Failed to set endpoint:", value) | ||||
| 					return &IPCError{Code: ipcErrorInvalidValue} | ||||
| 				} | ||||
| 				peer.mutex.Lock() | ||||
|  | @ -205,35 +213,34 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | |||
| 			case "persistent_keepalive_interval": | ||||
| 				secs, err := strconv.ParseInt(value, 10, 64) | ||||
| 				if secs < 0 || err != nil { | ||||
| 					logger.Println("Failed to set persistent_keepalive_interval:", err) | ||||
| 					logError.Println("Failed to set persistent_keepalive_interval:", err) | ||||
| 					return &IPCError{Code: ipcErrorInvalidValue} | ||||
| 				} | ||||
| 				peer.mutex.Lock() | ||||
| 				peer.persistentKeepaliveInterval = uint64(secs) | ||||
| 				peer.mutex.Unlock() | ||||
| 				atomic.StoreUint64( | ||||
| 					&peer.persistentKeepaliveInterval, | ||||
| 					uint64(secs), | ||||
| 				) | ||||
| 
 | ||||
| 			case "replace_allowed_ips": | ||||
| 				if value == "true" { | ||||
| 					device.routingTable.RemovePeer(peer) | ||||
| 				} else { | ||||
| 					logger.Println("Failed to set replace_allowed_ips, invalid value:", value) | ||||
| 					logError.Println("Failed to set replace_allowed_ips, invalid value:", value) | ||||
| 					return &IPCError{Code: ipcErrorInvalidValue} | ||||
| 				} | ||||
| 
 | ||||
| 			case "allowed_ip": | ||||
| 				_, network, err := net.ParseCIDR(value) | ||||
| 				if err != nil { | ||||
| 					logger.Println("Failed to set allowed_ip:", err) | ||||
| 					logError.Println("Failed to set allowed_ip:", err) | ||||
| 					return &IPCError{Code: ipcErrorInvalidValue} | ||||
| 				} | ||||
| 				ones, _ := network.Mask.Size() | ||||
| 				logger.Println(network, ones, network.IP) | ||||
| 				logError.Println(network, ones, network.IP) | ||||
| 				device.routingTable.Insert(network.IP, uint(ones), peer) | ||||
| 
 | ||||
| 			/* Invalid key */ | ||||
| 
 | ||||
| 			default: | ||||
| 				logger.Println("Invalid key:", key) | ||||
| 				logError.Println("Invalid UAPI key:", key) | ||||
| 				return &IPCError{Code: ipcErrorInvalidKey} | ||||
| 			} | ||||
| 		} | ||||
|  | @ -244,46 +251,45 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | |||
| 
 | ||||
| func ipcHandle(device *Device, socket net.Conn) { | ||||
| 
 | ||||
| 	func() { | ||||
| 		buffered := func(s io.ReadWriter) *bufio.ReadWriter { | ||||
| 			reader := bufio.NewReader(s) | ||||
| 			writer := bufio.NewWriter(s) | ||||
| 			return bufio.NewReadWriter(reader, writer) | ||||
| 		}(socket) | ||||
| 	defer socket.Close() | ||||
| 
 | ||||
| 		defer buffered.Flush() | ||||
| 	buffered := func(s io.ReadWriter) *bufio.ReadWriter { | ||||
| 		reader := bufio.NewReader(s) | ||||
| 		writer := bufio.NewWriter(s) | ||||
| 		return bufio.NewReadWriter(reader, writer) | ||||
| 	}(socket) | ||||
| 
 | ||||
| 		op, err := buffered.ReadString('\n') | ||||
| 	defer buffered.Flush() | ||||
| 
 | ||||
| 	op, err := buffered.ReadString('\n') | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	switch op { | ||||
| 
 | ||||
| 	case "set=1\n": | ||||
| 		device.log.Debug.Println("Config, set operation") | ||||
| 		err := ipcSetOperation(device, buffered) | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 			fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) | ||||
| 		} else { | ||||
| 			fmt.Fprintf(buffered, "errno=0\n\n") | ||||
| 		} | ||||
| 		return | ||||
| 
 | ||||
| 		switch op { | ||||
| 
 | ||||
| 		case "set=1\n": | ||||
| 			device.log.Debug.Println("Config, set operation") | ||||
| 			err := ipcSetOperation(device, buffered) | ||||
| 			if err != nil { | ||||
| 				fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) | ||||
| 			} else { | ||||
| 				fmt.Fprintf(buffered, "errno=0\n\n") | ||||
| 			} | ||||
| 			break | ||||
| 
 | ||||
| 		case "get=1\n": | ||||
| 			device.log.Debug.Println("Config, get operation") | ||||
| 			err := ipcGetOperation(device, buffered) | ||||
| 			if err != nil { | ||||
| 				fmt.Fprintf(buffered, "errno=1\n\n") // fix
 | ||||
| 			} else { | ||||
| 				fmt.Fprintf(buffered, "errno=0\n\n") | ||||
| 			} | ||||
| 			break | ||||
| 
 | ||||
| 		default: | ||||
| 			device.log.Info.Println("Invalid UAPI operation:", op) | ||||
| 	case "get=1\n": | ||||
| 		device.log.Debug.Println("Config, get operation") | ||||
| 		err := ipcGetOperation(device, buffered) | ||||
| 		if err != nil { | ||||
| 			fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) | ||||
| 		} else { | ||||
| 			fmt.Fprintf(buffered, "errno=0\n\n") | ||||
| 		} | ||||
| 	}() | ||||
| 		return | ||||
| 
 | ||||
| 	socket.Close() | ||||
| 	default: | ||||
| 		device.log.Error.Println("Invalid UAPI operation:", op) | ||||
| 
 | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -78,7 +78,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { | |||
| 	defer device.mutex.Unlock() | ||||
| 
 | ||||
| 	device.log = NewLogger(logLevel) | ||||
| 	// device.mtu = tun.MTU()
 | ||||
| 	device.peers = make(map[NoisePublicKey]*Peer) | ||||
| 	device.indices.Init() | ||||
| 	device.ratelimiter.Init() | ||||
|  | @ -131,12 +130,21 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { | |||
| 
 | ||||
| func (device *Device) RoutineMTUUpdater(tun TUNDevice) { | ||||
| 	logError := device.log.Error | ||||
| 	for ; ; time.Sleep(time.Second) { | ||||
| 	for ; ; time.Sleep(5 * time.Second) { | ||||
| 
 | ||||
| 		// load updated MTU
 | ||||
| 
 | ||||
| 		mtu, err := tun.MTU() | ||||
| 		if err != nil { | ||||
| 			logError.Println("Failed to load updated MTU of device:", err) | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		// upper bound of mtu
 | ||||
| 
 | ||||
| 		if mtu+MessageTransportSize > MaxMessageSize { | ||||
| 			mtu = MaxMessageSize - MessageTransportSize | ||||
| 		} | ||||
| 		atomic.StoreInt32(&device.mtu, int32(mtu)) | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -6,8 +6,6 @@ import ( | |||
| ) | ||||
| 
 | ||||
| /* Index=0 is reserved for unset indecies | ||||
|  * | ||||
|  * TODO: Rethink map[id] -> peer VS map[id] -> handshake and handshake <ref> peer | ||||
|  * | ||||
|  */ | ||||
| 
 | ||||
|  | @ -72,12 +70,12 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) { | |||
| 
 | ||||
| 		table.mutex.RLock() | ||||
| 		_, ok := table.table[index] | ||||
| 		table.mutex.RUnlock() | ||||
| 		if ok { | ||||
| 			continue | ||||
| 		} | ||||
| 		table.mutex.RUnlock() | ||||
| 
 | ||||
| 		// replace index
 | ||||
| 		// map index to handshake
 | ||||
| 
 | ||||
| 		table.mutex.Lock() | ||||
| 		_, found := table.table[index] | ||||
|  |  | |||
							
								
								
									
										20
									
								
								src/main.go
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								src/main.go
									
									
									
									
									
								
							|  | @ -17,12 +17,14 @@ func main() { | |||
| 	} | ||||
| 
 | ||||
| 	switch os.Args[1] { | ||||
| 
 | ||||
| 	case "-f", "--foreground": | ||||
| 		foreground = true | ||||
| 		if len(os.Args) != 3 { | ||||
| 			return | ||||
| 		} | ||||
| 		interfaceName = os.Args[2] | ||||
| 
 | ||||
| 	default: | ||||
| 		foreground = false | ||||
| 		if len(os.Args) != 2 { | ||||
|  | @ -48,8 +50,8 @@ func main() { | |||
| 	// open TUN device
 | ||||
| 
 | ||||
| 	tun, err := CreateTUN(interfaceName) | ||||
| 	log.Println(tun, err) | ||||
| 	if err != nil { | ||||
| 		log.Println("Failed to create tun device:", err) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
|  | @ -69,11 +71,15 @@ func main() { | |||
| 	} | ||||
| 	defer uapi.Close() | ||||
| 
 | ||||
| 	for { | ||||
| 		conn, err := uapi.Accept() | ||||
| 		if err != nil { | ||||
| 			logError.Fatal("accept error:", err) | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			conn, err := uapi.Accept() | ||||
| 			if err != nil { | ||||
| 				logError.Fatal("UAPI accept error:", err) | ||||
| 			} | ||||
| 			go ipcHandle(device, conn) | ||||
| 		} | ||||
| 		go ipcHandle(device, conn) | ||||
| 	} | ||||
| 	}() | ||||
| 
 | ||||
| 	device.Wait() | ||||
| } | ||||
|  |  | |||
|  | @ -459,7 +459,8 @@ func (peer *Peer) NewKeyPair() *KeyPair { | |||
| 
 | ||||
| 	// remap index
 | ||||
| 
 | ||||
| 	peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{ | ||||
| 	indices := &peer.device.indices | ||||
| 	indices.Insert(handshake.localIndex, IndexTableEntry{ | ||||
| 		peer:      peer, | ||||
| 		keyPair:   keyPair, | ||||
| 		handshake: nil, | ||||
|  | @ -476,7 +477,7 @@ func (peer *Peer) NewKeyPair() *KeyPair { | |||
| 			if kp.previous != nil { | ||||
| 				kp.previous.send = nil | ||||
| 				kp.previous.receive = nil | ||||
| 				peer.device.indices.Delete(kp.previous.localIndex) | ||||
| 				indices.Delete(kp.previous.localIndex) | ||||
| 			} | ||||
| 			kp.previous = kp.current | ||||
| 			kp.current = keyPair | ||||
|  |  | |||
|  | @ -212,18 +212,18 @@ func (device *Device) RoutineReceiveIncomming() { | |||
| 				// add to peer queue
 | ||||
| 
 | ||||
| 				peer := value.peer | ||||
| 				work := &QueueInboundElement{ | ||||
| 				elem := &QueueInboundElement{ | ||||
| 					packet:  packet, | ||||
| 					buffer:  buffer, | ||||
| 					keyPair: keyPair, | ||||
| 					dropped: AtomicFalse, | ||||
| 				} | ||||
| 				work.mutex.Lock() | ||||
| 				elem.mutex.Lock() | ||||
| 
 | ||||
| 				// add to decryption queues
 | ||||
| 
 | ||||
| 				device.addToInboundQueue(device.queue.decryption, work) | ||||
| 				device.addToInboundQueue(peer.queue.inbound, work) | ||||
| 				device.addToInboundQueue(device.queue.decryption, elem) | ||||
| 				device.addToInboundQueue(peer.queue.inbound, elem) | ||||
| 				buffer = nil | ||||
| 
 | ||||
| 			default: | ||||
|  |  | |||
							
								
								
									
										81
									
								
								src/send.go
									
									
									
									
									
								
							
							
						
						
									
										81
									
								
								src/send.go
									
									
									
									
									
								
							|  | @ -270,50 +270,65 @@ func (peer *Peer) RoutineNonce() { | |||
|  * Obs. One instance per core | ||||
|  */ | ||||
| func (device *Device) RoutineEncryption() { | ||||
| 
 | ||||
| 	var elem *QueueOutboundElement | ||||
| 	var nonce [chacha20poly1305.NonceSize]byte | ||||
| 	for work := range device.queue.encryption { | ||||
| 
 | ||||
| 	logDebug := device.log.Debug | ||||
| 	logDebug.Println("Routine, encryption worker, started") | ||||
| 
 | ||||
| 	for { | ||||
| 
 | ||||
| 		// fetch next element
 | ||||
| 
 | ||||
| 		select { | ||||
| 		case elem = <-device.queue.encryption: | ||||
| 		case <-device.signal.stop: | ||||
| 			logDebug.Println("Routine, encryption worker, stopped") | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		// check if dropped
 | ||||
| 
 | ||||
| 		if work.IsDropped() { | ||||
| 		if elem.IsDropped() { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		// populate header fields
 | ||||
| 
 | ||||
| 		header := work.buffer[:MessageTransportHeaderSize] | ||||
| 		header := elem.buffer[:MessageTransportHeaderSize] | ||||
| 
 | ||||
| 		fieldType := header[0:4] | ||||
| 		fieldReceiver := header[4:8] | ||||
| 		fieldNonce := header[8:16] | ||||
| 
 | ||||
| 		binary.LittleEndian.PutUint32(fieldType, MessageTransportType) | ||||
| 		binary.LittleEndian.PutUint32(fieldReceiver, work.keyPair.remoteIndex) | ||||
| 		binary.LittleEndian.PutUint64(fieldNonce, work.nonce) | ||||
| 		binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex) | ||||
| 		binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) | ||||
| 
 | ||||
| 		// pad content to MTU size
 | ||||
| 
 | ||||
| 		mtu := int(atomic.LoadInt32(&device.mtu)) | ||||
| 		for i := len(work.packet); i < mtu; i++ { | ||||
| 			work.packet = append(work.packet, 0) | ||||
| 		for i := len(elem.packet); i < mtu; i++ { | ||||
| 			elem.packet = append(elem.packet, 0) | ||||
| 		} | ||||
| 
 | ||||
| 		// encrypt content
 | ||||
| 
 | ||||
| 		binary.LittleEndian.PutUint64(nonce[4:], work.nonce) | ||||
| 		work.packet = work.keyPair.send.Seal( | ||||
| 			work.packet[:0], | ||||
| 		binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) | ||||
| 		elem.packet = elem.keyPair.send.Seal( | ||||
| 			elem.packet[:0], | ||||
| 			nonce[:], | ||||
| 			work.packet, | ||||
| 			elem.packet, | ||||
| 			nil, | ||||
| 		) | ||||
| 		length := MessageTransportHeaderSize + len(work.packet) | ||||
| 		work.packet = work.buffer[:length] | ||||
| 		work.mutex.Unlock() | ||||
| 		length := MessageTransportHeaderSize + len(elem.packet) | ||||
| 		elem.packet = elem.buffer[:length] | ||||
| 		elem.mutex.Unlock() | ||||
| 
 | ||||
| 		// refresh key if necessary
 | ||||
| 
 | ||||
| 		work.peer.KeepKeyFreshSending() | ||||
| 		elem.peer.KeepKeyFreshSending() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | @ -334,49 +349,43 @@ func (peer *Peer) RoutineSequentialSender() { | |||
| 			logDebug.Println("Routine, sequential sender, stopped for", peer.String()) | ||||
| 			return | ||||
| 
 | ||||
| 		case work := <-peer.queue.outbound: | ||||
| 			work.mutex.Lock() | ||||
| 		case elem := <-peer.queue.outbound: | ||||
| 			elem.mutex.Lock() | ||||
| 
 | ||||
| 			func() { | ||||
| 
 | ||||
| 				// return buffer to pool after processing
 | ||||
| 
 | ||||
| 				defer device.PutMessageBuffer(work.buffer) | ||||
| 				if work.IsDropped() { | ||||
| 				if elem.IsDropped() { | ||||
| 					return | ||||
| 				} | ||||
| 
 | ||||
| 				// send to endpoint
 | ||||
| 				// get endpoint and connection
 | ||||
| 
 | ||||
| 				peer.mutex.RLock() | ||||
| 				defer peer.mutex.RUnlock() | ||||
| 
 | ||||
| 				if peer.endpoint == nil { | ||||
| 				endpoint := peer.endpoint | ||||
| 				peer.mutex.RUnlock() | ||||
| 				if endpoint == nil { | ||||
| 					logDebug.Println("No endpoint for", peer.String()) | ||||
| 					return | ||||
| 				} | ||||
| 
 | ||||
| 				device.net.mutex.RLock() | ||||
| 				defer device.net.mutex.RUnlock() | ||||
| 
 | ||||
| 				if device.net.conn == nil { | ||||
| 				conn := device.net.conn | ||||
| 				device.net.mutex.RUnlock() | ||||
| 				if conn == nil { | ||||
| 					logDebug.Println("No source for device") | ||||
| 					return | ||||
| 				} | ||||
| 
 | ||||
| 				// send message and return buffer to pool
 | ||||
| 				// send message and refresh keys
 | ||||
| 
 | ||||
| 				_, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint) | ||||
| 				_, err := conn.WriteToUDP(elem.packet, endpoint) | ||||
| 				if err != nil { | ||||
| 					return | ||||
| 				} | ||||
| 
 | ||||
| 				atomic.AddUint64(&peer.txBytes, uint64(len(work.packet))) | ||||
| 
 | ||||
| 				// reset keep-alive
 | ||||
| 
 | ||||
| 				atomic.AddUint64(&peer.txBytes, uint64(len(elem.packet))) | ||||
| 				peer.TimerResetKeepalive() | ||||
| 			}() | ||||
| 
 | ||||
| 			device.PutMessageBuffer(elem.buffer) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -138,6 +138,7 @@ func (peer *Peer) BeginHandshakeInitiation() (*QueueOutboundElement, error) { | |||
| 
 | ||||
| func (peer *Peer) RoutineTimerHandler() { | ||||
| 	device := peer.device | ||||
| 	indices := &device.indices | ||||
| 
 | ||||
| 	logDebug := device.log.Debug | ||||
| 	logDebug.Println("Routine, timer handler, started for peer", peer.String()) | ||||
|  | @ -170,29 +171,42 @@ func (peer *Peer) RoutineTimerHandler() { | |||
| 
 | ||||
| 			logDebug.Println("Clearing all key material for", peer.String()) | ||||
| 
 | ||||
| 			// zero out key pairs
 | ||||
| 			kp := &peer.keyPairs | ||||
| 			kp.mutex.Lock() | ||||
| 
 | ||||
| 			func() { | ||||
| 				kp := &peer.keyPairs | ||||
| 				kp.mutex.Lock() | ||||
| 				// best we can do is wait for GC :( ?
 | ||||
| 				kp.current = nil | ||||
| 				kp.previous = nil | ||||
| 				kp.next = nil | ||||
| 				kp.mutex.Unlock() | ||||
| 			}() | ||||
| 			hs := &peer.handshake | ||||
| 			hs.mutex.Lock() | ||||
| 
 | ||||
| 			// unmap local indecies
 | ||||
| 
 | ||||
| 			indices.mutex.Lock() | ||||
| 			if kp.previous != nil { | ||||
| 				delete(indices.table, kp.previous.localIndex) | ||||
| 			} | ||||
| 			if kp.current != nil { | ||||
| 				delete(indices.table, kp.current.localIndex) | ||||
| 			} | ||||
| 			if kp.next != nil { | ||||
| 				delete(indices.table, kp.next.localIndex) | ||||
| 			} | ||||
| 			delete(indices.table, hs.localIndex) | ||||
| 			indices.mutex.Unlock() | ||||
| 
 | ||||
| 			// zero out key pairs (TODO: better than wait for GC)
 | ||||
| 
 | ||||
| 			kp.current = nil | ||||
| 			kp.previous = nil | ||||
| 			kp.next = nil | ||||
| 			kp.mutex.Unlock() | ||||
| 
 | ||||
| 			// zero out handshake
 | ||||
| 
 | ||||
| 			func() { | ||||
| 				hs := &peer.handshake | ||||
| 				hs.mutex.Lock() | ||||
| 				hs.localEphemeral = NoisePrivateKey{} | ||||
| 				hs.remoteEphemeral = NoisePublicKey{} | ||||
| 				hs.chainKey = [blake2s.Size]byte{} | ||||
| 				hs.hash = [blake2s.Size]byte{} | ||||
| 				hs.mutex.Unlock() | ||||
| 			}() | ||||
| 			hs.localIndex = 0 | ||||
| 			hs.localEphemeral = NoisePrivateKey{} | ||||
| 			hs.remoteEphemeral = NoisePublicKey{} | ||||
| 			hs.chainKey = [blake2s.Size]byte{} | ||||
| 			hs.hash = [blake2s.Size]byte{} | ||||
| 			hs.mutex.Unlock() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue