wireguard-go/trie.go
Jason A. Donenfeld a040786645 global: Add SPDX tags and copyright header
Mathias should probably add his copyright headers to each file too.
2018-05-03 15:06:25 +02:00

234 lines
4 KiB
Go

/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*/
package main
import (
"errors"
"net"
)
/* Binary trie
*
* The net.IPs used here are not formatted the
* same way as those created by the "net" functions.
* Here the IPs are slices of either 4 or 16 byte (not always 16)
*
* Synchronization done separately
* See: routing.go
*/
type Trie struct {
cidr uint
child [2]*Trie
bits []byte
peer *Peer
// index of "branching" bit
bit_at_byte uint
bit_at_shift uint
}
/* Finds length of matching prefix
*
* TODO: Only use during insertion (xor + prefix mask for lookup)
* Check out
* prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits)
* https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match
*
* Assumption:
* len(ip1) == len(ip2)
* len(ip1) mod 4 = 0
*/
func commonBits(ip1 []byte, ip2 []byte) uint {
var i uint
size := uint(len(ip1))
for i = 0; i < size; i++ {
v := ip1[i] ^ ip2[i]
if v != 0 {
v >>= 1
if v == 0 {
return i*8 + 7
}
v >>= 1
if v == 0 {
return i*8 + 6
}
v >>= 1
if v == 0 {
return i*8 + 5
}
v >>= 1
if v == 0 {
return i*8 + 4
}
v >>= 1
if v == 0 {
return i*8 + 3
}
v >>= 1
if v == 0 {
return i*8 + 2
}
v >>= 1
if v == 0 {
return i*8 + 1
}
return i * 8
}
}
return i * 8
}
func (node *Trie) RemovePeer(p *Peer) *Trie {
if node == nil {
return node
}
// walk recursively
node.child[0] = node.child[0].RemovePeer(p)
node.child[1] = node.child[1].RemovePeer(p)
if node.peer != p {
return node
}
// remove peer & merge
node.peer = nil
if node.child[0] == nil {
return node.child[1]
}
return node.child[0]
}
func (node *Trie) choose(ip net.IP) byte {
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
}
func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
// at leaf
if node == nil {
return &Trie{
bits: ip,
peer: peer,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
}
// traverse deeper
common := commonBits(node.bits, ip)
if node.cidr <= cidr && common >= node.cidr {
if node.cidr == cidr {
node.peer = peer
return node
}
bit := node.choose(ip)
node.child[bit] = node.child[bit].Insert(ip, cidr, peer)
return node
}
// split node
newNode := &Trie{
bits: ip,
peer: peer,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
cidr = min(cidr, common)
// check for shorter prefix
if newNode.cidr == cidr {
bit := newNode.choose(node.bits)
newNode.child[bit] = node
return newNode
}
// create new parent for node & newNode
parent := &Trie{
bits: ip,
peer: nil,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
bit := parent.choose(ip)
parent.child[bit] = newNode
parent.child[bit^1] = node
return parent
}
func (node *Trie) Lookup(ip net.IP) *Peer {
var found *Peer
size := uint(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil {
found = node.peer
}
if node.bit_at_byte == size {
break
}
bit := node.choose(ip)
node = node.child[bit]
}
return found
}
func (node *Trie) Count() uint {
if node == nil {
return 0
}
l := node.child[0].Count()
r := node.child[1].Count()
return l + r
}
func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
if node == nil {
return results
}
if node.peer == p {
var mask net.IPNet
mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
if len(node.bits) == net.IPv4len {
mask.IP = net.IPv4(
node.bits[0],
node.bits[1],
node.bits[2],
node.bits[3],
)
} else if len(node.bits) == net.IPv6len {
mask.IP = node.bits
} else {
panic(errors.New("bug: unexpected address length"))
}
results = append(results, mask)
}
results = node.child[0].AllowedIPs(p, results)
results = node.child[1].AllowedIPs(p, results)
return results
}