global: use netip where possible now

There are more places where we'll need to add it later, when Go 1.18
comes out with support for it in the "net" package. Also, allowedips
still uses slices internally, which might be suboptimal.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-11-05 01:52:54 +01:00
parent de7c702ace
commit ef8d6804d7
22 changed files with 247 additions and 285 deletions

View file

@ -14,6 +14,7 @@ import (
"unsafe" "unsafe"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/go118/netip"
) )
type ipv4Source struct { type ipv4Source struct {
@ -70,32 +71,30 @@ var _ Bind = (*LinuxSocketBind)(nil)
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) { func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
var end LinuxSocketEndpoint var end LinuxSocketEndpoint
addr, err := parseEndpoint(s) e, err := netip.ParseAddrPort(s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ipv4 := addr.IP.To4() if e.Addr().Is4() {
if ipv4 != nil {
dst := end.dst4() dst := end.dst4()
end.isV6 = false end.isV6 = false
dst.Port = addr.Port dst.Port = int(e.Port())
copy(dst.Addr[:], ipv4) dst.Addr = e.Addr().As4()
end.ClearSrc() end.ClearSrc()
return &end, nil return &end, nil
} }
ipv6 := addr.IP.To16() if e.Addr().Is6() {
if ipv6 != nil { zone, err := zoneToUint32(e.Addr().Zone())
zone, err := zoneToUint32(addr.Zone)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dst := end.dst6() dst := end.dst6()
end.isV6 = true end.isV6 = true
dst.Port = addr.Port dst.Port = int(e.Port())
dst.ZoneId = zone dst.ZoneId = zone
copy(dst.Addr[:], ipv6[:]) dst.Addr = e.Addr().As16()
end.ClearSrc() end.ClearSrc()
return &end, nil return &end, nil
} }
@ -266,29 +265,19 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
} }
} }
func (end *LinuxSocketEndpoint) SrcIP() net.IP { func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
if !end.isV6 { if !end.isV6 {
return net.IPv4( return netip.AddrFrom4(end.src4().Src)
end.src4().Src[0],
end.src4().Src[1],
end.src4().Src[2],
end.src4().Src[3],
)
} else { } else {
return end.src6().src[:] return netip.AddrFrom16(end.src6().src)
} }
} }
func (end *LinuxSocketEndpoint) DstIP() net.IP { func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
if !end.isV6 { if !end.isV6 {
return net.IPv4( return netip.AddrFrom4(end.dst4().Addr)
end.dst4().Addr[0],
end.dst4().Addr[1],
end.dst4().Addr[2],
end.dst4().Addr[3],
)
} else { } else {
return end.dst6().Addr[:] return netip.AddrFrom16(end.dst6().Addr)
} }
} }
@ -305,14 +294,13 @@ func (end *LinuxSocketEndpoint) SrcToString() string {
} }
func (end *LinuxSocketEndpoint) DstToString() string { func (end *LinuxSocketEndpoint) DstToString() string {
var udpAddr net.UDPAddr var port int
udpAddr.IP = end.DstIP()
if !end.isV6 { if !end.isV6 {
udpAddr.Port = end.dst4().Port port = end.dst4().Port
} else { } else {
udpAddr.Port = end.dst6().Port port = end.dst6().Port
} }
return udpAddr.String() return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
} }
func (end *LinuxSocketEndpoint) ClearDst() { func (end *LinuxSocketEndpoint) ClearDst() {

View file

@ -10,6 +10,8 @@ import (
"net" "net"
"sync" "sync"
"syscall" "syscall"
"golang.zx2c4.com/go118/netip"
) )
// StdNetBind is meant to be a temporary solution on platforms for which // StdNetBind is meant to be a temporary solution on platforms for which
@ -32,18 +34,23 @@ var _ Bind = (*StdNetBind)(nil)
var _ Endpoint = (*StdNetEndpoint)(nil) var _ Endpoint = (*StdNetEndpoint)(nil)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
addr, err := parseEndpoint(s) e, err := netip.ParseAddrPort(s)
return (*StdNetEndpoint)(addr), err return (*StdNetEndpoint)(&net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}), err
} }
func (*StdNetEndpoint) ClearSrc() {} func (*StdNetEndpoint) ClearSrc() {}
func (e *StdNetEndpoint) DstIP() net.IP { func (e *StdNetEndpoint) DstIP() netip.Addr {
return (*net.UDPAddr)(e).IP a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP)
return a
} }
func (e *StdNetEndpoint) SrcIP() net.IP { func (e *StdNetEndpoint) SrcIP() netip.Addr {
return nil // not supported return netip.Addr{} // not supported
} }
func (e *StdNetEndpoint) DstToBytes() []byte { func (e *StdNetEndpoint) DstToBytes() []byte {

View file

@ -15,6 +15,7 @@ import (
"unsafe" "unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn/winrio" "golang.zx2c4.com/wireguard/conn/winrio"
) )
@ -128,18 +129,18 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
func (*WinRingEndpoint) ClearSrc() {} func (*WinRingEndpoint) ClearSrc() {}
func (e *WinRingEndpoint) DstIP() net.IP { func (e *WinRingEndpoint) DstIP() netip.Addr {
switch e.family { switch e.family {
case windows.AF_INET: case windows.AF_INET:
return append([]byte{}, e.data[2:6]...) return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
case windows.AF_INET6: case windows.AF_INET6:
return append([]byte{}, e.data[6:22]...) return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
} }
return nil return netip.Addr{}
} }
func (e *WinRingEndpoint) SrcIP() net.IP { func (e *WinRingEndpoint) SrcIP() netip.Addr {
return nil // not supported return netip.Addr{} // not supported
} }
func (e *WinRingEndpoint) DstToBytes() []byte { func (e *WinRingEndpoint) DstToBytes() []byte {
@ -161,15 +162,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
func (e *WinRingEndpoint) DstToString() string { func (e *WinRingEndpoint) DstToString() string {
switch e.family { switch e.family {
case windows.AF_INET: case windows.AF_INET:
addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))} netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
return addr.String()
case windows.AF_INET6: case windows.AF_INET6:
var zone string var zone string
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
zone = strconv.FormatUint(uint64(scope), 10) zone = strconv.FormatUint(uint64(scope), 10)
} }
addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))} return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
return addr.String()
} }
return "" return ""
} }

