device: remove nodes by peer in O(1) instead of O(n)

Now that we have parent pointers hooked up, we can simply go right to
the node and remove it in place, rather than having to recursively walk
the entire trie.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-06-03 15:40:09 +02:00
parent b41f4cc768
commit c382222eab
2 changed files with 86 additions and 76 deletions

View file

@ -85,30 +85,6 @@ func (node *trieEntry) removeFromPeerEntries() {
} }
} }
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
if node == nil {
return node
}
// walk recursively
node.child[0] = node.child[0].removeByPeer(p)
node.child[1] = node.child[1].removeByPeer(p)
if node.peer != p {
return node
}
// remove peer & merge
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] == nil {
return node.child[1]
}
return node.child[0]
}
func (node *trieEntry) choose(ip net.IP) byte { func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1 return (ip[node.bitAtByte] >> node.bitAtShift) & 1
} }
@ -261,8 +237,38 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()
table.IPv4 = table.IPv4.removeByPeer(peer) var next *list.Element
table.IPv6 = table.IPv6.removeByPeer(peer) for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
node := elem.Value.(*trieEntry)
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] != nil && node.child[1] != nil {
continue
}
bit := 0
if node.child[0] == nil {
bit = 1
}
child := node.child[bit]
if child != nil {
child.parent = node.parent
}
*node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
continue
}
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil {
continue
}
child = parent.child[node.parent.parentBitType^1]
if child != nil {
child.parent = parent.parent
}
*parent.parent.parentBit = child
}
} }
func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {

View file

@ -7,6 +7,7 @@ package device
import ( import (
"math/rand" "math/rand"
"net"
"sort" "sort"
"testing" "testing"
) )
@ -64,68 +65,71 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
return nil return nil
} }
func TestTrieRandomIPv4(t *testing.T) { func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
var slow SlowRouter n := 0
for _, x := range r {
if x.peer != peer {
r[n] = x
n++
}
}
return r[:n]
}
func TestTrieRandom(t *testing.T) {
var slow4, slow6 SlowRouter
var peers []*Peer var peers []*Peer
var allowedIPs AllowedIPs var allowedIPs AllowedIPs
rand.Seed(1) rand.Seed(1)
const AddressLength = 4
for n := 0; n < NumberOfPeers; n++ { for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{}) peers = append(peers, &Peer{})
} }
for n := 0; n < NumberOfAddresses; n++ { for n := 0; n < NumberOfAddresses; n++ {
var addr [AddressLength]byte var addr4 [4]byte
rand.Read(addr[:]) rand.Read(addr4[:])
cidr := uint8(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Intn(32) + 1)
index := rand.Int() % NumberOfPeers index := rand.Intn(NumberOfPeers)
allowedIPs.Insert(addr[:], cidr, peers[index]) allowedIPs.Insert(addr4[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index]) slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
allowedIPs.Insert(addr6[:], cidr, peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
} }
for p := 0; ; p++ {
for n := 0; n < NumberOfTests; n++ { for n := 0; n < NumberOfTests; n++ {
var addr [AddressLength]byte var addr4 [4]byte
rand.Read(addr[:]) rand.Read(addr4[:])
peer1 := slow.Lookup(addr[:]) peer1 := slow4.Lookup(addr4[:])
peer2 := allowedIPs.LookupIPv4(addr[:]) peer2 := allowedIPs.LookupIPv4(addr4[:])
if peer1 != peer2 { if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr) t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
}
}
}
func TestTrieRandomIPv6(t *testing.T) {
var slow SlowRouter
var peers []*Peer
var allowedIPs AllowedIPs
rand.Seed(1)
const AddressLength = 16
for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
allowedIPs.Insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}
for n := 0; n < NumberOfTests; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
peer2 := allowedIPs.LookupIPv6(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
} }
var addr6 [16]byte
rand.Read(addr6[:])
peer1 = slow6.Lookup(addr6[:])
peer2 = allowedIPs.LookupIPv6(addr6[:])
if peer1 != peer2 {
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
}
}
if p >= len(peers) {
break
}
allowedIPs.RemoveByPeer(peers[p])
slow4 = slow4.RemoveByPeer(peers[p])
slow6 = slow6.RemoveByPeer(peers[p])
}
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Failed to remove all nodes from trie by peer")
} }
} }