From 1e620427bd01b1e897c57752359f7dbb28e34bff Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 1 Jul 2017 23:29:22 +0200 Subject: [PATCH] Handshake negotiation functioning --- src/constants.go | 6 +- src/device.go | 19 +- src/handshake.go | 60 +------ src/misc.go | 19 ++ src/noise_protocol.go | 21 ++- src/peer.go | 15 +- src/receive.go | 404 ++++++++++++++++++++++++++++++++++++++++++ src/send.go | 44 +++-- src/tun.go | 2 +- src/tun_linux.go | 4 +- 10 files changed, 512 insertions(+), 82 deletions(-) create mode 100644 src/receive.go diff --git a/src/constants.go b/src/constants.go index 34217d2..2e484f3 100644 --- a/src/constants.go +++ b/src/constants.go @@ -17,5 +17,9 @@ const ( ) const ( - QueueOutboundSize = 1024 + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 + QueueHandshakeBusySize = QueueHandshakeSize / 8 + MinMessageSize = MessageTransportSize ) diff --git a/src/device.go b/src/device.go index a33e923..ff10e32 100644 --- a/src/device.go +++ b/src/device.go @@ -23,7 +23,13 @@ type Device struct { routingTable RoutingTable indices IndexTable queue struct { - encryption chan *QueueOutboundElement // parallel work queue + encryption chan *QueueOutboundElement + decryption chan *QueueInboundElement + handshake chan QueueHandshakeElement + inbound chan []byte // inbound queue for TUN + } + signal struct { + stop chan struct{} } peers map[NoisePublicKey]*Peer mac MacStateDevice @@ -56,6 +62,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { defer device.mutex.Unlock() device.log = NewLogger(logLevel) + device.mtu = tun.MTU() device.peers = make(map[NoisePublicKey]*Peer) device.indices.Init() device.routingTable.Reset() @@ -71,13 +78,22 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { // create queues device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) + device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) + device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) + + // prepare signals + + device.signal.stop = make(chan struct{}) // start workers for i := 0; i < runtime.NumCPU(); i += 1 { go device.RoutineEncryption() + go device.RoutineDecryption() + go device.RoutineHandshake() } go device.RoutineReadFromTUN(tun) + go device.RoutineReceiveIncomming() return device } @@ -115,5 +131,6 @@ func (device *Device) RemoveAllPeers() { func (device *Device) Close() { device.RemoveAllPeers() + close(device.signal.stop) close(device.queue.encryption) } diff --git a/src/handshake.go b/src/handshake.go index 806d213..cf73e9b 100644 --- a/src/handshake.go +++ b/src/handshake.go @@ -3,7 +3,6 @@ package main import ( "bytes" "encoding/binary" - "net" "sync/atomic" "time" ) @@ -24,14 +23,6 @@ func (peer *Peer) SendKeepAlive() bool { return true } -func StoppedTimer() *time.Timer { - timer := time.NewTimer(time.Hour) - if !timer.Stop() { - <-timer.C - } - return timer -} - /* Called when a new authenticated message has been send * * TODO: This might be done in a faster way @@ -71,7 +62,7 @@ func (peer *Peer) RoutineHandshakeInitiator() { device := peer.device buffer := make([]byte, 1024) logger := device.log.Debug - timeout := time.NewTimer(time.Hour) + timeout := stoppedTimer() var work *QueueOutboundElement @@ -129,13 +120,8 @@ func (peer *Peer) RoutineHandshakeInitiator() { // set timeout - if !timeout.Stop() { - select { - case <-timeout.C: - default: - } - } attempts += 1 + stopTimer(timeout) timeout.Reset(RekeyTimeout) device.log.Debug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id) @@ -163,45 +149,3 @@ func (peer *Peer) RoutineHandshakeInitiator() { logger.Println("Routine, handshake initator, stopped for peer", peer.id) } - -/* Handles incomming packets related to handshake - * - * - */ -func (device *Device) HandshakeWorker(queue chan struct { - msg []byte - msgType uint32 - addr *net.UDPAddr -}) { - for { - elem := <-queue - - switch elem.msgType { - case MessageInitiationType: - if len(elem.msg) != MessageInitiationSize { - continue - } - - // check for cookie - - var msg MessageInitiation - - binary.Read(nil, binary.LittleEndian, &msg) - - case MessageResponseType: - if len(elem.msg) != MessageResponseSize { - continue - } - - // check for cookie - - case MessageCookieReplyType: - if len(elem.msg) != MessageCookieReplySize { - continue - } - - default: - device.log.Error.Println("Invalid message type in handshake queue") - } - } -} diff --git a/src/misc.go b/src/misc.go index 2bcb148..dd4fa63 100644 --- a/src/misc.go +++ b/src/misc.go @@ -1,5 +1,9 @@ package main +import ( + "time" +) + func min(a uint, b uint) uint { if a > b { return b @@ -13,3 +17,18 @@ func sendSignal(c chan struct{}) { default: } } + +func stopTimer(timer *time.Timer) { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } +} + +func stoppedTimer() *time.Timer { + timer := time.NewTimer(time.Hour) + stopTimer(timer) + return timer +} diff --git a/src/noise_protocol.go b/src/noise_protocol.go index a1a1c7b..adb00ec 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -6,6 +6,7 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" "sync" + "time" ) const ( @@ -34,6 +35,13 @@ const ( MessageInitiationSize = 148 MessageResponseSize = 92 MessageCookieReplySize = 64 + MessageTransportSize = 16 + poly1305.TagSize // size of empty transport +) + +const ( + MessageTransportOffsetReceiver = 4 + MessageTransportOffsetCounter = 8 + MessageTransportOffsetContent = 16 ) /* Type is an 8-bit field, followed by 3 nul bytes, @@ -55,7 +63,7 @@ type MessageInitiation struct { type MessageResponse struct { Type uint32 Sender uint32 - Reciever uint32 + Receiver uint32 Ephemeral NoisePublicKey Empty [poly1305.TagSize]byte Mac1 [blake2s.Size128]byte @@ -64,7 +72,7 @@ type MessageResponse struct { type MessageTransport struct { Type uint32 - Reciever uint32 + Receiver uint32 Counter uint64 Content []byte } @@ -292,7 +300,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error var msg MessageResponse msg.Type = MessageResponseType msg.Sender = handshake.localIndex - msg.Reciever = handshake.remoteIndex + msg.Receiver = handshake.remoteIndex // create ephemeral key @@ -302,6 +310,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error } msg.Ephemeral = handshake.localEphemeral.publicKey() handshake.mixHash(msg.Ephemeral[:]) + handshake.mixKey(msg.Ephemeral[:]) func() { ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) @@ -334,7 +343,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { // lookup handshake by reciever - lookup := device.indices.Lookup(msg.Reciever) + lookup := device.indices.Lookup(msg.Receiver) handshake := lookup.handshake if handshake == nil { return nil @@ -359,7 +368,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { // finish 3-way DH hash = mixHash(handshake.hash, msg.Ephemeral[:]) - chainKey = handshake.chainKey + chainKey = mixKey(handshake.chainKey, msg.Ephemeral[:]) func() { ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) @@ -380,6 +389,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { aead, _ := chacha20poly1305.New(key[:]) _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) if err != nil { + device.log.Debug.Println("failed to open") return false } hash = mixHash(hash, msg.Empty[:]) @@ -438,6 +448,7 @@ func (peer *Peer) NewKeyPair() *KeyPair { keyPair.recv, _ = chacha20poly1305.New(recvKey[:]) keyPair.sendNonce = 0 keyPair.recvNonce = 0 + keyPair.created = time.Now() // remap index diff --git a/src/peer.go b/src/peer.go index e885cee..1c40598 100644 --- a/src/peer.go +++ b/src/peer.go @@ -37,6 +37,7 @@ type Peer struct { queue struct { nonce chan []byte // nonce / pre-handshake queue outbound chan *QueueOutboundElement // sequential ordering of work + inbound chan *QueueInboundElement // sequential ordering of work } mac MacStatePeer } @@ -47,11 +48,10 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { peer := new(Peer) peer.mutex.Lock() defer peer.mutex.Unlock() - peer.device = device + peer.mac.Init(pk) - peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.nonce = make(chan []byte, QueueOutboundSize) - peer.timer.sendKeepalive = StoppedTimer() + peer.device = device + peer.timer.sendKeepalive = stoppedTimer() // assign id for debugging @@ -76,6 +76,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic) handshake.mutex.Unlock() + // prepare queuing + + peer.queue.nonce = make(chan []byte, QueueOutboundSize) + peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) + peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) + // prepare signaling peer.signal.stop = make(chan struct{}) @@ -89,6 +95,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { go peer.RoutineNonce() go peer.RoutineHandshakeInitiator() go peer.RoutineSequentialSender() + go peer.RoutineSequentialReceiver() return peer } diff --git a/src/receive.go b/src/receive.go new file mode 100644 index 0000000..ab28944 --- /dev/null +++ b/src/receive.go @@ -0,0 +1,404 @@ +package main + +import ( + "bytes" + "encoding/binary" + "golang.org/x/crypto/chacha20poly1305" + "net" + "sync" + "sync/atomic" + "time" +) + +const ( + ElementStateOkay = iota + ElementStateDropped +) + +type QueueHandshakeElement struct { + msgType uint32 + packet []byte + source *net.UDPAddr +} + +type QueueInboundElement struct { + state uint32 + mutex sync.Mutex + packet []byte + counter uint64 + keyPair *KeyPair +} + +func (elem *QueueInboundElement) Drop() { + atomic.StoreUint32(&elem.state, ElementStateDropped) + elem.mutex.Unlock() +} + +func (device *Device) RoutineReceiveIncomming() { + var packet []byte + + debugLog := device.log.Debug + debugLog.Println("Routine, receive incomming, started") + + errorLog := device.log.Error + + for { + + // check if stopped + + select { + case <-device.signal.stop: + return + default: + } + + // read next datagram + + if packet == nil { + packet = make([]byte, 1<<16) + } + + device.net.mutex.RLock() + conn := device.net.conn + device.net.mutex.RUnlock() + + conn.SetReadDeadline(time.Now().Add(time.Second)) + + size, raddr, err := conn.ReadFromUDP(packet) + if err != nil { + continue + } + if size < MinMessageSize { + continue + } + + // handle packet + + packet = packet[:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) + + func() { + switch msgType { + + case MessageInitiationType, MessageResponseType: + + // verify mac1 + + if !device.mac.CheckMAC1(packet) { + debugLog.Println("Received packet with invalid mac1") + return + } + + // check if busy, TODO: refine definition of "busy" + + busy := len(device.queue.handshake) > QueueHandshakeBusySize + if busy && !device.mac.CheckMAC2(packet, raddr) { + sender := binary.LittleEndian.Uint32(packet[4:8]) // "sender" follows "type" + reply, err := device.CreateMessageCookieReply(packet, sender, raddr) + if err != nil { + errorLog.Println("Failed to create cookie reply:", err) + return + } + writer := bytes.NewBuffer(packet[:0]) + binary.Write(writer, binary.LittleEndian, reply) + packet = writer.Bytes() + _, err = device.net.conn.WriteToUDP(packet, raddr) + if err != nil { + debugLog.Println("Failed to send cookie reply:", err) + } + return + } + + // add to handshake queue + + device.queue.handshake <- QueueHandshakeElement{ + msgType: msgType, + packet: packet, + source: raddr, + } + + case MessageCookieReplyType: + + // verify and update peer cookie state + + if len(packet) != MessageCookieReplySize { + return + } + + var reply MessageCookieReply + reader := bytes.NewReader(packet) + err := binary.Read(reader, binary.LittleEndian, &reply) + if err != nil { + debugLog.Println("Failed to decode cookie reply") + return + } + device.ConsumeMessageCookieReply(&reply) + + case MessageTransportType: + + debugLog.Println("DEBUG: Got transport") + + // lookup key pair + + if len(packet) < MessageTransportSize { + return + } + + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indices.Lookup(receiver) + keyPair := value.keyPair + if keyPair == nil { + return + } + + // check key-pair expiry + + if keyPair.created.Add(RejectAfterTime).Before(time.Now()) { + return + } + + // add to peer queue + + peer := value.peer + work := new(QueueInboundElement) + work.packet = packet + work.keyPair = keyPair + work.state = ElementStateOkay + work.mutex.Lock() + + // add to parallel decryption queue + + func() { + for { + select { + case device.queue.decryption <- work: + return + default: + select { + case elem := <-device.queue.decryption: + elem.Drop() + default: + } + } + } + }() + + // add to sequential inbound queue + + func() { + for { + select { + case peer.queue.inbound <- work: + break + default: + select { + case elem := <-peer.queue.inbound: + elem.Drop() + default: + } + } + } + }() + + default: + // unknown message type + } + }() + } +} + +func (device *Device) RoutineDecryption() { + var elem *QueueInboundElement + var nonce [chacha20poly1305.NonceSize]byte + + for { + select { + case elem = <-device.queue.decryption: + case <-device.signal.stop: + return + } + + // check if dropped + + state := atomic.LoadUint32(&elem.state) + if state != ElementStateOkay { + continue + } + + // split message into fields + + counter := binary.LittleEndian.Uint64( + elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent], + ) + content := elem.packet[MessageTransportOffsetContent:] + + // decrypt with key-pair + + var err error + binary.LittleEndian.PutUint64(nonce[4:], counter) + elem.packet, err = elem.keyPair.recv.Open(elem.packet[:0], nonce[:], content, nil) + if err != nil { + elem.Drop() + continue + } + + // release to consumer + + elem.counter = counter + elem.mutex.Unlock() + } +} + +/* Handles incomming packets related to handshake + * + * + */ +func (device *Device) RoutineHandshake() { + + logInfo := device.log.Info + logError := device.log.Error + logDebug := device.log.Debug + + var elem QueueHandshakeElement + + for { + select { + case elem = <-device.queue.handshake: + case <-device.signal.stop: + return + } + + func() { + + switch elem.msgType { + case MessageInitiationType: + + // unmarshal + + if len(elem.packet) != MessageInitiationSize { + return + } + + var msg MessageInitiation + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &msg) + if err != nil { + logError.Println("Failed to decode initiation message") + return + } + + // consume initiation + + peer := device.ConsumeMessageInitiation(&msg) + if peer == nil { + logInfo.Println( + "Recieved invalid initiation message from", + elem.source.IP.String(), + elem.source.Port, + ) + return + } + logDebug.Println("Recieved valid initiation message for peer", peer.id) + + case MessageResponseType: + + // unmarshal + + if len(elem.packet) != MessageResponseSize { + return + } + + var msg MessageResponse + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &msg) + if err != nil { + logError.Println("Failed to decode response message") + return + } + + // consume response + + peer := device.ConsumeMessageResponse(&msg) + if peer == nil { + logInfo.Println( + "Recieved invalid response message from", + elem.source.IP.String(), + elem.source.Port, + ) + return + } + sendSignal(peer.signal.handshakeCompleted) + logDebug.Println("Recieved valid response message for peer", peer.id) + peer.NewKeyPair() + peer.SendKeepAlive() + + default: + device.log.Error.Println("Invalid message type in handshake queue") + } + + }() + } +} + +func (peer *Peer) RoutineSequentialReceiver() { + var elem *QueueInboundElement + + device := peer.device + logDebug := device.log.Debug + + logDebug.Println("Routine, sequential receiver, started for peer", peer.id) + + for { + // wait for decryption + + select { + case <-peer.signal.stop: + return + case elem = <-peer.queue.inbound: + } + elem.mutex.Lock() + + // check if dropped + + logDebug.Println("MESSSAGE:", elem) + + state := atomic.LoadUint32(&elem.state) + if state != ElementStateOkay { + continue + } + + // check for replay + + // check for keep-alive + + if len(elem.packet) == 0 { + continue + } + + // insert into inbound TUN queue + + device.queue.inbound <- elem.packet + } + +} + +func (device *Device) RoutineWriteToTUN(tun TUNDevice) { + for { + var packet []byte + + select { + case <-device.signal.stop: + case packet = <-device.queue.inbound: + } + + device.log.Debug.Println("GOT:", packet) + + size, err := tun.Write(packet) + device.log.Debug.Println("DEBUG:", size, err) + if err != nil { + + } + } +} diff --git a/src/send.go b/src/send.go index d4f9342..7a10560 100644 --- a/src/send.go +++ b/src/send.go @@ -27,6 +27,7 @@ import ( * workers release lock when they have completed work on the packet. */ type QueueOutboundElement struct { + state uint32 mutex sync.Mutex packet []byte nonce uint64 @@ -59,6 +60,14 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) { } } +func (elem *QueueOutboundElement) Drop() { + atomic.StoreUint32(&elem.state, ElementStateDropped) +} + +func (elem *QueueOutboundElement) IsDropped() bool { + return atomic.LoadUint32(&elem.state) == ElementStateDropped +} + /* Reads packets from the TUN and inserts * into nonce queue for peer * @@ -162,6 +171,8 @@ func (peer *Peer) RoutineNonce() { } } + logger.Println("PACKET:", packet) + // wait for key pair for { @@ -176,6 +187,7 @@ func (peer *Peer) RoutineNonce() { break } } + logger.Println("Key pair:", keyPair) sendSignal(peer.signal.handshakeBegin) logger.Println("Waiting for key-pair, peer", peer.id) @@ -205,10 +217,12 @@ func (peer *Peer) RoutineNonce() { work := new(QueueOutboundElement) // TODO: profile, maybe use pool work.keyPair = keyPair work.packet = packet - work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) + work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 work.peer = peer work.mutex.Lock() + logger.Println("WORK:", work) + packet = nil // drop packets until there is space @@ -219,9 +233,11 @@ func (peer *Peer) RoutineNonce() { case peer.device.queue.encryption <- work: return default: - drop := <-peer.device.queue.encryption - drop.packet = nil - drop.mutex.Unlock() + select { + case elem := <-peer.device.queue.encryption: + elem.Drop() + default: + } } } }() @@ -241,18 +257,22 @@ func (peer *Peer) RoutineNonce() { func (device *Device) RoutineEncryption() { var nonce [chacha20poly1305.NonceSize]byte for work := range device.queue.encryption { + if work.IsDropped() { + continue + } // pad packet padding := device.mtu - len(work.packet) if padding < 0 { - // drop - work.packet = nil - work.mutex.Unlock() + work.Drop() + continue } + for n := 0; n < padding; n += 1 { work.packet = append(work.packet, 0) } + device.log.Debug.Println(work.packet) // encrypt @@ -288,6 +308,9 @@ func (peer *Peer) RoutineSequentialSender() { logger.Println("Routine, sequential sender, stopped for peer", peer.id) return case work := <-peer.queue.outbound: + if work.IsDropped() { + continue + } work.mutex.Lock() func() { if work.packet == nil { @@ -310,10 +333,12 @@ func (peer *Peer) RoutineSequentialSender() { return } - logger.Println("Sending packet for peer", peer.id, work.packet) + logger.Println(work.packet) _, err := device.net.conn.WriteToUDP(work.packet, peer.endpoint) - logger.Println("SEND:", peer.endpoint, err) + if err != nil { + return + } atomic.AddUint64(&peer.tx_bytes, uint64(len(work.packet))) // shift keep-alive timer @@ -323,7 +348,6 @@ func (peer *Peer) RoutineSequentialSender() { peer.timer.sendKeepalive.Reset(interval) } }() - work.mutex.Unlock() } } } diff --git a/src/tun.go b/src/tun.go index 594754a..60732c4 100644 --- a/src/tun.go +++ b/src/tun.go @@ -4,5 +4,5 @@ type TUNDevice interface { Read([]byte) (int, error) Write([]byte) (int, error) Name() string - MTU() uint + MTU() int } diff --git a/src/tun_linux.go b/src/tun_linux.go index db13fb0..a0bff81 100644 --- a/src/tun_linux.go +++ b/src/tun_linux.go @@ -24,14 +24,14 @@ const ( type NativeTun struct { fd *os.File name string - mtu uint + mtu int } func (tun *NativeTun) Name() string { return tun.name } -func (tun *NativeTun) MTU() uint { +func (tun *NativeTun) MTU() int { return tun.mtu }