View file

@ -10,8 +10,8 @@ import (
"math/rand" "math/rand"
"net" "net"
"os" "os"
"strconv"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
) )
@ -61,9 +61,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) } func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
func (c ChannelEndpoint) SrcIP() net.IP { return nil } func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool) c.closeSignal = make(chan bool)
@ -119,13 +119,9 @@ func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
} }
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
_, port, err := net.SplitHostPort(s) addr, err := netip.ParseAddrPort(s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
i, err := strconv.ParseUint(port, 10, 16) return ChannelEndpoint(addr.Port()), nil
if err != nil {
return nil, err
}
return ChannelEndpoint(i), nil
} }

View file

@ -9,10 +9,11 @@ package conn
import ( import (
"errors" "errors"
"fmt" "fmt"
"net"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
"golang.zx2c4.com/go118/netip"
) )
// A ReceiveFunc receives a single inbound packet from the network. // A ReceiveFunc receives a single inbound packet from the network.
@ -68,8 +69,8 @@ type Endpoint interface {
SrcToString() string // returns the local source address (ip:port) SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port) DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP DstIP() netip.Addr
SrcIP() net.IP SrcIP() netip.Addr
} }
var ( var (
@ -119,33 +120,3 @@ func (fn ReceiveFunc) PrettyName() string {
} }
return name return name
} }
func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address
host, _, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
// trying to make sure with a small sanity test that this is a real IP address and
// not something that's likely to incur DNS lookups.
host = host[:i]
}
if ip := net.ParseIP(host); ip == nil {
return nil, errors.New("Failed to parse IP address: " + host)
}
// parse address and port
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
return nil, err
}
ip4 := addr.IP.To4()
if ip4 != nil {
addr.IP = ip4
}
return addr, err
}

View file

