From c382222eab9e3814f4df75fd25f8e9e31484b5e0 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Thu, 3 Jun 2021 15:40:09 +0200 Subject: [PATCH] device: remove nodes by peer in O(1) instead of O(n) Now that we have parent pointers hooked up, we can simply go right to the node and remove it in place, rather than having to recursively walk the entire trie. Signed-off-by: Jason A. Donenfeld --- device/allowedips.go | 58 +++++++++--------- device/allowedips_rand_test.go | 104 +++++++++++++++++---------------- 2 files changed, 86 insertions(+), 76 deletions(-) diff --git a/device/allowedips.go b/device/allowedips.go index d613121..7af9fc7 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -85,30 +85,6 @@ func (node *trieEntry) removeFromPeerEntries() { } } -func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { - if node == nil { - return node - } - - // walk recursively - - node.child[0] = node.child[0].removeByPeer(p) - node.child[1] = node.child[1].removeByPeer(p) - - if node.peer != p { - return node - } - - // remove peer & merge - - node.removeFromPeerEntries() - node.peer = nil - if node.child[0] == nil { - return node.child[1] - } - return node.child[0] -} - func (node *trieEntry) choose(ip net.IP) byte { return (ip[node.bitAtByte] >> node.bitAtShift) & 1 } @@ -261,8 +237,38 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - table.IPv4 = table.IPv4.removeByPeer(peer) - table.IPv6 = table.IPv6.removeByPeer(peer) + var next *list.Element + for elem := peer.trieEntries.Front(); elem != nil; elem = next { + next = elem.Next() + node := elem.Value.(*trieEntry) + + node.removeFromPeerEntries() + node.peer = nil + if node.child[0] != nil && node.child[1] != nil { + continue + } + bit := 0 + if node.child[0] == nil { + bit = 1 + } + child := node.child[bit] + if child != nil { + child.parent = node.parent + } + *node.parent.parentBit = child + if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { + continue + } + parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) + if parent.peer != nil { + continue + } + child = parent.child[node.parent.parentBitType^1] + if child != nil { + child.parent = parent.parent + } + *parent.parent.parentBit = child + } } func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index 48a5bcd..c5f80fe 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -7,6 +7,7 @@ package device import ( "math/rand" + "net" "sort" "testing" ) @@ -64,68 +65,71 @@ func (r SlowRouter) Lookup(addr []byte) *Peer { return nil } -func TestTrieRandomIPv4(t *testing.T) { - var slow SlowRouter +func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { + n := 0 + for _, x := range r { + if x.peer != peer { + r[n] = x + n++ + } + } + return r[:n] +} + +func TestTrieRandom(t *testing.T) { + var slow4, slow6 SlowRouter var peers []*Peer var allowedIPs AllowedIPs rand.Seed(1) - const AddressLength = 4 - for n := 0; n < NumberOfPeers; n++ { peers = append(peers, &Peer{}) } for n := 0; n < NumberOfAddresses; n++ { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint8(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - allowedIPs.Insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) + var addr4 [4]byte + rand.Read(addr4[:]) + cidr := uint8(rand.Intn(32) + 1) + index := rand.Intn(NumberOfPeers) + allowedIPs.Insert(addr4[:], cidr, peers[index]) + slow4 = slow4.Insert(addr4[:], cidr, peers[index]) + + var addr6 [16]byte + rand.Read(addr6[:]) + cidr = uint8(rand.Intn(128) + 1) + index = rand.Intn(NumberOfPeers) + allowedIPs.Insert(addr6[:], cidr, peers[index]) + slow6 = slow6.Insert(addr6[:], cidr, peers[index]) } - for n := 0; n < NumberOfTests; n++ { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := allowedIPs.LookupIPv4(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) - } - } -} - -func TestTrieRandomIPv6(t *testing.T) { - var slow SlowRouter - var peers []*Peer - var allowedIPs AllowedIPs - - rand.Seed(1) - - const AddressLength = 16 - - for n := 0; n < NumberOfPeers; n++ { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n++ { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint8(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - allowedIPs.Insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < NumberOfTests; n++ { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := allowedIPs.LookupIPv6(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) + for p := 0; ; p++ { + for n := 0; n < NumberOfTests; n++ { + var addr4 [4]byte + rand.Read(addr4[:]) + peer1 := slow4.Lookup(addr4[:]) + peer2 := allowedIPs.LookupIPv4(addr4[:]) + if peer1 != peer2 { + t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) + } + + var addr6 [16]byte + rand.Read(addr6[:]) + peer1 = slow6.Lookup(addr6[:]) + peer2 = allowedIPs.LookupIPv6(addr6[:]) + if peer1 != peer2 { + t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) + } } + if p >= len(peers) { + break + } + allowedIPs.RemoveByPeer(peers[p]) + slow4 = slow4.RemoveByPeer(peers[p]) + slow6 = slow6.RemoveByPeer(peers[p]) + } + + if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { + t.Error("Failed to remove all nodes from trie by peer") } }