Better common bits function

This commit is contained in:
Jason A. Donenfeld 2018-05-14 15:49:20 +02:00
parent 7f1c9d1cc2
commit 09235d48d8
2 changed files with 44 additions and 56 deletions

View file

@ -7,8 +7,10 @@ package main
import ( import (
"errors" "errors"
"math/bits"
"net" "net"
"sync" "sync"
"unsafe"
) )
type trieEntry struct { type trieEntry struct {
@ -23,62 +25,48 @@ type trieEntry struct {
bit_at_shift uint bit_at_shift uint
} }
/* Finds length of matching prefix func isLittleEndian() bool {
* one := uint32(1)
* TODO: Only use during insertion (xor + prefix mask for lookup) return *(*byte)(unsafe.Pointer(&one)) != 0
* Check out
* prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits)
* https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match
*
* Assumption:
* len(ip1) == len(ip2)
* len(ip1) mod 4 = 0
*/
func commonBits(ip1 []byte, ip2 []byte) uint {
var i uint
size := uint(len(ip1))
for i = 0; i < size; i++ {
v := ip1[i] ^ ip2[i]
if v != 0 {
v >>= 1
if v == 0 {
return i*8 + 7
} }
v >>= 1 func swapU32(i uint32) uint32 {
if v == 0 { if !isLittleEndian() {
return i*8 + 6 return i
} }
v >>= 1 return bits.ReverseBytes32(i)
if v == 0 {
return i*8 + 5
} }
v >>= 1 func swapU64(i uint64) uint64 {
if v == 0 { if !isLittleEndian() {
return i*8 + 4 return i
} }
v >>= 1 return bits.ReverseBytes64(i)
if v == 0 {
return i*8 + 3
} }
v >>= 1 func commonBits(ip1 net.IP, ip2 net.IP) uint {
if v == 0 { size := len(ip1)
return i*8 + 2 if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0]))
b := (*uint32)(unsafe.Pointer(&ip2[0]))
x := *a ^ *b
return uint(bits.LeadingZeros32(swapU32(x)))
} else if size == net.IPv6len {
a := (*uint64)(unsafe.Pointer(&ip1[0]))
b := (*uint64)(unsafe.Pointer(&ip2[0]))
x := *a ^ *b
if x != 0 {
return uint(bits.LeadingZeros64(swapU64(x)))
} }
a = (*uint64)(unsafe.Pointer(&ip1[8]))
v >>= 1 b = (*uint64)(unsafe.Pointer(&ip2[8]))
if v == 0 { x = *a ^ *b
return i*8 + 1 return 64 + uint(bits.LeadingZeros64(swapU64(x)))
} else {
panic("Wrong size bit string")
} }
return i * 8
}
}
return i * 8
} }
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {

View file

@ -106,7 +106,7 @@ func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
} }
/* Test ported from kernel implementation: /* Test ported from kernel implementation:
* selftest/routingtable.h * selftest/allowedips.h
*/ */
func TestTrieIPv4(t *testing.T) { func TestTrieIPv4(t *testing.T) {
a := &Peer{} a := &Peer{}
@ -192,7 +192,7 @@ func TestTrieIPv4(t *testing.T) {
} }
/* Test ported from kernel implementation: /* Test ported from kernel implementation:
* selftest/routingtable.h * selftest/allowedips.h
*/ */
func TestTrieIPv6(t *testing.T) { func TestTrieIPv6(t *testing.T) {
a := &Peer{} a := &Peer{}