@ -12,6 +12,8 @@ import (
"net" "net"
"sync" "sync"
"unsafe" "unsafe"
"golang.zx2c4.com/go118/netip"
) )
type parentIndirection struct { type parentIndirection struct {
@ -26,7 +28,7 @@ type trieEntry struct {
cidr uint8 cidr uint8
bitAtByte uint8 bitAtByte uint8
bitAtShift uint8 bitAtShift uint8
bits net.IP bits []byte
perPeerElem *list.Element perPeerElem *list.Element
} }
@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 {
return bits.ReverseBytes64(i) return bits.ReverseBytes64(i)
} }
func commonBits(ip1 net.IP, ip2 net.IP) uint8 { func commonBits(ip1, ip2 []byte) uint8 {
size := len(ip1) size := len(ip1)
if size == net.IPv4len { if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0])) a := (*uint32)(unsafe.Pointer(&ip1[0]))
@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() {
} }
} }
func (node *trieEntry) choose(ip net.IP) byte { func (node *trieEntry) choose(ip []byte) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1 return (ip[node.bitAtByte] >> node.bitAtShift) & 1
} }
@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() {
node.parent.parentBit = nil node.parent.parentBit = nil
} }
func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) { func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node parent = node
if parent.cidr == cidr { if parent.cidr == cidr {
@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry,
return return
} }
func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if *trie.parentBit == nil { if *trie.parentBit == nil {
node := &trieEntry{ node := &trieEntry{
peer: peer, peer: peer,
@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
} }
} }
func (node *trieEntry) lookup(ip net.IP) *Peer { func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer var found *Peer
size := uint8(len(ip)) size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr { for node != nil && commonBits(node.bits, ip) >= node.cidr {
@ -229,13 +231,14 @@ type AllowedIPs struct {
mutex sync.RWMutex mutex sync.RWMutex
} }
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) { func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry) node := elem.Value.(*trieEntry)
if !cb(node.bits, node.cidr) { a, _ := netip.AddrFromSlice(node.bits)
if !cb(netip.PrefixFrom(a, int(node.cidr))) {
return return
} }
} }
@ -283,28 +286,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
} }
} }
func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()
switch len(ip) { if prefix.Addr().Is6() {
case net.IPv6len: ip := prefix.Addr().As16()
parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer) parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
case net.IPv4len: } else if prefix.Addr().Is4() {
parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer) ip := prefix.Addr().As4()
default: parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else {
panic(errors.New("inserting unknown address type")) panic(errors.New("inserting unknown address type"))
} }
} }
func (table *AllowedIPs) Lookup(address []byte) *Peer { func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
switch len(address) { switch len(ip) {
case net.IPv6len: case net.IPv6len:
return table.IPv6.lookup(address) return table.IPv6.lookup(ip)
case net.IPv4len: case net.IPv4len:
return table.IPv4.lookup(address) return table.IPv4.lookup(ip)
default: default:
panic(errors.New("looking up unknown address type")) panic(errors.New("looking up unknown address type"))
} }

View file

@ -10,6 +10,8 @@ import (
"net" "net"
"sort" "sort"
"testing" "testing"
"golang.zx2c4.com/go118/netip"
) )
const ( const (
@ -93,14 +95,14 @@ func TestTrieRandom(t *testing.T) {
rand.Read(addr4[:]) rand.Read(addr4[:])
cidr := uint8(rand.Intn(32) + 1) cidr := uint8(rand.Intn(32) + 1)
index := rand.Intn(NumberOfPeers) index := rand.Intn(NumberOfPeers)
allowedIPs.Insert(addr4[:], cidr, peers[index]) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index]) slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte var addr6 [16]byte
rand.Read(addr6[:]) rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1) cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers) index = rand.Intn(NumberOfPeers)
allowedIPs.Insert(addr6[:], cidr, peers[index]) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index]) slow6 = slow6.Insert(addr6[:], cidr, peers[index])
} }

View file

