global: use netip where possible now

There are more places where we'll need to add it later, when Go 1.18
comes out with support for it in the "net" package. Also, allowedips
still uses slices internally, which might be suboptimal.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
master
Jason A. Donenfeld 1 year ago
parent de7c702ace
commit ef8d6804d7
  1. 50
      conn/bind_linux.go
  2. 19
      conn/bind_std.go
  3. 19
      conn/bind_windows.go
  4. 14
      conn/bindtest/bindtest.go
  5. 37
      conn/conn.go
  6. 42
      device/allowedips.go
  7. 6
      device/allowedips_rand_test.go
  8. 6
      device/allowedips_test.go
  9. 6
      device/device_test.go
  10. 39
      device/endpoint_test.go
  11. 1
      device/receive.go
  12. 11
      device/uapi.go
  13. 7
      go.mod
  14. 13
      go.sum
  15. 58
      ratelimiter/ratelimiter.go
  16. 33
      ratelimiter/ratelimiter_test.go
  17. 7
      tun/netstack/examples/http_client.go
  18. 6
      tun/netstack/examples/http_server.go
  19. 1
      tun/netstack/go.mod
  20. 4
      tun/netstack/go.sum
  21. 143
      tun/netstack/tun.go
  22. 10
      tun/tuntest/tuntest.go

@ -14,6 +14,7 @@ import (
"unsafe"
"golang.org/x/sys/unix"
"golang.zx2c4.com/go118/netip"
)
type ipv4Source struct {
@ -70,32 +71,30 @@ var _ Bind = (*LinuxSocketBind)(nil)
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
var end LinuxSocketEndpoint
addr, err := parseEndpoint(s)
e, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
ipv4 := addr.IP.To4()
if ipv4 != nil {
if e.Addr().Is4() {
dst := end.dst4()
end.isV6 = false
dst.Port = addr.Port
copy(dst.Addr[:], ipv4)
dst.Port = int(e.Port())
dst.Addr = e.Addr().As4()
end.ClearSrc()
return &end, nil
}
ipv6 := addr.IP.To16()
if ipv6 != nil {
zone, err := zoneToUint32(addr.Zone)
if e.Addr().Is6() {
zone, err := zoneToUint32(e.Addr().Zone())
if err != nil {
return nil, err
}
dst := end.dst6()
end.isV6 = true
dst.Port = addr.Port
dst.Port = int(e.Port())
dst.ZoneId = zone
copy(dst.Addr[:], ipv6[:])
dst.Addr = e.Addr().As16()
end.ClearSrc()
return &end, nil
}
@ -266,29 +265,19 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
}
}
func (end *LinuxSocketEndpoint) SrcIP() net.IP {
func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
if !end.isV6 {
return net.IPv4(
end.src4().Src[0],
end.src4().Src[1],
end.src4().Src[2],
end.src4().Src[3],
)
return netip.AddrFrom4(end.src4().Src)
} else {
return end.src6().src[:]
return netip.AddrFrom16(end.src6().src)
}
}
func (end *LinuxSocketEndpoint) DstIP() net.IP {
func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
if !end.isV6 {
return net.IPv4(
end.dst4().Addr[0],
end.dst4().Addr[1],
end.dst4().Addr[2],
end.dst4().Addr[3],
)
return netip.AddrFrom4(end.dst4().Addr)
} else {
return end.dst6().Addr[:]
return netip.AddrFrom16(end.dst6().Addr)
}
}
@ -305,14 +294,13 @@ func (end *LinuxSocketEndpoint) SrcToString() string {
}
func (end *LinuxSocketEndpoint) DstToString() string {
var udpAddr net.UDPAddr
udpAddr.IP = end.DstIP()
var port int
if !end.isV6 {
udpAddr.Port = end.dst4().Port
port = end.dst4().Port
} else {
udpAddr.Port = end.dst6().Port
port = end.dst6().Port
}
return udpAddr.String()
return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
}
func (end *LinuxSocketEndpoint) ClearDst() {

@ -10,6 +10,8 @@ import (
"net"
"sync"
"syscall"
"golang.zx2c4.com/go118/netip"
)
// StdNetBind is meant to be a temporary solution on platforms for which
@ -32,18 +34,23 @@ var _ Bind = (*StdNetBind)(nil)
var _ Endpoint = (*StdNetEndpoint)(nil)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
addr, err := parseEndpoint(s)
return (*StdNetEndpoint)(addr), err
e, err := netip.ParseAddrPort(s)
return (*StdNetEndpoint)(&net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}), err
}
func (*StdNetEndpoint) ClearSrc() {}
func (e *StdNetEndpoint) DstIP() net.IP {
return (*net.UDPAddr)(e).IP
func (e *StdNetEndpoint) DstIP() netip.Addr {
a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP)
return a
}
func (e *StdNetEndpoint) SrcIP() net.IP {
return nil // not supported
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return netip.Addr{} // not supported
}
func (e *StdNetEndpoint) DstToBytes() []byte {

@ -15,6 +15,7 @@ import (
"unsafe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn/winrio"
)
@ -128,18 +129,18 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
func (*WinRingEndpoint) ClearSrc() {}
func (e *WinRingEndpoint) DstIP() net.IP {
func (e *WinRingEndpoint) DstIP() netip.Addr {
switch e.family {
case windows.AF_INET:
return append([]byte{}, e.data[2:6]...)
return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
case windows.AF_INET6:
return append([]byte{}, e.data[6:22]...)
return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
}
return nil
return netip.Addr{}
}
func (e *WinRingEndpoint) SrcIP() net.IP {
return nil // not supported
func (e *WinRingEndpoint) SrcIP() netip.Addr {
return netip.Addr{} // not supported
}
func (e *WinRingEndpoint) DstToBytes() []byte {
@ -161,15 +162,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
func (e *WinRingEndpoint) DstToString() string {
switch e.family {
case windows.AF_INET:
addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
return addr.String()
netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
case windows.AF_INET6:
var zone string
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
zone = strconv.FormatUint(uint64(scope), 10)
}
addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
return addr.String()
return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
}
return ""
}

@ -10,8 +10,8 @@ import (
"math/rand"
"net"
"os"
"strconv"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
)
@ -61,9 +61,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
func (c ChannelEndpoint) SrcIP() net.IP { return nil }
func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool)
@ -119,13 +119,9 @@ func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
}
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
_, port, err := net.SplitHostPort(s)
addr, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
i, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return nil, err
}
return ChannelEndpoint(i), nil
return ChannelEndpoint(addr.Port()), nil
}

