From 2c27ab205c992d3387574aa6d57780744d35d36f Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sun, 13 May 2018 18:23:40 +0200 Subject: [PATCH] Rework index hashtable --- device.go | 6 ++--- index.go => indextable.go | 53 ++++++++++++++++++--------------------- keypair.go | 2 +- noise-protocol.go | 51 ++++++++++++++++--------------------- peer.go | 6 ++--- receive.go | 24 +++++++++--------- send.go | 20 +++++++-------- timers.go | 4 +-- 8 files changed, 78 insertions(+), 88 deletions(-) rename index.go => indextable.go (65%) diff --git a/device.go b/device.go index e127b5b..3db3609 100644 --- a/device.go +++ b/device.go @@ -56,8 +56,8 @@ type Device struct { // unprotected / "self-synchronising resources" - indices IndexTable - mac CookieChecker + indexTable IndexTable + mac CookieChecker rate struct { underLoadUntil atomic.Value @@ -283,7 +283,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { // initialize noise & crypt-key routine - device.indices.Init() + device.indexTable.Init() device.routing.table.Reset() // setup buffer pool diff --git a/index.go b/indextable.go similarity index 65% rename from index.go rename to indextable.go index 4a78d55..2d947cd 100644 --- a/index.go +++ b/indextable.go @@ -7,18 +7,14 @@ package main import ( "crypto/rand" - "encoding/binary" "sync" + "unsafe" ) -/* Index=0 is reserved for unset indecies - * - */ - type IndexTableEntry struct { peer *Peer handshake *Handshake - keyPair *Keypair + keypair *Keypair } type IndexTable struct { @@ -27,34 +23,38 @@ type IndexTable struct { } func randUint32() (uint32, error) { - var buff [4]byte - _, err := rand.Read(buff[:]) - value := binary.LittleEndian.Uint32(buff[:]) - return value, err + var integer [4]byte + _, err := rand.Read(integer[:]) + return *(*uint32)(unsafe.Pointer(&integer[0])), err } func (table *IndexTable) Init() { table.mutex.Lock() + defer table.mutex.Unlock() table.table = make(map[uint32]IndexTableEntry) - table.mutex.Unlock() } func (table *IndexTable) Delete(index uint32) { - if index == 0 { + table.mutex.Lock() + defer table.mutex.Unlock() + delete(table.table, index) +} + +func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) { + table.mutex.Lock() + defer table.mutex.Unlock() + entry, ok := table.table[index] + if !ok { return } - table.mutex.Lock() - delete(table.table, index) - table.mutex.Unlock() + table.table[index] = IndexTableEntry{ + peer: entry.peer, + keypair: keypair, + handshake: nil, + } } -func (table *IndexTable) Insert(key uint32, value IndexTableEntry) { - table.mutex.Lock() - table.table[key] = value - table.mutex.Unlock() -} - -func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) { +func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) { for { // generate random index @@ -62,9 +62,6 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) { if err != nil { return index, err } - if index == 0 { - continue - } // check if index used @@ -75,7 +72,7 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) { continue } - // map index to handshake + // check again while locked table.mutex.Lock() _, found := table.table[index] @@ -85,8 +82,8 @@ func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) { } table.table[index] = IndexTableEntry{ peer: peer, - handshake: &peer.handshake, - keyPair: nil, + handshake: handshake, + keypair: nil, } table.mutex.Unlock() return index, nil diff --git a/keypair.go b/keypair.go index 07a183d..6f6f7c0 100644 --- a/keypair.go +++ b/keypair.go @@ -44,6 +44,6 @@ func (kp *Keypairs) Current() *Keypair { func (device *Device) DeleteKeypair(key *Keypair) { if key != nil { - device.indices.Delete(key.localIndex) + device.indexTable.Delete(key.localIndex) } } diff --git a/noise-protocol.go b/noise-protocol.go index 3abbe4b..82d553e 100644 --- a/noise-protocol.go +++ b/noise-protocol.go @@ -161,7 +161,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e defer handshake.mutex.Unlock() if isZero(handshake.precomputedStaticStatic[:]) { - return nil, errors.New("Static shared secret is zero") + return nil, errors.New("static shared secret is zero") } // create ephemeral key @@ -176,8 +176,8 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e // assign index - device.indices.Delete(handshake.localIndex) - handshake.localIndex, err = device.indices.NewIndex(peer) + device.indexTable.Delete(handshake.localIndex) + handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) if err != nil { return nil, err @@ -328,14 +328,14 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error defer handshake.mutex.Unlock() if handshake.state != HandshakeInitiationConsumed { - return nil, errors.New("handshake initation must be consumed first") + return nil, errors.New("handshake initiation must be consumed first") } // assign index var err error - device.indices.Delete(handshake.localIndex) - handshake.localIndex, err = device.indices.NewIndex(peer) + device.indexTable.Delete(handshake.localIndex) + handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) if err != nil { return nil, err } @@ -393,9 +393,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { return nil } - // lookup handshake by reciever + // lookup handshake by receiver - lookup := device.indices.Lookup(msg.Receiver) + lookup := device.indexTable.Lookup(msg.Receiver) handshake := lookup.handshake if handshake == nil { return nil @@ -528,35 +528,28 @@ func (peer *Peer) NewKeypair() *Keypair { // create AEAD instances - keyPair := new(Keypair) - keyPair.send, _ = chacha20poly1305.New(sendKey[:]) - keyPair.receive, _ = chacha20poly1305.New(recvKey[:]) + keypair := new(Keypair) + keypair.send, _ = chacha20poly1305.New(sendKey[:]) + keypair.receive, _ = chacha20poly1305.New(recvKey[:]) setZero(sendKey[:]) setZero(recvKey[:]) - keyPair.created = time.Now() - keyPair.sendNonce = 0 - keyPair.replayFilter.Init() - keyPair.isInitiator = isInitiator - keyPair.localIndex = peer.handshake.localIndex - keyPair.remoteIndex = peer.handshake.remoteIndex + keypair.created = time.Now() + keypair.sendNonce = 0 + keypair.replayFilter.Init() + keypair.isInitiator = isInitiator + keypair.localIndex = peer.handshake.localIndex + keypair.remoteIndex = peer.handshake.remoteIndex // remap index - device.indices.Insert( - handshake.localIndex, - IndexTableEntry{ - peer: peer, - keyPair: keyPair, - handshake: nil, - }, - ) + device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair) handshake.localIndex = 0 // rotate key pairs - kp := &peer.keyPairs + kp := &peer.keypairs kp.mutex.Lock() peer.timersSessionDerived() @@ -574,14 +567,14 @@ func (peer *Peer) NewKeypair() *Keypair { kp.previous = current } device.DeleteKeypair(previous) - kp.current = keyPair + kp.current = keypair } else { - kp.next = keyPair + kp.next = keypair device.DeleteKeypair(next) kp.previous = nil device.DeleteKeypair(previous) } kp.mutex.Unlock() - return keyPair + return keypair } diff --git a/peer.go b/peer.go index 242729e..f49f806 100644 --- a/peer.go +++ b/peer.go @@ -20,7 +20,7 @@ const ( type Peer struct { isRunning AtomicBool mutex sync.RWMutex - keyPairs Keypairs + keypairs Keypairs handshake Handshake device *Device endpoint Endpoint @@ -234,7 +234,7 @@ func (peer *Peer) Stop() { // clear key pairs - kp := &peer.keyPairs + kp := &peer.keypairs kp.mutex.Lock() device.DeleteKeypair(kp.previous) @@ -250,7 +250,7 @@ func (peer *Peer) Stop() { hs := &peer.handshake hs.mutex.Lock() - device.indices.Delete(hs.localIndex) + device.indexTable.Delete(hs.localIndex) hs.Clear() hs.mutex.Unlock() diff --git a/receive.go b/receive.go index 0f22a3f..60a2510 100644 --- a/receive.go +++ b/receive.go @@ -31,7 +31,7 @@ type QueueInboundElement struct { buffer *[MaxMessageSize]byte packet []byte counter uint64 - keyPair *Keypair + keypair *Keypair endpoint Endpoint } @@ -107,7 +107,7 @@ func (peer *Peer) keepKeyFreshReceiving() { if peer.timers.sentLastMinuteHandshake { return } - kp := peer.keyPairs.Current() + kp := peer.keypairs.Current() if kp != nil && kp.isInitiator && time.Now().Sub(kp.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { peer.timers.sentLastMinuteHandshake = true peer.SendHandshakeInitiation(false) @@ -183,15 +183,15 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { receiver := binary.LittleEndian.Uint32( packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], ) - value := device.indices.Lookup(receiver) - keyPair := value.keyPair - if keyPair == nil { + value := device.indexTable.Lookup(receiver) + keypair := value.keypair + if keypair == nil { continue } // check key-pair expiry - if keyPair.created.Add(RejectAfterTime).Before(time.Now()) { + if keypair.created.Add(RejectAfterTime).Before(time.Now()) { continue } @@ -201,7 +201,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { elem := &QueueInboundElement{ packet: packet, buffer: buffer, - keyPair: keyPair, + keypair: keypair, dropped: AtomicFalse, endpoint: endpoint, } @@ -296,7 +296,7 @@ func (device *Device) RoutineDecryption() { var err error elem.counter = binary.LittleEndian.Uint64(counter) - elem.packet, err = elem.keyPair.receive.Open( + elem.packet, err = elem.keypair.receive.Open( content[:0], nonce[:], content, @@ -358,7 +358,7 @@ func (device *Device) RoutineHandshake() { // lookup peer from index - entry := device.indices.Lookup(reply.Receiver) + entry := device.indexTable.Lookup(reply.Receiver) if entry.peer == nil { continue @@ -587,7 +587,7 @@ func (peer *Peer) RoutineSequentialReceiver() { // check for replay - if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) { + if !elem.keypair.replayFilter.ValidateCounter(elem.counter) { continue } @@ -599,9 +599,9 @@ func (peer *Peer) RoutineSequentialReceiver() { // check if using new key-pair - kp := &peer.keyPairs + kp := &peer.keypairs kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true - if kp.next == elem.keyPair { + if kp.next == elem.keypair { old := kp.previous kp.previous = kp.current device.DeleteKeypair(old) diff --git a/send.go b/send.go index 1b35e27..35e0d00 100644 --- a/send.go +++ b/send.go @@ -47,7 +47,7 @@ type QueueOutboundElement struct { buffer *[MaxMessageSize]byte // slice holding the packet data packet []byte // slice of "buffer" (always!) nonce uint64 // nonce for encryption - keyPair *Keypair // key-pair for encryption + keypair *Keypair // key-pair for encryption peer *Peer // related peer } @@ -161,7 +161,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { * */ func (peer *Peer) keepKeyFreshSending() { - kp := peer.keyPairs.Current() + kp := peer.keypairs.Current() if kp == nil { return } @@ -260,7 +260,7 @@ func (peer *Peer) FlushNonceQueue() { * Obs. A single instance per peer */ func (peer *Peer) RoutineNonce() { - var keyPair *Keypair + var keypair *Keypair device := peer.device logDebug := device.log.Debug @@ -291,9 +291,9 @@ func (peer *Peer) RoutineNonce() { // wait for key pair for { - keyPair = peer.keyPairs.Current() - if keyPair != nil && keyPair.sendNonce < RejectAfterMessages { - if time.Now().Sub(keyPair.created) < RejectAfterTime { + keypair = peer.keypairs.Current() + if keypair != nil && keypair.sendNonce < RejectAfterMessages { + if time.Now().Sub(keypair.created) < RejectAfterTime { break } } @@ -328,12 +328,12 @@ func (peer *Peer) RoutineNonce() { // populate work element elem.peer = peer - elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 + elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1 // double check in case of race condition added by future code if elem.nonce >= RejectAfterMessages { goto NextPacket } - elem.keyPair = keyPair + elem.keypair = keypair elem.dropped = AtomicFalse elem.mutex.Lock() @@ -392,7 +392,7 @@ func (device *Device) RoutineEncryption() { fieldNonce := header[8:16] binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex) + binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) // pad content to multiple of 16 @@ -408,7 +408,7 @@ func (device *Device) RoutineEncryption() { // encrypt content and release to consumer binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.packet = elem.keyPair.send.Seal( + elem.packet = elem.keypair.send.Seal( header, nonce[:], elem.packet, diff --git a/timers.go b/timers.go index 5c72efd..9e633ee 100644 --- a/timers.go +++ b/timers.go @@ -108,7 +108,7 @@ func expiredZeroKeyMaterial(peer *Peer) { hs := &peer.handshake hs.mutex.Lock() - kp := &peer.keyPairs + kp := &peer.keypairs kp.mutex.Lock() if kp.previous != nil { @@ -125,7 +125,7 @@ func expiredZeroKeyMaterial(peer *Peer) { } kp.mutex.Unlock() - peer.device.indices.Delete(hs.localIndex) + peer.device.indexTable.Delete(hs.localIndex) hs.Clear() hs.mutex.Unlock() }