@ -9,6 +9,8 @@ import (
"math/rand" "math/rand"
"net" "net"
"testing" "testing"
"golang.zx2c4.com/go118/netip"
) )
type testPairCommonBits struct { type testPairCommonBits struct {
@ -98,7 +100,7 @@ func TestTrieIPv4(t *testing.T) {
var allowedIPs AllowedIPs var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
} }
assertEQ := func(peer *Peer, a, b, c, d byte) { assertEQ := func(peer *Peer, a, b, c, d byte) {
@ -208,7 +210,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...) addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...) addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...) addr = append(addr, expand(d)...)
allowedIPs.Insert(addr, cidr, peer) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
} }
assertEQ := func(peer *Peer, a, b, c, d uint32) { assertEQ := func(peer *Peer, a, b, c, d uint32) {

View file

@ -11,7 +11,6 @@ import (
"fmt" "fmt"
"io" "io"
"math/rand" "math/rand"
"net"
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"sync" "sync"
@ -19,6 +18,7 @@ import (
"testing" "testing"
"time" "time"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest" "golang.zx2c4.com/wireguard/conn/bindtest"
"golang.zx2c4.com/wireguard/tun/tuntest" "golang.zx2c4.com/wireguard/tun/tuntest"
@ -96,7 +96,7 @@ type testPair [2]testPeer
type testPeer struct { type testPeer struct {
tun *tuntest.ChannelTUN tun *tuntest.ChannelTUN
dev *Device dev *Device
ip net.IP ip netip.Addr
} }
type SendDirection bool type SendDirection bool
@ -159,7 +159,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
for i := range pair { for i := range pair {
p := &pair[i] p := &pair[i]
p.tun = tuntest.NewChannelTUN() p.tun = tuntest.NewChannelTUN()
p.ip = net.IPv4(1, 0, 0, byte(i+1)) p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
level := LogLevelVerbose level := LogLevelVerbose
if _, ok := tb.(*testing.B); ok && !testing.Verbose() { if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError level = LogLevelError

View file

@ -7,47 +7,44 @@ package device
import ( import (
"math/rand" "math/rand"
"net"
"golang.zx2c4.com/go118/netip"
) )
type DummyEndpoint struct { type DummyEndpoint struct {
src [16]byte src, dst netip.Addr
dst [16]byte
} }
func CreateDummyEndpoint() (*DummyEndpoint, error) { func CreateDummyEndpoint() (*DummyEndpoint, error) {
var end DummyEndpoint var src, dst [16]byte
if _, err := rand.Read(end.src[:]); err != nil { if _, err := rand.Read(src[:]); err != nil {
return nil, err return nil, err
} }
_, err := rand.Read(end.dst[:]) _, err := rand.Read(dst[:])
return &end, err return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
} }
func (e *DummyEndpoint) ClearSrc() {} func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string { func (e *DummyEndpoint) SrcToString() string {
var addr net.UDPAddr return netip.AddrPortFrom(e.SrcIP(), 1000).String()
addr.IP = e.SrcIP()
addr.Port = 1000
return addr.String()
} }
func (e *DummyEndpoint) DstToString() string { func (e *DummyEndpoint) DstToString() string {
var addr net.UDPAddr return netip.AddrPortFrom(e.DstIP(), 1000).String()
addr.IP = e.DstIP()
addr.Port = 1000
return addr.String()
} }
func (e *DummyEndpoint) SrcToBytes() []byte { func (e *DummyEndpoint) DstToBytes() []byte {
return e.src[:] out := e.DstIP().AsSlice()
out = append(out, byte(1000&0xff))
out = append(out, byte((1000>>8)&0xff))
return out
} }
func (e *DummyEndpoint) DstIP() net.IP { func (e *DummyEndpoint) DstIP() netip.Addr {
return e.dst[:] return e.dst
} }
func (e *DummyEndpoint) SrcIP() net.IP { func (e *DummyEndpoint) SrcIP() netip.Addr {
return e.src[:] return e.src
} }

View file

@ -17,7 +17,6 @@ import (
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
) )

View file

@ -18,6 +18,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
) )
@ -121,8 +122,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool { device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s/%d", ip.String(), cidr) sendf("allowed_ip=%s", prefix.String())
return true return true
}) })
} }
@ -374,16 +375,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "allowed_ip": case "allowed_ip":
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
prefix, err := netip.ParsePrefix(value)
_, network, err := net.ParseCIDR(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
} }
if peer.dummy { if peer.dummy {
return nil return nil
} }
ones, _ := network.Mask.Size() device.allowedips.Insert(prefix, peer.Peer)
device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
case "protocol_version": case "protocol_version":
if value != "1" { if value != "1" {

7
go.mod
View file

@ -3,8 +3,9 @@ module golang.zx2c4.com/wireguard
go 1.17 go 1.17
require ( require (
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa
golang.org/x/net v0.0.0-20211101193420-4a448f8816b3 golang.org/x/net v0.0.0-20211111083644-e5c967477495
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b golang.org/x/sys v0.0.0-20211110154304-99a53858aa08
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
) )

13
go.sum
View file

@ -1,16 +1,19 @@
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa h1:idItI2DDfCokpg0N51B2VtiLdJ4vAuXC9fnCb2gACo4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211101193420-4a448f8816b3 h1:VrJZAjbekhoRn7n5FBujY31gboH+iB3pdLxn3gE9FjU= golang.org/x/net v0.0.0-20211111083644-e5c967477495 h1:cjxxlQm6d4kYbhpZ2ghvmI8xnq0AG+jXmzrhzfkyu5A=
golang.org/x/net v0.0.0-20211101193420-4a448f8816b3/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211111083644-e5c967477495/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4=
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211110154304-99a53858aa08 h1:WecRHqgE09JBkh/584XIE6PMz5KKE/vER4izNUi30AQ=
golang.org/x/sys v0.0.0-20211110154304-99a53858aa08/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d h1:9+v0G0naRhLPOJEeJOL6NuXTtAHHwmkyZlgQJ0XcQ8I=
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=

View file

@ -6,9 +6,10 @@
package ratelimiter package ratelimiter
import ( import (
"net"
"sync" "sync"
"time" "time"
"golang.zx2c4.com/go118/netip"
) )
const ( const (
@ -30,8 +31,7 @@ type Ratelimiter struct {
timeNow func() time.Time timeNow func() time.Time
stopReset chan struct{} // send to reset, close to stop stopReset chan struct{} // send to reset, close to stop
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry table map[netip.Addr]*RatelimiterEntry
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
} }
func (rate *Ratelimiter) Close() { func (rate *Ratelimiter) Close() {
@ -57,8 +57,7 @@ func (rate *Ratelimiter) Init() {
} }
rate.stopReset = make(chan struct{}) rate.stopReset = make(chan struct{})
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) rate.table = make(map[netip.Addr]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
stopReset := rate.stopReset // store in case Init is called again. stopReset := rate.stopReset // store in case Init is called again.
@ -87,71 +86,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) {
rate.mu.Lock() rate.mu.Lock()
defer rate.mu.Unlock() defer rate.mu.Unlock()
for key, entry := range rate.tableIPv4 { for key, entry := range rate.table {
entry.mu.Lock() entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv4, key) delete(rate.table, key)
} }
entry.mu.Unlock() entry.mu.Unlock()
} }
for key, entry := range rate.tableIPv6 { return len(rate.table) == 0
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv6, key)
}
entry.mu.Unlock()
} }
return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
}
func (rate *Ratelimiter) Allow(ip net.IP) bool {
var entry *RatelimiterEntry var entry *RatelimiterEntry
var keyIPv4 [net.IPv4len]byte
var keyIPv6 [net.IPv6len]byte
// lookup entry // lookup entry
IPv4 := ip.To4()
IPv6 := ip.To16()
rate.mu.RLock() rate.mu.RLock()
entry = rate.table[ip]
if IPv4 != nil {
copy(keyIPv4[:], IPv4)
entry = rate.tableIPv4[keyIPv4]
} else {
copy(keyIPv6[:], IPv6)
entry = rate.tableIPv6[keyIPv6]
}
rate.mu.RUnlock() rate.mu.RUnlock()
// make new entry if not found // make new entry if not found
if entry == nil { if entry == nil {
entry = new(RatelimiterEntry) entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost entry.tokens = maxTokens - packetCost
entry.lastTime = rate.timeNow() entry.lastTime = rate.timeNow()
rate.mu.Lock() rate.mu.Lock()
if IPv4 != nil { rate.table[ip] = entry
rate.tableIPv4[keyIPv4] = entry if len(rate.table) == 1 {
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
rate.stopReset <- struct{}{} rate.stopReset <- struct{}{}
} }
} else {
rate.tableIPv6[keyIPv6] = entry
if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
rate.stopReset <- struct{}{}
}
}
rate.mu.Unlock() rate.mu.Unlock()
return true return true
} }
// add tokens to entry // add tokens to entry
entry.mu.Lock() entry.mu.Lock()
now := rate.timeNow() now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds() entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
@ -161,7 +128,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
} }
// subtract cost of packet // subtract cost of packet
if entry.tokens > packetCost { if entry.tokens > packetCost {
entry.tokens -= packetCost entry.tokens -= packetCost
entry.mu.Unlock() entry.mu.Unlock()

View file

@ -6,9 +6,10 @@
package ratelimiter package ratelimiter
import ( import (
"net"
"testing" "testing"
"time" "time"
"golang.zx2c4.com/go118/netip"
) )
type result struct { type result struct {
@ -71,21 +72,21 @@ func TestRatelimiter(t *testing.T) {
text: "packet following 2 packet burst", text: "packet following 2 packet burst",
}) })
ips := []net.IP{ ips := []netip.Addr{
net.ParseIP("127.0.0.1"), netip.MustParseAddr("127.0.0.1"),
net.ParseIP("192.168.1.1"), netip.MustParseAddr("192.168.1.1"),
net.ParseIP("172.167.2.3"), netip.MustParseAddr("172.167.2.3"),
net.ParseIP("97.231.252.215"), netip.MustParseAddr("97.231.252.215"),
net.ParseIP("248.97.91.167"), netip.MustParseAddr("248.97.91.167"),
net.ParseIP("188.208.233.47"), netip.MustParseAddr("188.208.233.47"),
net.ParseIP("104.2.183.179"), netip.MustParseAddr("104.2.183.179"),
net.ParseIP("72.129.46.120"), netip.MustParseAddr("72.129.46.120"),
net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"), netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
} }
now := time.Now() now := time.Now()

View file

@ -1,4 +1,5 @@
//go:build ignore //go:build ignore
// +build ignore
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
@ -10,9 +11,9 @@ package main
import ( import (
"io" "io"
"log" "log"
"net"
"net/http" "net/http"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
@ -20,8 +21,8 @@ import (
func main() { func main() {
tun, tnet, err := netstack.CreateNetTUN( tun, tnet, err := netstack.CreateNetTUN(
[]net.IP{net.ParseIP("192.168.4.29")}, []netip.Addr{netip.MustParseAddr("192.168.4.29")},
[]net.IP{net.ParseIP("8.8.8.8")}, []netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420) 1420)
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)

View file

@ -1,4 +1,5 @@
//go:build ignore //go:build ignore
// +build ignore
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
@ -13,6 +14,7 @@ import (
"net" "net"
"net/http" "net/http"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
@ -20,8 +22,8 @@ import (
func main() { func main() {
tun, tnet, err := netstack.CreateNetTUN( tun, tnet, err := netstack.CreateNetTUN(
[]net.IP{net.ParseIP("192.168.4.29")}, []netip.Addr{netip.MustParseAddr("192.168.4.29")},
[]net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")}, []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
1420, 1420,
) )
if err != nil { if err != nil {

View file

@ -6,6 +6,7 @@ require (
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6 golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22
gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6 gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6
) )

View file

@ -805,6 +805,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5 h1:mV4w4F7AtWXoDNkko9odoTdWpNwyDh8jx+S1fOZKDLg=
golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 h1:nFvpdzrHF9IPo9xPgayHWObCATpQYKky8VSSdt9lf9E=
golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw= golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw=
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg= golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=

View file

@ -18,6 +18,7 @@ import (
"strings" "strings"
"time" "time"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage" "golang.org/x/net/dns/dnsmessage"
@ -38,7 +39,7 @@ type netTun struct {
events chan tun.Event events chan tun.Event
incomingPacket chan buffer.VectorisedView incomingPacket chan buffer.VectorisedView
mtu int mtu int
dnsServers []net.IP dnsServers []netip.Addr
hasV4, hasV6 bool hasV4, hasV6 bool
} }
type endpoint netTun type endpoint netTun
@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
} }
func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) { func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{ opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
} }
for _, ip := range localAddresses { for _, ip := range localAddresses {
if ip4 := ip.To4(); ip4 != nil { var protoNumber tcpip.NetworkProtocolNumber
protoAddr := tcpip.ProtocolAddress{ if ip.Is4() {
Protocol: ipv4.ProtocolNumber, protoNumber = ipv4.ProtocolNumber
AddressWithPrefix: tcpip.Address(ip4).WithPrefix(), } else if ip.Is6() {
protoNumber = ipv6.ProtocolNumber
} }
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr)
}
dev.hasV4 = true
} else {
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber, Protocol: protoNumber,
AddressWithPrefix: tcpip.Address(ip).WithPrefix(), AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
} }
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil { if tcpipErr != nil {
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
} }
if ip.Is4() {
dev.hasV4 = true
} else if ip.Is6() {
dev.hasV6 = true dev.hasV6 = true
} }
} }
@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil return tun.mtu, nil
} }
func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
if ip4 := ip.To4(); ip4 != nil { var protoNumber tcpip.NetworkProtocolNumber
return tcpip.FullAddress{ if endpoint.Addr().Is4() {
NIC: 1, protoNumber = ipv4.ProtocolNumber
Addr: tcpip.Address(ip4),
Port: uint16(port),
}, ipv4.ProtocolNumber
} else { } else {
protoNumber = ipv6.ProtocolNumber
}
return tcpip.FullAddress{ return tcpip.FullAddress{
NIC: 1, NIC: 1,
Addr: tcpip.Address(ip), Addr: tcpip.Address(endpoint.Addr().AsSlice()),
Port: uint16(port), Port: endpoint.Port(),
}, ipv6.ProtocolNumber }, protoNumber
} }
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
fa, pn := convertToFullAddr(addr)
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
} }
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil { if addr == nil {
panic("todo: deal with auto addr semantics for nil addr") return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
} }
fa, pn := convertToFullAddr(addr.IP, addr.Port) return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
return gonet.DialContextTCP(ctx, net.stack, fa, pn) }
func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
fa, pn := convertToFullAddr(addr)
return gonet.DialTCP(net.stack, fa, pn)
} }
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil { if addr == nil {
panic("todo: deal with auto addr semantics for nil addr") return net.DialTCPAddrPort(netip.AddrPort{})
} }
fa, pn := convertToFullAddr(addr.IP, addr.Port) return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
return gonet.DialTCP(net.stack, fa, pn) }
func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
fa, pn := convertToFullAddr(addr)
return gonet.ListenTCP(net.stack, fa, pn)
} }
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
if addr == nil { if addr == nil {
panic("todo: deal with auto addr semantics for nil addr") return net.ListenTCPAddrPort(netip.AddrPort{})
} }
fa, pn := convertToFullAddr(addr.IP, addr.Port) return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
return gonet.ListenTCP(net.stack, fa, pn)
} }
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
var lfa, rfa *tcpip.FullAddress var lfa, rfa *tcpip.FullAddress
var pn tcpip.NetworkProtocolNumber var pn tcpip.NetworkProtocolNumber
if laddr != nil { if laddr.IsValid() || laddr.Port() > 0 {
var addr tcpip.FullAddress var addr tcpip.FullAddress
addr, pn = convertToFullAddr(laddr.IP, laddr.Port) addr, pn = convertToFullAddr(laddr)
lfa = &addr lfa = &addr
} }
if raddr != nil { if raddr.IsValid() || raddr.Port() > 0 {
var addr tcpip.FullAddress var addr tcpip.FullAddress
addr, pn = convertToFullAddr(raddr.IP, raddr.Port) addr, pn = convertToFullAddr(raddr)
rfa = &addr rfa = &addr
} }
return gonet.DialUDP(net.stack, lfa, rfa, pn) return gonet.DialUDP(net.stack, lfa, rfa, pn)
} }
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
var la, ra netip.AddrPort
if laddr != nil {
la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port))
}
if raddr != nil {
ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port))
}
return net.DialUDPAddrPort(la, ra)
}
var ( var (
errNoSuchHost = errors.New("no such host") errNoSuchHost = errors.New("no such host")
errLameReferral = errors.New("lame referral") errLameReferral = errors.New("lame referral")
@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by
return p, h, nil return p, h, nil
} }
func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
q.Class = dnsmessage.ClassINET q.Class = dnsmessage.ClassINET
id, udpReq, tcpReq, err := newRequest(q) id, udpReq, tcpReq, err := newRequest(q)
if err != nil { if err != nil {
@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest
var c net.Conn var c net.Conn
var err error var err error
if useUDP { if useUDP {
c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53}) c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
} else { } else {
c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53}) c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
} }
if err != nil { if err != nil {
@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
zlen = zidx zlen = zidx
} }
} }
if ip := net.ParseIP(host[:zlen]); ip != nil { if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
return []string{host[:zlen]}, nil return []string{ip.String()}, nil
} }
if !isDomainName(host) { if !isDomainName(host) {
@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
server string server string
error error
} }
var addrsV4, addrsV6 []net.IP var addrsV4, addrsV6 []netip.Addr
lanes := 0 lanes := 0
if tnet.hasV4 { if tnet.hasV4 {
lanes++ lanes++
@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
} }
break loop break loop
} }
addrsV4 = append(addrsV4, net.IP(a.A[:])) addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
case dnsmessage.TypeAAAA: case dnsmessage.TypeAAAA:
aaaa, err := result.p.AAAAResource() aaaa, err := result.p.AAAAResource()
@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
} }
break loop break loop
} }
addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:])) addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
default: default:
if err := result.p.SkipAnswer(); err != nil { if err := result.p.SkipAnswer(); err != nil {
@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
} }
} }
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
var addrs []net.IP var addrs []netip.Addr
if tnet.hasV6 { if tnet.hasV6 {
addrs = append(addrsV6, addrsV4...) addrs = append(addrsV6, addrsV4...)
} else { } else {
@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
if err != nil { if err != nil {
return nil, &net.OpError{Op: "dial", Err: err} return nil, &net.OpError{Op: "dial", Err: err}
} }
var addrs []net.IP var addrs []netip.AddrPort
for _, addr := range allAddr { for _, addr := range allAddr {
if strings.IndexByte(addr, ':') != -1 && acceptV6 { ip, err := netip.ParseAddr(addr)
addrs = append(addrs, net.ParseIP(addr)) if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
} else if strings.IndexByte(addr, '.') != -1 && acceptV4 { addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
addrs = append(addrs, net.ParseIP(addr))
} }
} }
if len(addrs) == 0 && len(allAddr) != 0 { if len(addrs) == 0 && len(allAddr) != 0 {
@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
var c net.Conn var c net.Conn
if useUDP { if useUDP {
c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port}) c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
} else { } else {
c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port}) c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
} }
if err == nil { if err == nil {
return c, nil return c, nil

View file

@ -8,13 +8,13 @@ package tuntest
import ( import (
"encoding/binary" "encoding/binary"
"io" "io"
"net"
"os" "os"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
func Ping(dst, src net.IP) []byte { func Ping(dst, src netip.Addr) []byte {
localPort := uint16(1337) localPort := uint16(1337)
seq := uint16(0) seq := uint16(0)
@ -40,7 +40,7 @@ func checksum(buf []byte, initial uint16) uint16 {
return ^uint16(v) return ^uint16(v)
} }
func genICMPv4(payload []byte, dst, src net.IP) []byte { func genICMPv4(payload []byte, dst, src netip.Addr) []byte {
const ( const (
icmpv4ProtocolNumber = 1 icmpv4ProtocolNumber = 1
icmpv4Echo = 8 icmpv4Echo = 8
@ -70,8 +70,8 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte {
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl ip[8] = ttl
ip[9] = icmpv4ProtocolNumber ip[9] = icmpv4ProtocolNumber
copy(ip[12:], src.To4()) copy(ip[12:], src.AsSlice())
copy(ip[16:], dst.To4()) copy(ip[16:], dst.AsSlice())
chksum = ^checksum(ip[:], 0) chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)