@ -9,10 +9,11 @@ package conn
import (
"errors"
"fmt"
"net"
"reflect"
"runtime"
"strings"
"golang.zx2c4.com/go118/netip"
)
// A ReceiveFunc receives a single inbound packet from the network.
@ -68,8 +69,8 @@ type Endpoint interface {
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP
SrcIP() net.IP
DstIP() netip.Addr
SrcIP() netip.Addr
}
var (
@ -119,33 +120,3 @@ func (fn ReceiveFunc) PrettyName() string {
}
return name
}
func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address
host, _, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
// trying to make sure with a small sanity test that this is a real IP address and
// not something that's likely to incur DNS lookups.
host = host[:i]
}
if ip := net.ParseIP(host); ip == nil {
return nil, errors.New("Failed to parse IP address: " + host)
}
// parse address and port
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
return nil, err
}
ip4 := addr.IP.To4()
if ip4 != nil {
addr.IP = ip4
}
return addr, err
}

@ -12,6 +12,8 @@ import (
"net"
"sync"
"unsafe"
"golang.zx2c4.com/go118/netip"
)
type parentIndirection struct {
@ -26,7 +28,7 @@ type trieEntry struct {
cidr uint8
bitAtByte uint8
bitAtShift uint8
bits net.IP
bits []byte
perPeerElem *list.Element
}
@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 {
return bits.ReverseBytes64(i)
}
func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
func commonBits(ip1, ip2 []byte) uint8 {
size := len(ip1)
if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0]))
@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() {
}
}
func (node *trieEntry) choose(ip net.IP) byte {
func (node *trieEntry) choose(ip []byte) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() {
node.parent.parentBit = nil
}
func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry,
return
}
func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry{
peer: peer,
@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
}
}
func (node *trieEntry) lookup(ip net.IP) *Peer {
func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer
size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
@ -229,13 +231,14 @@ type AllowedIPs struct {
mutex sync.RWMutex
}
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry)
if !cb(node.bits, node.cidr) {
a, _ := netip.AddrFromSlice(node.bits)
if !cb(netip.PrefixFrom(a, int(node.cidr))) {
return
}
}
@ -283,28 +286,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
}
}
func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
switch len(ip) {
case net.IPv6len:
parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
case net.IPv4len:
parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
default:
if prefix.Addr().Is6() {
ip := prefix.Addr().As16()
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else if prefix.Addr().Is4() {
ip := prefix.Addr().As4()
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else {
panic(errors.New("inserting unknown address type"))
}
}
func (table *AllowedIPs) Lookup(address []byte) *Peer {
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
switch len(address) {
switch len(ip) {
case net.IPv6len:
return table.IPv6.lookup(address)
return table.IPv6.lookup(ip)
case net.IPv4len:
return table.IPv4.lookup(address)
return table.IPv4.lookup(ip)
default:
panic(errors.New("looking up unknown address type"))
}

@ -10,6 +10,8 @@ import (
"net"
"sort"
"testing"
"golang.zx2c4.com/go118/netip"
)
const (
@ -93,14 +95,14 @@ func TestTrieRandom(t *testing.T) {
rand.Read(addr4[:])
cidr := uint8(rand.Intn(32) + 1)
index := rand.Intn(NumberOfPeers)
allowedIPs.Insert(addr4[:], cidr, peers[index])
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
allowedIPs.Insert(addr6[:], cidr, peers[index])
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}

@ -9,6 +9,8 @@ import (
"math/rand"
"net"
"testing"
"golang.zx2c4.com/go118/netip"
)
type testPairCommonBits struct {
@ -98,7 +100,7 @@ func TestTrieIPv4(t *testing.T) {
var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
@ -208,7 +210,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
allowedIPs.Insert(addr, cidr, peer)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {

@ -11,7 +11,6 @@ import (
"fmt"
"io"
"math/rand"
"net"
"runtime"
"runtime/pprof"
"sync"
@ -19,6 +18,7 @@ import (
"testing"
"time"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest"
"golang.zx2c4.com/wireguard/tun/tuntest"
@ -96,7 +96,7 @@ type testPair [2]testPeer
type testPeer struct {
tun *tuntest.ChannelTUN
dev *Device
ip net.IP
ip netip.Addr
}
type SendDirection bool
@ -159,7 +159,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
for i := range pair {
p := &pair[i]
p.tun = tuntest.NewChannelTUN()
p.ip = net.IPv4(1, 0, 0, byte(i+1))
p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
level := LogLevelVerbose
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError

@ -7,47 +7,44 @@ package device
import (
"math/rand"
"net"
"golang.zx2c4.com/go118/netip"
)
type DummyEndpoint struct {
src [16]byte
dst [16]byte
src, dst netip.Addr
}
func CreateDummyEndpoint() (*DummyEndpoint, error) {
var end DummyEndpoint
if _, err := rand.Read(end.src[:]); err != nil {
var src, dst [16]byte
if _, err := rand.Read(src[:]); err != nil {
return nil, err
}
_, err := rand.Read(end.dst[:])
return &end, err
_, err := rand.Read(dst[:])
return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
}
func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string {
var addr net.UDPAddr
addr.IP = e.SrcIP()
addr.Port = 1000
return addr.String()
return netip.AddrPortFrom(e.SrcIP(), 1000).String()
}
func (e *DummyEndpoint) DstToString() string {
var addr net.UDPAddr
addr.IP = e.DstIP()
addr.Port = 1000
return addr.String()
return netip.AddrPortFrom(e.DstIP(), 1000).String()
}
func (e *DummyEndpoint) SrcToBytes() []byte {
return e.src[:]
func (e *DummyEndpoint) DstToBytes() []byte {
out := e.DstIP().AsSlice()
out = append(out, byte(1000&0xff))
out = append(out, byte((1000>>8)&0xff))
return out
}
func (e *DummyEndpoint) DstIP() net.IP {
return e.dst[:]
func (e *DummyEndpoint) DstIP() netip.Addr {
return e.dst
}
func (e *DummyEndpoint) SrcIP() net.IP {
return e.src[:]
func (e *DummyEndpoint) SrcIP() netip.Addr {
return e.src
}

@ -17,7 +17,6 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn"
)

@ -18,6 +18,7 @@ import (
"sync/atomic"
"time"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/ipc"
)
@ -121,8 +122,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
sendf("allowed_ip=%s/%d", ip.String(), cidr)
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s", prefix.String())
return true
})
}
@ -374,16 +375,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "allowed_ip":
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
_, network, err := net.ParseCIDR(value)
prefix, err := netip.ParsePrefix(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
}
if peer.dummy {
return nil
}
ones, _ := network.Mask.Size()
device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
device.allowedips.Insert(prefix, peer.Peer)
case "protocol_version":
if value != "1" {

@ -3,8 +3,9 @@ module golang.zx2c4.com/wireguard
go 1.17
require (
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
golang.org/x/net v0.0.0-20211101193420-4a448f8816b3
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa
golang.org/x/net v0.0.0-20211111083644-e5c967477495
golang.org/x/sys v0.0.0-20211110154304-99a53858aa08
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
)

@ -1,16 +1,19 @@
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa h1:idItI2DDfCokpg0N51B2VtiLdJ4vAuXC9fnCb2gACo4=
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211101193420-4a448f8816b3 h1:VrJZAjbekhoRn7n5FBujY31gboH+iB3pdLxn3gE9FjU=
golang.org/x/net v0.0.0-20211101193420-4a448f8816b3/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211111083644-e5c967477495 h1:cjxxlQm6d4kYbhpZ2ghvmI8xnq0AG+jXmzrhzfkyu5A=
golang.org/x/net v0.0.0-20211111083644-e5c967477495/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4=
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211110154304-99a53858aa08 h1:WecRHqgE09JBkh/584XIE6PMz5KKE/vER4izNUi30AQ=
golang.org/x/sys v0.0.0-20211110154304-99a53858aa08/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d h1:9+v0G0naRhLPOJEeJOL6NuXTtAHHwmkyZlgQJ0XcQ8I=
golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=

@ -6,9 +6,10 @@
package ratelimiter
import (
"net"
"sync"
"time"
"golang.zx2c4.com/go118/netip"
)
const (
@ -30,8 +31,7 @@ type Ratelimiter struct {
timeNow func() time.Time
stopReset chan struct{} // send to reset, close to stop
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
table map[netip.Addr]*RatelimiterEntry
}
func (rate *Ratelimiter) Close() {
@ -57,8 +57,7 @@ func (rate *Ratelimiter) Init() {
}
rate.stopReset = make(chan struct{})
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
rate.table = make(map[netip.Addr]*RatelimiterEntry)
stopReset := rate.stopReset // store in case Init is called again.
@ -87,71 +86,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) {
rate.mu.Lock()
defer rate.mu.Unlock()
for key, entry := range rate.tableIPv4 {
for key, entry := range rate.table {
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv4, key)
delete(rate.table, key)
}
entry.mu.Unlock()
}
for key, entry := range rate.tableIPv6 {
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv6, key)
}
entry.mu.Unlock()
}
return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
return len(rate.table) == 0
}
func (rate *Ratelimiter) Allow(ip net.IP) bool {
func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
var entry *RatelimiterEntry
var keyIPv4 [net.IPv4len]byte
var keyIPv6 [net.IPv6len]byte
// lookup entry
IPv4 := ip.To4()
IPv6 := ip.To16()
rate.mu.RLock()
if IPv4 != nil {
copy(keyIPv4[:], IPv4)
entry = rate.tableIPv4[keyIPv4]
} else {
copy(keyIPv6[:], IPv6)
entry = rate.tableIPv6[keyIPv6]
}
entry = rate.table[ip]
rate.mu.RUnlock()
// make new entry if not found
if entry == nil {
entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost
entry.lastTime = rate.timeNow()
rate.mu.Lock()
if IPv4 != nil {
rate.tableIPv4[keyIPv4] = entry
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
rate.stopReset <- struct{}{}
}
} else {
rate.tableIPv6[keyIPv6] = entry
if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
rate.stopReset <- struct{}{}
}
rate.table[ip] = entry
if len(rate.table) == 1 {
rate.stopReset <- struct{}{}
}
rate.mu.Unlock()
return true
}
// add tokens to entry
entry.mu.Lock()
now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
@ -161,7 +128,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
}
// subtract cost of packet
if entry.tokens > packetCost {
entry.tokens -= packetCost
entry.mu.Unlock()

@ -6,9 +6,10 @@
package ratelimiter
import (
"net"
"testing"
"time"
"golang.zx2c4.com/go118/netip"
)
type result struct {
@ -71,21 +72,21 @@ func TestRatelimiter(t *testing.T) {
text: "packet following 2 packet burst",
})
ips := []net.IP{
net.ParseIP("127.0.0.1"),
net.ParseIP("192.168.1.1"),
net.ParseIP("172.167.2.3"),
net.ParseIP("97.231.252.215"),
net.ParseIP("248.97.91.167"),
net.ParseIP("188.208.233.47"),
net.ParseIP("104.2.183.179"),
net.ParseIP("72.129.46.120"),
net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
ips := []netip.Addr{
netip.MustParseAddr("127.0.0.1"),
netip.MustParseAddr("192.168.1.1"),
netip.MustParseAddr("172.167.2.3"),
netip.MustParseAddr("97.231.252.215"),
netip.MustParseAddr("248.97.91.167"),
netip.MustParseAddr("188.208.233.47"),
netip.MustParseAddr("104.2.183.179"),
netip.MustParseAddr("72.129.46.120"),
netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
}
now := time.Now()

@ -1,4 +1,5 @@
//go:build ignore
// +build ignore
/* SPDX-License-Identifier: MIT
*
@ -10,9 +11,9 @@ package main
import (
"io"
"log"
"net"
"net/http"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
@ -20,8 +21,8 @@ import (
func main() {
tun, tnet, err := netstack.CreateNetTUN(
[]net.IP{net.ParseIP("192.168.4.29")},
[]net.IP{net.ParseIP("8.8.8.8")},
[]netip.Addr{netip.MustParseAddr("192.168.4.29")},
[]netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420)
if err != nil {
log.Panic(err)

@ -1,4 +1,5 @@
//go:build ignore
// +build ignore
/* SPDX-License-Identifier: MIT
*
@ -13,6 +14,7 @@ import (
"net"
"net/http"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
@ -20,8 +22,8 @@ import (
func main() {
tun, tnet, err := netstack.CreateNetTUN(
[]net.IP{net.ParseIP("192.168.4.29")},
[]net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")},
[]netip.Addr{netip.MustParseAddr("192.168.4.29")},
[]netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
1420,
)
if err != nil {

@ -6,6 +6,7 @@ require (
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22
gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6
)

@ -805,6 +805,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5 h1:mV4w4F7AtWXoDNkko9odoTdWpNwyDh8jx+S1fOZKDLg=
golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 h1:nFvpdzrHF9IPo9xPgayHWObCATpQYKky8VSSdt9lf9E=
golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw=
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=

@ -18,6 +18,7 @@ import (
"strings"
"time"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage"
@ -38,7 +39,7 @@ type netTun struct {
events chan tun.Event
incomingPacket chan buffer.VectorisedView
mtu int
dnsServers []net.IP
dnsServers []netip.Addr
hasV4, hasV6 bool
}
type endpoint netTun
@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
}
func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) {
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
for _, ip := range localAddresses {
if ip4 := ip.To4(); ip4 != nil {
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.Address(ip4).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr)
}
var protoNumber tcpip.NetworkProtocolNumber
if ip.Is4() {
protoNumber = ipv4.ProtocolNumber
} else if ip.Is6() {
protoNumber = ipv6.ProtocolNumber
}
protoAddr := tcpip.ProtocolAddress{
Protocol: protoNumber,
AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
}
if ip.Is4() {
dev.hasV4 = true
} else {
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.Address(ip).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
}
} else if ip.Is6() {
dev.hasV6 = true
}
}
@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
if ip4 := ip.To4(); ip4 != nil {
return tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(ip4),
Port: uint16(port),
}, ipv4.ProtocolNumber
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
var protoNumber tcpip.NetworkProtocolNumber
if endpoint.Addr().Is4() {
protoNumber = ipv4.ProtocolNumber
} else {
return tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(ip),
Port: uint16(port),
}, ipv6.ProtocolNumber
protoNumber = ipv6.ProtocolNumber
}
return tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(endpoint.Addr().AsSlice()),
Port: endpoint.Port(),
}, protoNumber
}
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
fa, pn := convertToFullAddr(addr)
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
}
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
panic("todo: deal with auto addr semantics for nil addr")
return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
}
fa, pn := convertToFullAddr(addr.IP, addr.Port)
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
}
func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
fa, pn := convertToFullAddr(addr)
return gonet.DialTCP(net.stack, fa, pn)
}
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
panic("todo: deal with auto addr semantics for nil addr")
return net.DialTCPAddrPort(netip.AddrPort{})
}
fa, pn := convertToFullAddr(addr.IP, addr.Port)
return gonet.DialTCP(net.stack, fa, pn)
return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
}
func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
fa, pn := convertToFullAddr(addr)
return gonet.ListenTCP(net.stack, fa, pn)
}
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
if addr == nil {
panic("todo: deal with auto addr semantics for nil addr")
return net.ListenTCPAddrPort(netip.AddrPort{})
}
fa, pn := convertToFullAddr(addr.IP, addr.Port)
return gonet.ListenTCP(net.stack, fa, pn)
return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
}
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
var lfa, rfa *tcpip.FullAddress
var pn tcpip.NetworkProtocolNumber
if laddr != nil {
if laddr.IsValid() || laddr.Port() > 0 {
var addr tcpip.FullAddress
addr, pn = convertToFullAddr(laddr.IP, laddr.Port)
addr, pn = convertToFullAddr(laddr)
lfa = &addr
}
if raddr != nil {
if raddr.IsValid() || raddr.Port() > 0 {
var addr tcpip.FullAddress
addr, pn = convertToFullAddr(raddr.IP, raddr.Port)
addr, pn = convertToFullAddr(raddr)
rfa = &addr
}
return gonet.DialUDP(net.stack, lfa, rfa, pn)
}
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
var la, ra netip.AddrPort
if laddr != nil {
la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port))
}
if raddr != nil {
ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port))
}
return net.DialUDPAddrPort(la, ra)
}
var (
errNoSuchHost = errors.New("no such host")
errLameReferral = errors.New("lame referral")
@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by
return p, h, nil
}
func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
q.Class = dnsmessage.ClassINET
id, udpReq, tcpReq, err := newRequest(q)
if err != nil {
@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest
var c net.Conn
var err error
if useUDP {
c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53})
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
} else {
c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53})
c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
}
if err != nil {
@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
zlen = zidx
}
}
if ip := net.ParseIP(host[:zlen]); ip != nil {
return []string{host[:zlen]}, nil
if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
return []string{ip.String()}, nil
}
if !isDomainName(host) {
@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
server string
error
}
var addrsV4, addrsV6 []net.IP
var addrsV4, addrsV6 []netip.Addr
lanes := 0
if tnet.hasV4 {
lanes++
@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
addrsV4 = append(addrsV4, net.IP(a.A[:]))
addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
case dnsmessage.TypeAAAA:
aaaa, err := result.p.AAAAResource()
@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:]))
addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
default:
if err := result.p.SkipAnswer(); err != nil {
@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
}
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
var addrs []net.IP
var addrs []netip.Addr
if tnet.hasV6 {
addrs = append(addrsV6, addrsV4...)
} else {
@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
if err != nil {
return nil, &net.OpError