conn: introduce new package that splits out the Bind and Endpoint types
The sticky socket code stays in the device package for now, as it reaches deeply into the peer list. This is the first step in an effort to split some code out of the very busy device package. Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
This commit is contained in:
parent
6aefb61355
commit
203554620d
|
@ -3,11 +3,10 @@
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package conn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
@ -18,17 +17,13 @@ const (
|
||||||
sockoptIPV6_UNICAST_IF = 31
|
sockoptIPV6_UNICAST_IF = 31
|
||||||
)
|
)
|
||||||
|
|
||||||
func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||||
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
|
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
|
||||||
bytes := make([]byte, 4)
|
bytes := make([]byte, 4)
|
||||||
binary.BigEndian.PutUint32(bytes, interfaceIndex)
|
binary.BigEndian.PutUint32(bytes, interfaceIndex)
|
||||||
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
|
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
|
||||||
|
|
||||||
if device.net.bind == nil {
|
sysconn, err := bind.ipv4.SyscallConn()
|
||||||
return errors.New("Bind is not yet initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -41,12 +36,12 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bo
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
device.net.bind.(*nativeBind).blackhole4 = blackhole
|
bind.blackhole4 = blackhole
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||||
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
|
sysconn, err := bind.ipv6.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -59,6 +54,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bo
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
device.net.bind.(*nativeBind).blackhole6 = blackhole
|
bind.blackhole6 = blackhole
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
101
conn/conn.go
Normal file
101
conn/conn.go
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package conn implements WireGuard's network connections.
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
||||||
|
type Bind interface {
|
||||||
|
// LastMark reports the last mark set for this Bind.
|
||||||
|
LastMark() uint32
|
||||||
|
|
||||||
|
// SetMark sets the mark for each packet sent through this Bind.
|
||||||
|
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||||
|
SetMark(mark uint32) error
|
||||||
|
|
||||||
|
// ReceiveIPv6 reads an IPv6 UDP packet into b.
|
||||||
|
//
|
||||||
|
// It reports the number of bytes read, n,
|
||||||
|
// the packet source address ep,
|
||||||
|
// and any error.
|
||||||
|
ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error)
|
||||||
|
|
||||||
|
// ReceiveIPv4 reads an IPv4 UDP packet into b.
|
||||||
|
//
|
||||||
|
// It reports the number of bytes read, n,
|
||||||
|
// the packet source address ep,
|
||||||
|
// and any error.
|
||||||
|
ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
|
||||||
|
|
||||||
|
// Send writes a packet b to address ep.
|
||||||
|
Send(b []byte, ep Endpoint) error
|
||||||
|
|
||||||
|
// Close closes the Bind connection.
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateBind creates a Bind bound to a port.
|
||||||
|
//
|
||||||
|
// The value actualPort reports the actual port number the Bind
|
||||||
|
// object gets bound to.
|
||||||
|
func CreateBind(port uint16) (b Bind, actualPort uint16, err error) {
|
||||||
|
return createBind(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BindToInterface is implemented by Bind objects that support being
|
||||||
|
// tied to a single network interface.
|
||||||
|
type BindToInterface interface {
|
||||||
|
BindToInterface4(interfaceIndex uint32, blackhole bool) error
|
||||||
|
BindToInterface6(interfaceIndex uint32, blackhole bool) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// An Endpoint maintains the source/destination caching for a peer.
|
||||||
|
//
|
||||||
|
// dst : the remote address of a peer ("endpoint" in uapi terminology)
|
||||||
|
// src : the local address from which datagrams originate going to the peer
|
||||||
|
type Endpoint interface {
|
||||||
|
ClearSrc() // clears the source address
|
||||||
|
SrcToString() string // returns the local source address (ip:port)
|
||||||
|
DstToString() string // returns the destination address (ip:port)
|
||||||
|
DstToBytes() []byte // used for mac2 cookie calculations
|
||||||
|
DstIP() net.IP
|
||||||
|
SrcIP() net.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
||||||
|
// ensure that the host is an IP address
|
||||||
|
|
||||||
|
host, _, err := net.SplitHostPort(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
|
||||||
|
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
|
||||||
|
// trying to make sure with a small sanity test that this is a real IP address and
|
||||||
|
// not something that's likely to incur DNS lookups.
|
||||||
|
host = host[:i]
|
||||||
|
}
|
||||||
|
if ip := net.ParseIP(host); ip == nil {
|
||||||
|
return nil, errors.New("Failed to parse IP address: " + host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse address and port
|
||||||
|
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ip4 := addr.IP.To4()
|
||||||
|
if ip4 != nil {
|
||||||
|
addr.IP = ip4
|
||||||
|
}
|
||||||
|
return addr, err
|
||||||
|
}
|
|
@ -5,7 +5,7 @@
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package conn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
@ -67,16 +67,13 @@ func (e *NativeEndpoint) SrcToString() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||||
|
|
||||||
// listen
|
|
||||||
|
|
||||||
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// retrieve port
|
// Retrieve port.
|
||||||
|
// TODO(crawshaw): under what circumstances is this necessary?
|
||||||
laddr := conn.LocalAddr()
|
laddr := conn.LocalAddr()
|
||||||
uaddr, err := net.ResolveUDPAddr(
|
uaddr, err := net.ResolveUDPAddr(
|
||||||
laddr.Network(),
|
laddr.Network(),
|
||||||
|
@ -100,7 +97,7 @@ func extractErrno(err error) error {
|
||||||
return syscallErr.Err
|
return syscallErr.Err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
|
func createBind(uport uint16) (Bind, uint16, error) {
|
||||||
var err error
|
var err error
|
||||||
var bind nativeBind
|
var bind nativeBind
|
||||||
|
|
||||||
|
@ -135,6 +132,8 @@ func (bind *nativeBind) Close() error {
|
||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (bind *nativeBind) LastMark() uint32 { return 0 }
|
||||||
|
|
||||||
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
||||||
if bind.ipv4 == nil {
|
if bind.ipv4 == nil {
|
||||||
return 0, nil, syscall.EAFNOSUPPORT
|
return 0, nil, syscall.EAFNOSUPPORT
|
|
@ -3,18 +3,9 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||||
*
|
|
||||||
* This implements userspace semantics of "sticky sockets", modeled after
|
|
||||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
|
||||||
* of the sticky-sockets.c example code:
|
|
||||||
* https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c
|
|
||||||
*
|
|
||||||
* Currently there is no way to achieve this within the net package:
|
|
||||||
* See e.g. https://github.com/golang/go/issues/17930
|
|
||||||
* So this code is remains platform dependent.
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package conn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -25,7 +16,6 @@ import (
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/rwcancel"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -33,8 +23,8 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
type IPv4Source struct {
|
type IPv4Source struct {
|
||||||
src [4]byte
|
Src [4]byte
|
||||||
ifindex int32
|
Ifindex int32
|
||||||
}
|
}
|
||||||
|
|
||||||
type IPv6Source struct {
|
type IPv6Source struct {
|
||||||
|
@ -49,6 +39,10 @@ type NativeEndpoint struct {
|
||||||
isV6 bool
|
isV6 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() }
|
||||||
|
func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
|
||||||
|
func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 }
|
||||||
|
|
||||||
func (endpoint *NativeEndpoint) src4() *IPv4Source {
|
func (endpoint *NativeEndpoint) src4() *IPv4Source {
|
||||||
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
|
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
|
||||||
}
|
}
|
||||||
|
@ -68,8 +62,6 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
|
||||||
type nativeBind struct {
|
type nativeBind struct {
|
||||||
sock4 int
|
sock4 int
|
||||||
sock6 int
|
sock6 int
|
||||||
netlinkSock int
|
|
||||||
netlinkCancel *rwcancel.RWCancel
|
|
||||||
lastMark uint32
|
lastMark uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,59 +103,25 @@ func CreateEndpoint(s string) (Endpoint, error) {
|
||||||
return nil, errors.New("Invalid IP address")
|
return nil, errors.New("Invalid IP address")
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNetlinkRouteSocket() (int, error) {
|
func createBind(port uint16) (Bind, uint16, error) {
|
||||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
saddr := &unix.SockaddrNetlink{
|
|
||||||
Family: unix.AF_NETLINK,
|
|
||||||
Groups: unix.RTMGRP_IPV4_ROUTE,
|
|
||||||
}
|
|
||||||
err = unix.Bind(sock, saddr)
|
|
||||||
if err != nil {
|
|
||||||
unix.Close(sock)
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
return sock, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
|
|
||||||
var err error
|
var err error
|
||||||
var bind nativeBind
|
var bind nativeBind
|
||||||
var newPort uint16
|
var newPort uint16
|
||||||
|
|
||||||
bind.netlinkSock, err = createNetlinkRouteSocket()
|
// Attempt ipv6 bind, update port if successful.
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
|
|
||||||
if err != nil {
|
|
||||||
unix.Close(bind.netlinkSock)
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
go bind.routineRouteListener(device)
|
|
||||||
|
|
||||||
// attempt ipv6 bind, update port if successful
|
|
||||||
|
|
||||||
bind.sock6, newPort, err = create6(port)
|
bind.sock6, newPort, err = create6(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != syscall.EAFNOSUPPORT {
|
if err != syscall.EAFNOSUPPORT {
|
||||||
bind.netlinkCancel.Cancel()
|
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
port = newPort
|
port = newPort
|
||||||
}
|
}
|
||||||
|
|
||||||
// attempt ipv4 bind, update port if successful
|
// Attempt ipv4 bind, update port if successful.
|
||||||
|
|
||||||
bind.sock4, newPort, err = create4(port)
|
bind.sock4, newPort, err = create4(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != syscall.EAFNOSUPPORT {
|
if err != syscall.EAFNOSUPPORT {
|
||||||
bind.netlinkCancel.Cancel()
|
|
||||||
unix.Close(bind.sock6)
|
unix.Close(bind.sock6)
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
@ -178,6 +136,10 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
|
||||||
return &bind, port, nil
|
return &bind, port, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (bind *nativeBind) LastMark() uint32 {
|
||||||
|
return bind.lastMark
|
||||||
|
}
|
||||||
|
|
||||||
func (bind *nativeBind) SetMark(value uint32) error {
|
func (bind *nativeBind) SetMark(value uint32) error {
|
||||||
if bind.sock6 != -1 {
|
if bind.sock6 != -1 {
|
||||||
err := unix.SetsockoptInt(
|
err := unix.SetsockoptInt(
|
||||||
|
@ -216,23 +178,19 @@ func closeUnblock(fd int) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *nativeBind) Close() error {
|
func (bind *nativeBind) Close() error {
|
||||||
var err1, err2, err3 error
|
var err1, err2 error
|
||||||
if bind.sock6 != -1 {
|
if bind.sock6 != -1 {
|
||||||
err1 = closeUnblock(bind.sock6)
|
err1 = closeUnblock(bind.sock6)
|
||||||
}
|
}
|
||||||
if bind.sock4 != -1 {
|
if bind.sock4 != -1 {
|
||||||
err2 = closeUnblock(bind.sock4)
|
err2 = closeUnblock(bind.sock4)
|
||||||
}
|
}
|
||||||
err3 = bind.netlinkCancel.Cancel()
|
|
||||||
|
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
return err1
|
return err1
|
||||||
}
|
}
|
||||||
if err2 != nil {
|
|
||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
return err3
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||||
var end NativeEndpoint
|
var end NativeEndpoint
|
||||||
|
@ -278,10 +236,10 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
|
||||||
func (end *NativeEndpoint) SrcIP() net.IP {
|
func (end *NativeEndpoint) SrcIP() net.IP {
|
||||||
if !end.isV6 {
|
if !end.isV6 {
|
||||||
return net.IPv4(
|
return net.IPv4(
|
||||||
end.src4().src[0],
|
end.src4().Src[0],
|
||||||
end.src4().src[1],
|
end.src4().Src[1],
|
||||||
end.src4().src[2],
|
end.src4().Src[2],
|
||||||
end.src4().src[3],
|
end.src4().Src[3],
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
return end.src6().src[:]
|
return end.src6().src[:]
|
||||||
|
@ -478,8 +436,8 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
|
||||||
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
|
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
|
||||||
},
|
},
|
||||||
unix.Inet4Pktinfo{
|
unix.Inet4Pktinfo{
|
||||||
Spec_dst: end.src4().src,
|
Spec_dst: end.src4().Src,
|
||||||
Ifindex: end.src4().ifindex,
|
Ifindex: end.src4().Ifindex,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -573,8 +531,8 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
||||||
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
||||||
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
||||||
end.src4().src = cmsg.pktinfo.Spec_dst
|
end.src4().Src = cmsg.pktinfo.Spec_dst
|
||||||
end.src4().ifindex = cmsg.pktinfo.Ifindex
|
end.src4().Ifindex = cmsg.pktinfo.Ifindex
|
||||||
}
|
}
|
||||||
|
|
||||||
return size, nil
|
return size, nil
|
||||||
|
@ -611,156 +569,3 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||||
|
|
||||||
return size, nil
|
return size, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *nativeBind) routineRouteListener(device *Device) {
|
|
||||||
type peerEndpointPtr struct {
|
|
||||||
peer *Peer
|
|
||||||
endpoint *Endpoint
|
|
||||||
}
|
|
||||||
var reqPeer map[uint32]peerEndpointPtr
|
|
||||||
var reqPeerLock sync.Mutex
|
|
||||||
|
|
||||||
defer unix.Close(bind.netlinkSock)
|
|
||||||
|
|
||||||
for msg := make([]byte, 1<<16); ; {
|
|
||||||
var err error
|
|
||||||
var msgn int
|
|
||||||
for {
|
|
||||||
msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
|
|
||||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if !bind.netlinkCancel.ReadyRead() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
|
||||||
|
|
||||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
|
||||||
|
|
||||||
if uint(hdr.Len) > uint(len(remain)) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
switch hdr.Type {
|
|
||||||
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
|
|
||||||
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
|
|
||||||
if uint(len(remain)) < uint(hdr.Len) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
|
|
||||||
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
|
|
||||||
for {
|
|
||||||
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
|
|
||||||
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
|
|
||||||
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
|
|
||||||
reqPeerLock.Lock()
|
|
||||||
if reqPeer == nil {
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pePtr, ok := reqPeer[hdr.Seq]
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pePtr.peer.Lock()
|
|
||||||
if &pePtr.peer.endpoint != pePtr.endpoint {
|
|
||||||
pePtr.peer.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
|
|
||||||
pePtr.peer.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
|
|
||||||
pePtr.peer.Unlock()
|
|
||||||
}
|
|
||||||
attr = attr[attrhdr.Len:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
reqPeerLock.Lock()
|
|
||||||
reqPeer = make(map[uint32]peerEndpointPtr)
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
go func() {
|
|
||||||
device.peers.RLock()
|
|
||||||
i := uint32(1)
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.RLock()
|
|
||||||
if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
|
|
||||||
peer.RUnlock()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
|
|
||||||
peer.RUnlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
nlmsg := struct {
|
|
||||||
hdr unix.NlMsghdr
|
|
||||||
msg unix.RtMsg
|
|
||||||
dsthdr unix.RtAttr
|
|
||||||
dst [4]byte
|
|
||||||
srchdr unix.RtAttr
|
|
||||||
src [4]byte
|
|
||||||
markhdr unix.RtAttr
|
|
||||||
mark uint32
|
|
||||||
}{
|
|
||||||
unix.NlMsghdr{
|
|
||||||
Type: uint16(unix.RTM_GETROUTE),
|
|
||||||
Flags: unix.NLM_F_REQUEST,
|
|
||||||
Seq: i,
|
|
||||||
},
|
|
||||||
unix.RtMsg{
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Dst_len: 32,
|
|
||||||
Src_len: 32,
|
|
||||||
},
|
|
||||||
unix.RtAttr{
|
|
||||||
Len: 8,
|
|
||||||
Type: unix.RTA_DST,
|
|
||||||
},
|
|
||||||
peer.endpoint.(*NativeEndpoint).dst4().Addr,
|
|
||||||
unix.RtAttr{
|
|
||||||
Len: 8,
|
|
||||||
Type: unix.RTA_SRC,
|
|
||||||
},
|
|
||||||
peer.endpoint.(*NativeEndpoint).src4().src,
|
|
||||||
unix.RtAttr{
|
|
||||||
Len: 8,
|
|
||||||
Type: unix.RTA_MARK,
|
|
||||||
},
|
|
||||||
uint32(bind.lastMark),
|
|
||||||
}
|
|
||||||
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
|
||||||
reqPeerLock.Lock()
|
|
||||||
reqPeer[i] = peerEndpointPtr{
|
|
||||||
peer: peer,
|
|
||||||
endpoint: &peer.endpoint,
|
|
||||||
}
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
peer.RUnlock()
|
|
||||||
i++
|
|
||||||
_, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
remain = remain[hdr.Len:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -5,7 +5,7 @@
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package conn
|
||||||
|
|
||||||
func (bind *nativeBind) SetMark(mark uint32) error {
|
func (bind *nativeBind) SetMark(mark uint32) error {
|
||||||
return nil
|
return nil
|
|
@ -5,7 +5,7 @@
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package conn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"runtime"
|
"runtime"
|
|
@ -5,11 +5,15 @@
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import "errors"
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
type DummyDatagram struct {
|
type DummyDatagram struct {
|
||||||
msg []byte
|
msg []byte
|
||||||
endpoint Endpoint
|
endpoint conn.Endpoint
|
||||||
world bool // better type
|
world bool // better type
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,7 +29,7 @@ func (b *DummyBind) SetMark(v uint32) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
|
||||||
datagram, ok := <-b.in6
|
datagram, ok := <-b.in6
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, nil, errors.New("closed")
|
return 0, nil, errors.New("closed")
|
||||||
|
@ -34,7 +38,7 @@ func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||||
return len(datagram.msg), datagram.endpoint, nil
|
return len(datagram.msg), datagram.endpoint, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
|
||||||
datagram, ok := <-b.in4
|
datagram, ok := <-b.in4
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, nil, errors.New("closed")
|
return 0, nil, errors.New("closed")
|
||||||
|
@ -50,6 +54,6 @@ func (b *DummyBind) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) Send(buff []byte, end Endpoint) error {
|
func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
36
device/bindsocketshim.go
Normal file
36
device/bindsocketshim.go
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn.
|
||||||
|
func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||||
|
if device.net.bind == nil {
|
||||||
|
return errors.New("Bind is not yet initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface, ok := device.net.bind.(conn.BindToInterface); ok {
|
||||||
|
return iface.BindToInterface4(interfaceIndex, blackhole)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn.
|
||||||
|
func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||||
|
if device.net.bind == nil {
|
||||||
|
return errors.New("Bind is not yet initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface, ok := device.net.bind.(conn.BindToInterface); ok {
|
||||||
|
return iface.BindToInterface6(interfaceIndex, blackhole)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
187
device/conn.go
187
device/conn.go
|
@ -1,187 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.org/x/net/ipv6"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ConnRoutineNumber = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
|
|
||||||
*/
|
|
||||||
type Bind interface {
|
|
||||||
SetMark(value uint32) error
|
|
||||||
ReceiveIPv6(buff []byte) (int, Endpoint, error)
|
|
||||||
ReceiveIPv4(buff []byte) (int, Endpoint, error)
|
|
||||||
Send(buff []byte, end Endpoint) error
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
/* An Endpoint maintains the source/destination caching for a peer
|
|
||||||
*
|
|
||||||
* dst : the remote address of a peer ("endpoint" in uapi terminology)
|
|
||||||
* src : the local address from which datagrams originate going to the peer
|
|
||||||
*/
|
|
||||||
type Endpoint interface {
|
|
||||||
ClearSrc() // clears the source address
|
|
||||||
SrcToString() string // returns the local source address (ip:port)
|
|
||||||
DstToString() string // returns the destination address (ip:port)
|
|
||||||
DstToBytes() []byte // used for mac2 cookie calculations
|
|
||||||
DstIP() net.IP
|
|
||||||
SrcIP() net.IP
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
|
||||||
// ensure that the host is an IP address
|
|
||||||
|
|
||||||
host, _, err := net.SplitHostPort(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
|
|
||||||
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
|
|
||||||
// trying to make sure with a small sanity test that this is a real IP address and
|
|
||||||
// not something that's likely to incur DNS lookups.
|
|
||||||
host = host[:i]
|
|
||||||
}
|
|
||||||
if ip := net.ParseIP(host); ip == nil {
|
|
||||||
return nil, errors.New("Failed to parse IP address: " + host)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse address and port
|
|
||||||
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ip4 := addr.IP.To4()
|
|
||||||
if ip4 != nil {
|
|
||||||
addr.IP = ip4
|
|
||||||
}
|
|
||||||
return addr, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func unsafeCloseBind(device *Device) error {
|
|
||||||
var err error
|
|
||||||
netc := &device.net
|
|
||||||
if netc.bind != nil {
|
|
||||||
err = netc.bind.Close()
|
|
||||||
netc.bind = nil
|
|
||||||
}
|
|
||||||
netc.stopping.Wait()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) BindSetMark(mark uint32) error {
|
|
||||||
|
|
||||||
device.net.Lock()
|
|
||||||
defer device.net.Unlock()
|
|
||||||
|
|
||||||
// check if modified
|
|
||||||
|
|
||||||
if device.net.fwmark == mark {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// update fwmark on existing bind
|
|
||||||
|
|
||||||
device.net.fwmark = mark
|
|
||||||
if device.isUp.Get() && device.net.bind != nil {
|
|
||||||
if err := device.net.bind.SetMark(mark); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// clear cached source addresses
|
|
||||||
|
|
||||||
device.peers.RLock()
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.Lock()
|
|
||||||
defer peer.Unlock()
|
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) BindUpdate() error {
|
|
||||||
|
|
||||||
device.net.Lock()
|
|
||||||
defer device.net.Unlock()
|
|
||||||
|
|
||||||
// close existing sockets
|
|
||||||
|
|
||||||
if err := unsafeCloseBind(device); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// open new sockets
|
|
||||||
|
|
||||||
if device.isUp.Get() {
|
|
||||||
|
|
||||||
// bind to new port
|
|
||||||
|
|
||||||
var err error
|
|
||||||
netc := &device.net
|
|
||||||
netc.bind, netc.port, err = CreateBind(netc.port, device)
|
|
||||||
if err != nil {
|
|
||||||
netc.bind = nil
|
|
||||||
netc.port = 0
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// set fwmark
|
|
||||||
|
|
||||||
if netc.fwmark != 0 {
|
|
||||||
err = netc.bind.SetMark(netc.fwmark)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// clear cached source addresses
|
|
||||||
|
|
||||||
device.peers.RLock()
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.Lock()
|
|
||||||
defer peer.Unlock()
|
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
|
|
||||||
// start receiving routines
|
|
||||||
|
|
||||||
device.net.starting.Add(ConnRoutineNumber)
|
|
||||||
device.net.stopping.Add(ConnRoutineNumber)
|
|
||||||
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
|
||||||
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
|
||||||
device.net.starting.Wait()
|
|
||||||
|
|
||||||
device.log.Debug.Println("UDP bind has been updated")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) BindClose() error {
|
|
||||||
device.net.Lock()
|
|
||||||
err := unsafeCloseBind(device)
|
|
||||||
device.net.Unlock()
|
|
||||||
return err
|
|
||||||
}
|
|
142
device/device.go
142
device/device.go
|
@ -11,15 +11,14 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/ratelimiter"
|
"golang.zx2c4.com/wireguard/ratelimiter"
|
||||||
|
"golang.zx2c4.com/wireguard/rwcancel"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
DeviceRoutineNumberPerCPU = 3
|
|
||||||
DeviceRoutineNumberAdditional = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
type Device struct {
|
type Device struct {
|
||||||
isUp AtomicBool // device is (going) up
|
isUp AtomicBool // device is (going) up
|
||||||
isClosed AtomicBool // device is closed? (acting as guard)
|
isClosed AtomicBool // device is closed? (acting as guard)
|
||||||
|
@ -39,7 +38,8 @@ type Device struct {
|
||||||
starting sync.WaitGroup
|
starting sync.WaitGroup
|
||||||
stopping sync.WaitGroup
|
stopping sync.WaitGroup
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
bind Bind // bind interface
|
bind conn.Bind // bind interface
|
||||||
|
netlinkCancel *rwcancel.RWCancel
|
||||||
port uint16 // listening port
|
port uint16 // listening port
|
||||||
fwmark uint32 // mark value (0 = disabled)
|
fwmark uint32 // mark value (0 = disabled)
|
||||||
}
|
}
|
||||||
|
@ -299,14 +299,16 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
||||||
cpus := runtime.NumCPU()
|
cpus := runtime.NumCPU()
|
||||||
device.state.starting.Wait()
|
device.state.starting.Wait()
|
||||||
device.state.stopping.Wait()
|
device.state.stopping.Wait()
|
||||||
device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
|
|
||||||
device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
|
|
||||||
for i := 0; i < cpus; i += 1 {
|
for i := 0; i < cpus; i += 1 {
|
||||||
|
device.state.starting.Add(3)
|
||||||
|
device.state.stopping.Add(3)
|
||||||
go device.RoutineEncryption()
|
go device.RoutineEncryption()
|
||||||
go device.RoutineDecryption()
|
go device.RoutineDecryption()
|
||||||
go device.RoutineHandshake()
|
go device.RoutineHandshake()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
device.state.starting.Add(2)
|
||||||
|
device.state.stopping.Add(2)
|
||||||
go device.RoutineReadFromTUN()
|
go device.RoutineReadFromTUN()
|
||||||
go device.RoutineTUNEventReader()
|
go device.RoutineTUNEventReader()
|
||||||
|
|
||||||
|
@ -413,3 +415,127 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
||||||
}
|
}
|
||||||
device.peers.RUnlock()
|
device.peers.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func unsafeCloseBind(device *Device) error {
|
||||||
|
var err error
|
||||||
|
netc := &device.net
|
||||||
|
if netc.netlinkCancel != nil {
|
||||||
|
netc.netlinkCancel.Cancel()
|
||||||
|
}
|
||||||
|
if netc.bind != nil {
|
||||||
|
err = netc.bind.Close()
|
||||||
|
netc.bind = nil
|
||||||
|
}
|
||||||
|
netc.stopping.Wait()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) BindSetMark(mark uint32) error {
|
||||||
|
|
||||||
|
device.net.Lock()
|
||||||
|
defer device.net.Unlock()
|
||||||
|
|
||||||
|
// check if modified
|
||||||
|
|
||||||
|
if device.net.fwmark == mark {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// update fwmark on existing bind
|
||||||
|
|
||||||
|
device.net.fwmark = mark
|
||||||
|
if device.isUp.Get() && device.net.bind != nil {
|
||||||
|
if err := device.net.bind.SetMark(mark); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear cached source addresses
|
||||||
|
|
||||||
|
device.peers.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.Lock()
|
||||||
|
defer peer.Unlock()
|
||||||
|
if peer.endpoint != nil {
|
||||||
|
peer.endpoint.ClearSrc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
device.peers.RUnlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) BindUpdate() error {
|
||||||
|
|
||||||
|
device.net.Lock()
|
||||||
|
defer device.net.Unlock()
|
||||||
|
|
||||||
|
// close existing sockets
|
||||||
|
|
||||||
|
if err := unsafeCloseBind(device); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// open new sockets
|
||||||
|
|
||||||
|
if device.isUp.Get() {
|
||||||
|
|
||||||
|
// bind to new port
|
||||||
|
|
||||||
|
var err error
|
||||||
|
netc := &device.net
|
||||||
|
netc.bind, netc.port, err = conn.CreateBind(netc.port)
|
||||||
|
if err != nil {
|
||||||
|
netc.bind = nil
|
||||||
|
netc.port = 0
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
||||||
|
if err != nil {
|
||||||
|
netc.bind.Close()
|
||||||
|
netc.bind = nil
|
||||||
|
netc.port = 0
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// set fwmark
|
||||||
|
|
||||||
|
if netc.fwmark != 0 {
|
||||||
|
err = netc.bind.SetMark(netc.fwmark)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear cached source addresses
|
||||||
|
|
||||||
|
device.peers.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.Lock()
|
||||||
|
defer peer.Unlock()
|
||||||
|
if peer.endpoint != nil {
|
||||||
|
peer.endpoint.ClearSrc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
device.peers.RUnlock()
|
||||||
|
|
||||||
|
// start receiving routines
|
||||||
|
|
||||||
|
device.net.starting.Add(2)
|
||||||
|
device.net.stopping.Add(2)
|
||||||
|
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
||||||
|
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
||||||
|
device.net.starting.Wait()
|
||||||
|
|
||||||
|
device.log.Debug.Println("UDP bind has been updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) BindClose() error {
|
||||||
|
device.net.Lock()
|
||||||
|
err := unsafeCloseBind(device)
|
||||||
|
device.net.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
|
@ -12,6 +12,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -24,7 +26,7 @@ type Peer struct {
|
||||||
keypairs Keypairs
|
keypairs Keypairs
|
||||||
handshake Handshake
|
handshake Handshake
|
||||||
device *Device
|
device *Device
|
||||||
endpoint Endpoint
|
endpoint conn.Endpoint
|
||||||
persistentKeepaliveInterval uint16
|
persistentKeepaliveInterval uint16
|
||||||
|
|
||||||
// These fields are accessed with atomic operations, which must be
|
// These fields are accessed with atomic operations, which must be
|
||||||
|
@ -290,7 +292,7 @@ func (peer *Peer) Stop() {
|
||||||
|
|
||||||
var RoamingDisabled bool
|
var RoamingDisabled bool
|
||||||
|
|
||||||
func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
|
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
|
||||||
if RoamingDisabled {
|
if RoamingDisabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,12 +17,13 @@ import (
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
type QueueHandshakeElement struct {
|
type QueueHandshakeElement struct {
|
||||||
msgType uint32
|
msgType uint32
|
||||||
packet []byte
|
packet []byte
|
||||||
endpoint Endpoint
|
endpoint conn.Endpoint
|
||||||
buffer *[MaxMessageSize]byte
|
buffer *[MaxMessageSize]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,7 +34,7 @@ type QueueInboundElement struct {
|
||||||
packet []byte
|
packet []byte
|
||||||
counter uint64
|
counter uint64
|
||||||
keypair *Keypair
|
keypair *Keypair
|
||||||
endpoint Endpoint
|
endpoint conn.Endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
func (elem *QueueInboundElement) Drop() {
|
func (elem *QueueInboundElement) Drop() {
|
||||||
|
@ -90,7 +91,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
|
||||||
* Every time the bind is updated a new routine is started for
|
* Every time the bind is updated a new routine is started for
|
||||||
* IPv4 and IPv6 (separately)
|
* IPv4 and IPv6 (separately)
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
|
func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
|
||||||
|
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -108,7 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
size int
|
size int
|
||||||
endpoint Endpoint
|
endpoint conn.Endpoint
|
||||||
)
|
)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|
12
device/sticky_default.go
Normal file
12
device/sticky_default.go
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
// +build !linux
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
"golang.zx2c4.com/wireguard/rwcancel"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
215
device/sticky_linux.go
Normal file
215
device/sticky_linux.go
Normal file
|
@ -0,0 +1,215 @@
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* This implements userspace semantics of "sticky sockets", modeled after
|
||||||
|
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||||
|
* of the sticky-sockets.c example code:
|
||||||
|
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
|
||||||
|
*
|
||||||
|
* Currently there is no way to achieve this within the net package:
|
||||||
|
* See e.g. https://github.com/golang/go/issues/17930
|
||||||
|
* So this code is remains platform dependent.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
"golang.zx2c4.com/wireguard/rwcancel"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||||
|
netlinkSock, err := createNetlinkRouteSocket()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(netlinkSock)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
|
||||||
|
|
||||||
|
return netlinkCancel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
||||||
|
type peerEndpointPtr struct {
|
||||||
|
peer *Peer
|
||||||
|
endpoint *conn.Endpoint
|
||||||
|
}
|
||||||
|
var reqPeer map[uint32]peerEndpointPtr
|
||||||
|
var reqPeerLock sync.Mutex
|
||||||
|
|
||||||
|
defer unix.Close(netlinkSock)
|
||||||
|
|
||||||
|
for msg := make([]byte, 1<<16); ; {
|
||||||
|
var err error
|
||||||
|
var msgn int
|
||||||
|
for {
|
||||||
|
msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
|
||||||
|
if err == nil || !rwcancel.RetryAfterError(err) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !netlinkCancel.ReadyRead() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||||
|
|
||||||
|
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||||
|
|
||||||
|
if uint(hdr.Len) > uint(len(remain)) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
switch hdr.Type {
|
||||||
|
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
|
||||||
|
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
|
||||||
|
if uint(len(remain)) < uint(hdr.Len) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
|
||||||
|
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
|
||||||
|
for {
|
||||||
|
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
|
||||||
|
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
|
||||||
|
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
|
||||||
|
reqPeerLock.Lock()
|
||||||
|
if reqPeer == nil {
|
||||||
|
reqPeerLock.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pePtr, ok := reqPeer[hdr.Seq]
|
||||||
|
reqPeerLock.Unlock()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pePtr.peer.Lock()
|
||||||
|
if &pePtr.peer.endpoint != pePtr.endpoint {
|
||||||
|
pePtr.peer.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx {
|
||||||
|
pePtr.peer.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc()
|
||||||
|
pePtr.peer.Unlock()
|
||||||
|
}
|
||||||
|
attr = attr[attrhdr.Len:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
reqPeerLock.Lock()
|
||||||
|
reqPeer = make(map[uint32]peerEndpointPtr)
|
||||||
|
reqPeerLock.Unlock()
|
||||||
|
go func() {
|
||||||
|
device.peers.RLock()
|
||||||
|
i := uint32(1)
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.RLock()
|
||||||
|
if peer.endpoint == nil {
|
||||||
|
peer.RUnlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint)
|
||||||
|
if nativeEP == nil {
|
||||||
|
peer.RUnlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
|
||||||
|
peer.RUnlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
nlmsg := struct {
|
||||||
|
hdr unix.NlMsghdr
|
||||||
|
msg unix.RtMsg
|
||||||
|
dsthdr unix.RtAttr
|
||||||
|
dst [4]byte
|
||||||
|
srchdr unix.RtAttr
|
||||||
|
src [4]byte
|
||||||
|
markhdr unix.RtAttr
|
||||||
|
mark uint32
|
||||||
|
}{
|
||||||
|
unix.NlMsghdr{
|
||||||
|
Type: uint16(unix.RTM_GETROUTE),
|
||||||
|
Flags: unix.NLM_F_REQUEST,
|
||||||
|
Seq: i,
|
||||||
|
},
|
||||||
|
unix.RtMsg{
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Dst_len: 32,
|
||||||
|
Src_len: 32,
|
||||||
|
},
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: unix.RTA_DST,
|
||||||
|
},
|
||||||
|
nativeEP.Dst4().Addr,
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: unix.RTA_SRC,
|
||||||
|
},
|
||||||
|
nativeEP.Src4().Src,
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: unix.RTA_MARK,
|
||||||
|
},
|
||||||
|
uint32(bind.LastMark()),
|
||||||
|
}
|
||||||
|
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
||||||
|
reqPeerLock.Lock()
|
||||||
|
reqPeer[i] = peerEndpointPtr{
|
||||||
|
peer: peer,
|
||||||
|
endpoint: &peer.endpoint,
|
||||||
|
}
|
||||||
|
reqPeerLock.Unlock()
|
||||||
|
peer.RUnlock()
|
||||||
|
i++
|
||||||
|
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
device.peers.RUnlock()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
remain = remain[hdr.Len:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createNetlinkRouteSocket() (int, error) {
|
||||||
|
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
saddr := &unix.SockaddrNetlink{
|
||||||
|
Family: unix.AF_NETLINK,
|
||||||
|
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
|
||||||
|
}
|
||||||
|
err = unix.Bind(sock, saddr)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(sock)
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
return sock, nil
|
||||||
|
}
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -306,7 +307,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
||||||
err := func() error {
|
err := func() error {
|
||||||
peer.Lock()
|
peer.Lock()
|
||||||
defer peer.Unlock()
|
defer peer.Unlock()
|
||||||
endpoint, err := CreateEndpoint(value)
|
endpoint, err := conn.CreateEndpoint(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue