diff --git a/src/constants.go b/src/constants.go index 2e484f3..053ba4f 100644 --- a/src/constants.go +++ b/src/constants.go @@ -21,5 +21,6 @@ const ( QueueInboundSize = 1024 QueueHandshakeSize = 1024 QueueHandshakeBusySize = QueueHandshakeSize / 8 - MinMessageSize = MessageTransportSize + MinMessageSize = MessageTransportSize // keep-alive + MaxMessageSize = 4096 ) diff --git a/src/device.go b/src/device.go index ff10e32..a317122 100644 --- a/src/device.go +++ b/src/device.go @@ -80,6 +80,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) + device.queue.inbound = make(chan []byte, QueueInboundSize) // prepare signals @@ -94,6 +95,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { } go device.RoutineReadFromTUN(tun) go device.RoutineReceiveIncomming() + go device.RoutineWriteToTUN(tun) return device } diff --git a/src/handshake.go b/src/handshake.go index cf73e9b..88bb8cb 100644 --- a/src/handshake.go +++ b/src/handshake.go @@ -12,9 +12,11 @@ import ( * Used by initiator of handshake and with active keep-alive */ func (peer *Peer) SendKeepAlive() bool { + elem := peer.device.NewOutboundElement() + elem.packet = nil if len(peer.queue.nonce) == 0 { select { - case peer.queue.nonce <- []byte{}: + case peer.queue.nonce <- elem: return true default: return false @@ -60,11 +62,10 @@ func (peer *Peer) KeepKeyFreshSending() { */ func (peer *Peer) RoutineHandshakeInitiator() { device := peer.device - buffer := make([]byte, 1024) logger := device.log.Debug timeout := stoppedTimer() - var work *QueueOutboundElement + var elem *QueueOutboundElement logger.Println("Routine, handshake initator, started for peer", peer.id) @@ -94,25 +95,25 @@ func (peer *Peer) RoutineHandshakeInitiator() { // create initiation - if work != nil { - work.mutex.Lock() - work.packet = nil - work.mutex.Unlock() + if elem != nil { + elem.Drop() } - work = new(QueueOutboundElement) + elem = device.NewOutboundElement() + msg, err := device.CreateMessageInitiation(peer) if err != nil { device.log.Error.Println("Failed to create initiation message:", err) break } - // schedule for sending + // marshal & schedule for sending - writer := bytes.NewBuffer(buffer[:0]) + writer := bytes.NewBuffer(elem.data[:0]) binary.Write(writer, binary.LittleEndian, msg) - work.packet = writer.Bytes() - peer.mac.AddMacs(work.packet) - peer.InsertOutbound(work) + elem.packet = writer.Bytes() + peer.mac.AddMacs(elem.packet) + println(elem) + addToOutboundQueue(peer.queue.outbound, elem) if attempts == 0 { deadline = time.Now().Add(MaxHandshakeAttemptTime) @@ -132,9 +133,11 @@ func (peer *Peer) RoutineHandshakeInitiator() { return case <-peer.signal.handshakeCompleted: + device.log.Debug.Println("Handshake complete") break HandshakeLoop case <-timeout.C: + device.log.Debug.Println("Timeout") if deadline.Before(time.Now().Add(RekeyTimeout)) { peer.signal.flushNonceQueue <- struct{}{} if !peer.timer.sendKeepalive.Stop() { diff --git a/src/keypair.go b/src/keypair.go index 0fac5cb..3caa0c8 100644 --- a/src/keypair.go +++ b/src/keypair.go @@ -7,8 +7,7 @@ import ( ) type KeyPair struct { - recv cipher.AEAD - recvNonce uint64 + receive cipher.AEAD send cipher.AEAD sendNonce uint64 isInitiator bool diff --git a/src/noise_protocol.go b/src/noise_protocol.go index 5a62901..0258288 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -446,10 +446,10 @@ func (peer *Peer) NewKeyPair() *KeyPair { keyPair := new(KeyPair) keyPair.send, _ = chacha20poly1305.New(sendKey[:]) - keyPair.recv, _ = chacha20poly1305.New(recvKey[:]) + keyPair.receive, _ = chacha20poly1305.New(recvKey[:]) keyPair.sendNonce = 0 - keyPair.recvNonce = 0 keyPair.created = time.Now() + keyPair.isInitiator = isInitiator keyPair.localIndex = peer.handshake.localIndex keyPair.remoteIndex = peer.handshake.remoteIndex @@ -462,7 +462,7 @@ func (peer *Peer) NewKeyPair() *KeyPair { }) handshake.localIndex = 0 - // start timer for keypair + // TODO: start timer for keypair (clearing) // rotate key pairs @@ -473,7 +473,7 @@ func (peer *Peer) NewKeyPair() *KeyPair { if isInitiator { if kp.previous != nil { kp.previous.send = nil - kp.previous.recv = nil + kp.previous.receive = nil peer.device.indices.Delete(kp.previous.localIndex) } kp.previous = kp.current diff --git a/src/peer.go b/src/peer.go index 1c40598..77bd4df 100644 --- a/src/peer.go +++ b/src/peer.go @@ -35,7 +35,7 @@ type Peer struct { handshakeTimeout *time.Timer } queue struct { - nonce chan []byte // nonce / pre-handshake queue + nonce chan *QueueOutboundElement // nonce / pre-handshake queue outbound chan *QueueOutboundElement // sequential ordering of work inbound chan *QueueInboundElement // sequential ordering of work } @@ -78,9 +78,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { // prepare queuing - peer.queue.nonce = make(chan []byte, QueueOutboundSize) - peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) + peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) + peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) // prepare signaling diff --git a/src/receive.go b/src/receive.go index 5afbf7f..50789a1 100644 --- a/src/receive.go +++ b/src/receive.go @@ -31,17 +31,39 @@ type QueueInboundElement struct { func (elem *QueueInboundElement) Drop() { atomic.StoreUint32(&elem.state, ElementStateDropped) - elem.mutex.Unlock() +} + +func (elem *QueueInboundElement) IsDropped() bool { + return atomic.LoadUint32(&elem.state) == ElementStateDropped +} + +func addToInboundQueue( + queue chan *QueueInboundElement, + element *QueueInboundElement, +) { + for { + select { + case queue <- element: + return + default: + select { + case old := <-queue: + old.Drop() + default: + } + } + } } func (device *Device) RoutineReceiveIncomming() { - var packet []byte debugLog := device.log.Debug debugLog.Println("Routine, receive incomming, started") errorLog := device.log.Error + var buffer []byte // unsliced buffer + for { // check if stopped @@ -54,28 +76,28 @@ func (device *Device) RoutineReceiveIncomming() { // read next datagram - if packet == nil { - packet = make([]byte, 1<<16) + if buffer == nil { + buffer = make([]byte, MaxMessageSize) } device.net.mutex.RLock() conn := device.net.conn device.net.mutex.RUnlock() + if conn == nil { + time.Sleep(time.Second) + continue + } conn.SetReadDeadline(time.Now().Add(time.Second)) - size, raddr, err := conn.ReadFromUDP(packet) - if err != nil { - continue - } - if size < MinMessageSize { + size, raddr, err := conn.ReadFromUDP(buffer) + if err != nil || size < MinMessageSize { continue } // handle packet - packet = packet[:size] - debugLog.Println("GOT:", packet) + packet := buffer[:size] msgType := binary.LittleEndian.Uint32(packet[:4]) func() { @@ -112,6 +134,7 @@ func (device *Device) RoutineReceiveIncomming() { // add to handshake queue + buffer = nil device.queue.handshake <- QueueHandshakeElement{ msgType: msgType, packet: packet, @@ -137,8 +160,6 @@ func (device *Device) RoutineReceiveIncomming() { case MessageTransportType: - debugLog.Println("DEBUG: Got transport") - // lookup key pair if len(packet) < MessageTransportSize { @@ -169,42 +190,15 @@ func (device *Device) RoutineReceiveIncomming() { work.state = ElementStateOkay work.mutex.Lock() - // add to parallel decryption queue + // add to decryption queues - func() { - for { - select { - case device.queue.decryption <- work: - return - default: - select { - case elem := <-device.queue.decryption: - elem.Drop() - default: - } - } - } - }() - - // add to sequential inbound queue - - func() { - for { - select { - case peer.queue.inbound <- work: - break - default: - select { - case elem := <-peer.queue.inbound: - elem.Drop() - default: - } - } - } - }() + addToInboundQueue(device.queue.decryption, work) + addToInboundQueue(peer.queue.inbound, work) + buffer = nil default: // unknown message type + debugLog.Println("Got unknown message from:", raddr) } }() } @@ -214,6 +208,9 @@ func (device *Device) RoutineDecryption() { var elem *QueueInboundElement var nonce [chacha20poly1305.NonceSize]byte + logDebug := device.log.Debug + logDebug.Println("Routine, decryption, started for device") + for { select { case elem = <-device.queue.decryption: @@ -223,31 +220,25 @@ func (device *Device) RoutineDecryption() { // check if dropped - state := atomic.LoadUint32(&elem.state) - if state != ElementStateOkay { + if elem.IsDropped() { + elem.mutex.Unlock() continue } // split message into fields - counter := binary.LittleEndian.Uint64( - elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent], - ) + counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] content := elem.packet[MessageTransportOffsetContent:] // decrypt with key-pair var err error - binary.LittleEndian.PutUint64(nonce[4:], counter) - elem.packet, err = elem.keyPair.recv.Open(elem.packet[:0], nonce[:], content, nil) + copy(nonce[4:], counter) + elem.counter = binary.LittleEndian.Uint64(counter) + elem.packet, err = elem.keyPair.receive.Open(elem.packet[:0], nonce[:], content, nil) if err != nil { elem.Drop() - continue } - - // release to consumer - - elem.counter = counter elem.mutex.Unlock() } } @@ -261,6 +252,7 @@ func (device *Device) RoutineHandshake() { logInfo := device.log.Info logError := device.log.Error logDebug := device.log.Debug + logDebug.Println("Routine, handshake routine, started for device") var elem QueueHandshakeElement @@ -332,13 +324,15 @@ func (device *Device) RoutineHandshake() { } sendSignal(peer.signal.handshakeCompleted) logDebug.Println("Recieved valid response message for peer", peer.id) - peer.NewKeyPair() + kp := peer.NewKeyPair() + if kp == nil { + logDebug.Println("Failed to derieve key-pair") + } peer.SendKeepAlive() default: device.log.Error.Println("Invalid message type in handshake queue") } - }() } } @@ -348,7 +342,6 @@ func (peer *Peer) RoutineSequentialReceiver() { device := peer.device logDebug := device.log.Debug - logDebug.Println("Routine, sequential receiver, started for peer", peer.id) for { @@ -359,20 +352,15 @@ func (peer *Peer) RoutineSequentialReceiver() { return case elem = <-peer.queue.inbound: } + elem.mutex.Lock() - - // check if dropped - - logDebug.Println("MESSSAGE:", elem) - - state := atomic.LoadUint32(&elem.state) - if state != ElementStateOkay { + if elem.IsDropped() { continue } // check for replay - // strip padding + // update timers // check for keep-alive @@ -380,26 +368,30 @@ func (peer *Peer) RoutineSequentialReceiver() { continue } + // strip padding + // insert into inbound TUN queue device.queue.inbound <- elem.packet - } + // update key material + } } func (device *Device) RoutineWriteToTUN(tun TUNDevice) { - for { - var packet []byte + logError := device.log.Error + logDebug := device.log.Debug + logDebug.Println("Routine, sequential tun writer, started") + for { select { case <-device.signal.stop: - case packet = <-device.queue.inbound: - } - - size, err := tun.Write(packet) - device.log.Debug.Println("DEBUG:", size, err) - if err != nil { - + return + case packet := <-device.queue.inbound: + _, err := tun.Write(packet) + if err != nil { + logError.Println("Failed to write packet to TUN device:", err) + } } } } diff --git a/src/send.go b/src/send.go index 3fe4733..4053669 100644 --- a/src/send.go +++ b/src/send.go @@ -25,14 +25,19 @@ import ( * * The sequential consumers will attempt to take the lock, * workers release lock when they have completed work on the packet. + * + * If the element is inserted into the "encryption queue", + * the content is preceeded by enough "junk" to contain the header + * (to allow the constuction of transport messages in-place) */ type QueueOutboundElement struct { state uint32 mutex sync.Mutex - packet []byte - nonce uint64 - keyPair *KeyPair - peer *Peer + data [MaxMessageSize]byte + packet []byte // slice of packet (sending) + nonce uint64 // nonce for encryption + keyPair *KeyPair // key-pair for encryption + peer *Peer // related peer } func (peer *Peer) FlushNonceQueue() { @@ -46,18 +51,9 @@ func (peer *Peer) FlushNonceQueue() { } } -func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) { - for { - select { - case peer.queue.outbound <- elem: - return - default: - select { - case <-peer.queue.outbound: - default: - } - } - } +func (device *Device) NewOutboundElement() *QueueOutboundElement { + elem := new(QueueOutboundElement) // TODO: profile, consider sync.Pool + return elem } func (elem *QueueOutboundElement) Drop() { @@ -68,53 +64,74 @@ func (elem *QueueOutboundElement) IsDropped() bool { return atomic.LoadUint32(&elem.state) == ElementStateDropped } +func addToOutboundQueue( + queue chan *QueueOutboundElement, + element *QueueOutboundElement, +) { + for { + select { + case queue <- element: + return + default: + select { + case old := <-queue: + old.Drop() + default: + } + } + } +} + /* Reads packets from the TUN and inserts * into nonce queue for peer * * Obs. Single instance per TUN device */ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { - if tun.MTU() == 0 { - // Dummy + if tun == nil { + // dummy return } + elem := device.NewOutboundElement() + device.log.Debug.Println("Routine, TUN Reader: started") for { // read packet - packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation - size, err := tun.Read(packet) + if elem == nil { + elem = device.NewOutboundElement() + } + + elem.packet = elem.data[MessageTransportHeaderSize:] + size, err := tun.Read(elem.packet) if err != nil { device.log.Error.Println("Failed to read packet from TUN device:", err) continue } - packet = packet[:size] - if len(packet) < IPv4headerSize { - device.log.Error.Println("Packet too short, length:", len(packet)) + elem.packet = elem.packet[:size] + if len(elem.packet) < IPv4headerSize { + device.log.Error.Println("Packet too short, length:", size) continue } // lookup peer var peer *Peer - switch packet[0] >> 4 { + switch elem.packet[0] >> 4 { case IPv4version: - dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] peer = device.routingTable.LookupIPv4(dst) - device.log.Debug.Println("New IPv4 packet:", packet, dst) case IPv6version: - dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] peer = device.routingTable.LookupIPv6(dst) - device.log.Debug.Println("New IPv6 packet:", packet, dst) default: device.log.Debug.Println("Receieved packet with unknown IP version") } if peer == nil { - device.log.Debug.Println("No peer configured for IP") continue } if peer.endpoint == nil { @@ -124,18 +141,9 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { // insert into nonce/pre-handshake queue - for { - select { - case peer.queue.nonce <- packet: - default: - select { - case <-peer.queue.nonce: - default: - } - continue - } - break - } + addToOutboundQueue(peer.queue.nonce, elem) + elem = nil + } } @@ -148,8 +156,8 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { * Obs. A single instance per peer */ func (peer *Peer) RoutineNonce() { - var packet []byte var keyPair *KeyPair + var elem *QueueOutboundElement device := peer.device logger := device.log.Debug @@ -163,9 +171,9 @@ func (peer *Peer) RoutineNonce() { // wait for packet - if packet == nil { + if elem == nil { select { - case packet = <-peer.queue.nonce: + case elem = <-peer.queue.nonce: case <-peer.signal.stop: return } @@ -198,7 +206,7 @@ func (peer *Peer) RoutineNonce() { case <-peer.signal.flushNonceQueue: logger.Println("Clearing queue for peer", peer.id) peer.FlushNonceQueue() - packet = nil + elem = nil goto NextPacket case <-peer.signal.stop: @@ -208,36 +216,20 @@ func (peer *Peer) RoutineNonce() { // process current packet - if packet != nil { + if elem != nil { // create work element - work := new(QueueOutboundElement) // TODO: profile, maybe use pool - work.keyPair = keyPair - work.packet = packet - work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 - work.peer = peer - work.mutex.Lock() + elem.keyPair = keyPair + elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 + elem.peer = peer + elem.mutex.Lock() - packet = nil + // add to parallel processing and sequential consuming queue - // drop packets until there is space - - func() { - for { - select { - case peer.device.queue.encryption <- work: - return - default: - select { - case elem := <-peer.device.queue.encryption: - elem.Drop() - default: - } - } - } - }() - peer.queue.outbound <- work + addToOutboundQueue(device.queue.encryption, elem) + addToOutboundQueue(peer.queue.outbound, elem) + elem = nil } } }() @@ -257,42 +249,38 @@ func (device *Device) RoutineEncryption() { continue } - // pad packet + // populate header fields - padding := device.mtu - len(work.packet) - MessageTransportSize - if padding < 0 { - work.Drop() - continue - } + func() { + header := work.data[:MessageTransportHeaderSize] - for n := 0; n < padding; n += 1 { - work.packet = append(work.packet, 0) - } - content := work.packet[MessageTransportHeaderSize:] - copy(content, work.packet) + fieldType := header[0:4] + fieldReceiver := header[4:8] + fieldNonce := header[8:16] - // prepare header - - binary.LittleEndian.PutUint32(work.packet[:4], MessageTransportType) - binary.LittleEndian.PutUint32(work.packet[4:8], work.keyPair.remoteIndex) - binary.LittleEndian.PutUint64(work.packet[8:16], work.nonce) - - device.log.Debug.Println(work.packet, work.nonce) + binary.LittleEndian.PutUint32(fieldType, MessageTransportType) + binary.LittleEndian.PutUint32(fieldReceiver, work.keyPair.remoteIndex) + binary.LittleEndian.PutUint64(fieldNonce, work.nonce) + }() // encrypt content - binary.LittleEndian.PutUint64(nonce[4:], work.nonce) - work.keyPair.send.Seal( - content[:0], - nonce[:], - content, - nil, - ) - work.mutex.Unlock() + func() { + binary.LittleEndian.PutUint64(nonce[4:], work.nonce) + work.packet = work.keyPair.send.Seal( + work.packet[:0], + nonce[:], + work.packet, + nil, + ) + work.mutex.Unlock() + }() - device.log.Debug.Println(work.packet, work.nonce) + // reslice to include header - // initiate new handshake + work.packet = work.data[:MessageTransportHeaderSize+len(work.packet)] + + // refresh key if necessary work.peer.KeepKeyFreshSending() } @@ -340,8 +328,6 @@ func (peer *Peer) RoutineSequentialSender() { return } - logger.Println(work.packet) - _, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint) if err != nil { return