diff --git a/device/allowedips.go b/device/allowedips.go index cb6df0a..b6f096a 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -6,6 +6,7 @@ package device import ( + "container/list" "errors" "math/bits" "net" @@ -14,14 +15,13 @@ import ( ) type trieEntry struct { - child [2]*trieEntry - peer *Peer - bits net.IP - cidr uint - bit_at_byte uint - bit_at_shift uint - nextEntryForPeer *trieEntry - pprevEntryForPeer **trieEntry + child [2]*trieEntry + peer *Peer + bits net.IP + cidr uint + bit_at_byte uint + bit_at_shift uint + perPeerElem *list.Element } func isLittleEndian() bool { @@ -69,28 +69,14 @@ func commonBits(ip1 net.IP, ip2 net.IP) uint { } func (node *trieEntry) addToPeerEntries() { - p := node.peer - first := p.firstTrieEntry - node.nextEntryForPeer = first - if first != nil { - first.pprevEntryForPeer = &node.nextEntryForPeer - } - p.firstTrieEntry = node - node.pprevEntryForPeer = &p.firstTrieEntry + node.perPeerElem = node.peer.trieEntries.PushBack(node) } func (node *trieEntry) removeFromPeerEntries() { - if node.pprevEntryForPeer == nil { - return + if node.perPeerElem != nil { + node.peer.trieEntries.Remove(node.perPeerElem) + node.perPeerElem = nil } - next := node.nextEntryForPeer - pprev := node.pprevEntryForPeer - *pprev = next - if next != nil { - next.pprevEntryForPeer = pprev - } - node.nextEntryForPeer = nil - node.pprevEntryForPeer = nil } func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { @@ -226,7 +212,8 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint table.mutex.RLock() defer table.mutex.RUnlock() - for node := peer.firstTrieEntry; node != nil; node = node.nextEntryForPeer { + for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { + node := elem.Value.(*trieEntry) if !cb(node.bits, node.cidr) { return } diff --git a/device/peer.go b/device/peer.go index a3b428a..499888d 100644 --- a/device/peer.go +++ b/device/peer.go @@ -6,6 +6,7 @@ package device import ( + "container/list" "encoding/base64" "errors" "fmt" @@ -17,15 +18,13 @@ import ( ) type Peer struct { - isRunning AtomicBool - sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer - keypairs Keypairs - handshake Handshake - device *Device - endpoint conn.Endpoint - persistentKeepaliveInterval uint32 // accessed atomically - firstTrieEntry *trieEntry - stopping sync.WaitGroup // routines pending stop + isRunning AtomicBool + sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer + keypairs Keypairs + handshake Handshake + device *Device + endpoint conn.Endpoint + stopping sync.WaitGroup // routines pending stop // These fields are accessed with atomic operations, which must be // 64-bit aligned even on 32-bit platforms. Go guarantees that an @@ -61,7 +60,9 @@ type Peer struct { inbound *autodrainingInboundQueue // sequential ordering of tun writing } - cookieGenerator CookieGenerator + cookieGenerator CookieGenerator + trieEntries list.List + persistentKeepaliveInterval uint32 // accessed atomically } func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {