all: make conn.Bind.Open return a slice of receive functions
Instead of hard-coding exactly two sources from which to receive packets (an IPv4 source and an IPv6 source), allow the conn.Bind to specify a set of sources. Beneficial consequences: * If there's no IPv6 support on a system, conn.Bind.Open can choose not to return a receive function for it, which is simpler than tracking that state in the bind. This simplification removes existing data races from both conn.StdNetBind and bindtest.ChannelBind. * If there are more than two sources on a system, the conn.Bind no longer needs to add a separate muxing layer. Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
This commit is contained in:
parent
8ed83e0427
commit
10533c3e73
|
@ -55,10 +55,11 @@ func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
|
||||||
|
|
||||||
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
|
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
|
||||||
type LinuxSocketBind struct {
|
type LinuxSocketBind struct {
|
||||||
|
// mu guards sock4 and sock6 and the associated fds.
|
||||||
|
// As long as someone holds mu (read or write), the associated fds are valid.
|
||||||
|
mu sync.RWMutex
|
||||||
sock4 int
|
sock4 int
|
||||||
sock6 int
|
sock6 int
|
||||||
lastMark uint32
|
|
||||||
closing sync.RWMutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
|
func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
|
||||||
|
@ -102,54 +103,67 @@ func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||||
return nil, errors.New("invalid IP address")
|
return nil, errors.New("invalid IP address")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *LinuxSocketBind) Open(port uint16) (uint16, error) {
|
func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
|
||||||
|
bind.mu.Lock()
|
||||||
|
defer bind.mu.Unlock()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
var newPort uint16
|
var newPort uint16
|
||||||
var tries int
|
var tries int
|
||||||
|
|
||||||
if bind.sock4 != -1 || bind.sock6 != -1 {
|
if bind.sock4 != -1 || bind.sock6 != -1 {
|
||||||
return 0, ErrBindAlreadyOpen
|
return nil, 0, ErrBindAlreadyOpen
|
||||||
}
|
}
|
||||||
|
|
||||||
originalPort := port
|
originalPort := port
|
||||||
|
|
||||||
again:
|
again:
|
||||||
port = originalPort
|
port = originalPort
|
||||||
|
var sock4, sock6 int
|
||||||
// Attempt ipv6 bind, update port if successful.
|
// Attempt ipv6 bind, update port if successful.
|
||||||
bind.sock6, newPort, err = create6(port)
|
sock6, newPort, err = create6(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != syscall.EAFNOSUPPORT {
|
if !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
return 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
port = newPort
|
port = newPort
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt ipv4 bind, update port if successful.
|
// Attempt ipv4 bind, update port if successful.
|
||||||
bind.sock4, newPort, err = create4(port)
|
sock4, newPort, err = create4(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if originalPort == 0 && err == syscall.EADDRINUSE && tries < 100 {
|
if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
||||||
unix.Close(bind.sock6)
|
unix.Close(sock6)
|
||||||
tries++
|
tries++
|
||||||
goto again
|
goto again
|
||||||
}
|
}
|
||||||
if err != syscall.EAFNOSUPPORT {
|
if !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
unix.Close(bind.sock6)
|
unix.Close(sock6)
|
||||||
return 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
port = newPort
|
port = newPort
|
||||||
}
|
}
|
||||||
|
|
||||||
if bind.sock4 == -1 && bind.sock6 == -1 {
|
var fns []ReceiveFunc
|
||||||
return 0, syscall.EAFNOSUPPORT
|
if sock4 != -1 {
|
||||||
|
fns = append(fns, makeReceiveIPv4(sock4))
|
||||||
|
bind.sock4 = sock4
|
||||||
}
|
}
|
||||||
return port, nil
|
if sock6 != -1 {
|
||||||
|
fns = append(fns, makeReceiveIPv6(sock6))
|
||||||
|
bind.sock6 = sock6
|
||||||
|
}
|
||||||
|
if len(fns) == 0 {
|
||||||
|
return nil, 0, syscall.EAFNOSUPPORT
|
||||||
|
}
|
||||||
|
return fns, port, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *LinuxSocketBind) SetMark(value uint32) error {
|
func (bind *LinuxSocketBind) SetMark(value uint32) error {
|
||||||
bind.closing.RLock()
|
bind.mu.RLock()
|
||||||
defer bind.closing.RUnlock()
|
defer bind.mu.RUnlock()
|
||||||
|
|
||||||
if bind.sock6 != -1 {
|
if bind.sock6 != -1 {
|
||||||
err := unix.SetsockoptInt(
|
err := unix.SetsockoptInt(
|
||||||
|
@ -177,21 +191,24 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bind.lastMark = value
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *LinuxSocketBind) Close() error {
|
func (bind *LinuxSocketBind) Close() error {
|
||||||
var err1, err2 error
|
// Take a readlock to shut down the sockets...
|
||||||
bind.closing.RLock()
|
bind.mu.RLock()
|
||||||
if bind.sock6 != -1 {
|
if bind.sock6 != -1 {
|
||||||
unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
|
unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
|
||||||
}
|
}
|
||||||
if bind.sock4 != -1 {
|
if bind.sock4 != -1 {
|
||||||
unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
|
unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
|
||||||
}
|
}
|
||||||
bind.closing.RUnlock()
|
bind.mu.RUnlock()
|
||||||
bind.closing.Lock()
|
// ...and a write lock to close the fd.
|
||||||
|
// This ensures that no one else is using the fd.
|
||||||
|
bind.mu.Lock()
|
||||||
|
defer bind.mu.Unlock()
|
||||||
|
var err1, err2 error
|
||||||
if bind.sock6 != -1 {
|
if bind.sock6 != -1 {
|
||||||
err1 = unix.Close(bind.sock6)
|
err1 = unix.Close(bind.sock6)
|
||||||
bind.sock6 = -1
|
bind.sock6 = -1
|
||||||
|
@ -200,7 +217,6 @@ func (bind *LinuxSocketBind) Close() error {
|
||||||
err2 = unix.Close(bind.sock4)
|
err2 = unix.Close(bind.sock4)
|
||||||
bind.sock4 = -1
|
bind.sock4 = -1
|
||||||
}
|
}
|
||||||
bind.closing.Unlock()
|
|
||||||
|
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
return err1
|
return err1
|
||||||
|
@ -208,46 +224,29 @@ func (bind *LinuxSocketBind) Close() error {
|
||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *LinuxSocketBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
func makeReceiveIPv6(sock int) ReceiveFunc {
|
||||||
bind.closing.RLock()
|
return func(buff []byte) (int, Endpoint, error) {
|
||||||
defer bind.closing.RUnlock()
|
|
||||||
|
|
||||||
var end LinuxSocketEndpoint
|
var end LinuxSocketEndpoint
|
||||||
if bind.sock6 == -1 {
|
n, err := receive6(sock, buff, &end)
|
||||||
return 0, nil, net.ErrClosed
|
|
||||||
}
|
|
||||||
n, err := receive6(
|
|
||||||
bind.sock6,
|
|
||||||
buff,
|
|
||||||
&end,
|
|
||||||
)
|
|
||||||
return n, &end, err
|
return n, &end, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *LinuxSocketBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
|
||||||
bind.closing.RLock()
|
|
||||||
defer bind.closing.RUnlock()
|
|
||||||
|
|
||||||
var end LinuxSocketEndpoint
|
|
||||||
if bind.sock4 == -1 {
|
|
||||||
return 0, nil, net.ErrClosed
|
|
||||||
}
|
}
|
||||||
n, err := receive4(
|
|
||||||
bind.sock4,
|
func makeReceiveIPv4(sock int) ReceiveFunc {
|
||||||
buff,
|
return func(buff []byte) (int, Endpoint, error) {
|
||||||
&end,
|
var end LinuxSocketEndpoint
|
||||||
)
|
n, err := receive4(sock, buff, &end)
|
||||||
return n, &end, err
|
return n, &end, err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
|
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
|
||||||
bind.closing.RLock()
|
|
||||||
defer bind.closing.RUnlock()
|
|
||||||
|
|
||||||
nend, ok := end.(*LinuxSocketEndpoint)
|
nend, ok := end.(*LinuxSocketEndpoint)
|
||||||
if !ok {
|
if !ok {
|
||||||
return ErrWrongEndpointType
|
return ErrWrongEndpointType
|
||||||
}
|
}
|
||||||
|
bind.mu.RLock()
|
||||||
|
defer bind.mu.RUnlock()
|
||||||
if !nend.isV6 {
|
if !nend.isV6 {
|
||||||
if bind.sock4 == -1 {
|
if bind.sock4 == -1 {
|
||||||
return net.ErrClosed
|
return net.ErrClosed
|
||||||
|
|
|
@ -8,6 +8,7 @@ package conn
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,6 +17,7 @@ import (
|
||||||
// It uses the Go's net package to implement networking.
|
// It uses the Go's net package to implement networking.
|
||||||
// See LinuxSocketBind for a proper implementation on the Linux platform.
|
// See LinuxSocketBind for a proper implementation on the Linux platform.
|
||||||
type StdNetBind struct {
|
type StdNetBind struct {
|
||||||
|
mu sync.Mutex // protects following fields
|
||||||
ipv4 *net.UDPConn
|
ipv4 *net.UDPConn
|
||||||
ipv6 *net.UDPConn
|
ipv6 *net.UDPConn
|
||||||
blackhole4 bool
|
blackhole4 bool
|
||||||
|
@ -81,12 +83,15 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||||
return conn, uaddr.Port, nil
|
return conn, uaddr.Port, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) Open(uport uint16) (uint16, error) {
|
func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
||||||
|
bind.mu.Lock()
|
||||||
|
defer bind.mu.Unlock()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
var tries int
|
var tries int
|
||||||
|
|
||||||
if bind.ipv4 != nil || bind.ipv6 != nil {
|
if bind.ipv4 != nil || bind.ipv6 != nil {
|
||||||
return 0, ErrBindAlreadyOpen
|
return nil, 0, ErrBindAlreadyOpen
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt to open ipv4 and ipv6 listeners on the same port.
|
// Attempt to open ipv4 and ipv6 listeners on the same port.
|
||||||
|
@ -97,7 +102,7 @@ again:
|
||||||
|
|
||||||
ipv4, port, err = listenNet("udp4", port)
|
ipv4, port, err = listenNet("udp4", port)
|
||||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
return 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen on the same port as we're using for ipv4.
|
// Listen on the same port as we're using for ipv4.
|
||||||
|
@ -109,17 +114,27 @@ again:
|
||||||
}
|
}
|
||||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
ipv4.Close()
|
ipv4.Close()
|
||||||
return 0, err
|
return nil, 0, err
|
||||||
}
|
|
||||||
if ipv4 == nil && ipv6 == nil {
|
|
||||||
return 0, syscall.EAFNOSUPPORT
|
|
||||||
}
|
}
|
||||||
|
var fns []ReceiveFunc
|
||||||
|
if ipv4 != nil {
|
||||||
|
fns = append(fns, makeReceiveFunc(ipv4, true))
|
||||||
bind.ipv4 = ipv4
|
bind.ipv4 = ipv4
|
||||||
|
}
|
||||||
|
if ipv6 != nil {
|
||||||
|
fns = append(fns, makeReceiveFunc(ipv6, false))
|
||||||
bind.ipv6 = ipv6
|
bind.ipv6 = ipv6
|
||||||
return uint16(port), nil
|
}
|
||||||
|
if len(fns) == 0 {
|
||||||
|
return nil, 0, syscall.EAFNOSUPPORT
|
||||||
|
}
|
||||||
|
return fns, uint16(port), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) Close() error {
|
func (bind *StdNetBind) Close() error {
|
||||||
|
bind.mu.Lock()
|
||||||
|
defer bind.mu.Unlock()
|
||||||
|
|
||||||
var err1, err2 error
|
var err1, err2 error
|
||||||
if bind.ipv4 != nil {
|
if bind.ipv4 != nil {
|
||||||
err1 = bind.ipv4.Close()
|
err1 = bind.ipv4.Close()
|
||||||
|
@ -137,23 +152,14 @@ func (bind *StdNetBind) Close() error {
|
||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
func makeReceiveFunc(conn *net.UDPConn, isIPv4 bool) ReceiveFunc {
|
||||||
if bind.ipv4 == nil {
|
return func(buff []byte) (int, Endpoint, error) {
|
||||||
return 0, nil, syscall.EAFNOSUPPORT
|
n, endpoint, err := conn.ReadFromUDP(buff)
|
||||||
}
|
if isIPv4 && endpoint != nil {
|
||||||
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
|
|
||||||
if endpoint != nil {
|
|
||||||
endpoint.IP = endpoint.IP.To4()
|
endpoint.IP = endpoint.IP.To4()
|
||||||
}
|
}
|
||||||
return n, (*StdNetEndpoint)(endpoint), err
|
return n, (*StdNetEndpoint)(endpoint), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
|
||||||
if bind.ipv6 == nil {
|
|
||||||
return 0, nil, syscall.EAFNOSUPPORT
|
|
||||||
}
|
|
||||||
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
|
|
||||||
return n, (*StdNetEndpoint)(endpoint), err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
|
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
|
||||||
|
@ -162,15 +168,16 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
|
||||||
if !ok {
|
if !ok {
|
||||||
return ErrWrongEndpointType
|
return ErrWrongEndpointType
|
||||||
}
|
}
|
||||||
var conn *net.UDPConn
|
|
||||||
var blackhole bool
|
bind.mu.Lock()
|
||||||
if nend.IP.To4() != nil {
|
blackhole := bind.blackhole4
|
||||||
blackhole = bind.blackhole4
|
conn := bind.ipv4
|
||||||
conn = bind.ipv4
|
if nend.IP.To4() == nil {
|
||||||
} else {
|
|
||||||
blackhole = bind.blackhole6
|
blackhole = bind.blackhole6
|
||||||
conn = bind.ipv6
|
conn = bind.ipv6
|
||||||
}
|
}
|
||||||
|
bind.mu.Unlock()
|
||||||
|
|
||||||
if blackhole {
|
if blackhole {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -266,7 +266,7 @@ func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sock
|
||||||
return sa, nil
|
return sa, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) {
|
func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
|
||||||
bind.mu.Lock()
|
bind.mu.Lock()
|
||||||
defer bind.mu.Unlock()
|
defer bind.mu.Unlock()
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -275,30 +275,30 @@ func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if atomic.LoadUint32(&bind.isOpen) != 0 {
|
if atomic.LoadUint32(&bind.isOpen) != 0 {
|
||||||
return 0, ErrBindAlreadyOpen
|
return nil, 0, ErrBindAlreadyOpen
|
||||||
}
|
}
|
||||||
var sa windows.Sockaddr
|
var sa windows.Sockaddr
|
||||||
sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
|
sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
|
sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
|
selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
|
||||||
for i := 0; i < packetsPerRing; i++ {
|
for i := 0; i < packetsPerRing; i++ {
|
||||||
err = bind.v4.InsertReceiveRequest()
|
err = bind.v4.InsertReceiveRequest()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
err = bind.v6.InsertReceiveRequest()
|
err = bind.v6.InsertReceiveRequest()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
atomic.StoreUint32(&bind.isOpen, 1)
|
atomic.StoreUint32(&bind.isOpen, 1)
|
||||||
return
|
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *WinRingBind) Close() error {
|
func (bind *WinRingBind) Close() error {
|
||||||
|
@ -395,13 +395,13 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
|
||||||
return n, &ep, nil
|
return n, &ep, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *WinRingBind) ReceiveIPv4(buf []byte) (int, Endpoint, error) {
|
func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
|
||||||
bind.mu.RLock()
|
bind.mu.RLock()
|
||||||
defer bind.mu.RUnlock()
|
defer bind.mu.RUnlock()
|
||||||
return bind.v4.Receive(buf, &bind.isOpen)
|
return bind.v4.Receive(buf, &bind.isOpen)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *WinRingBind) ReceiveIPv6(buf []byte) (int, Endpoint, error) {
|
func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
|
||||||
bind.mu.RLock()
|
bind.mu.RLock()
|
||||||
defer bind.mu.RUnlock()
|
defer bind.mu.RUnlock()
|
||||||
return bind.v6.Receive(buf, &bind.isOpen)
|
return bind.v6.Receive(buf, &bind.isOpen)
|
||||||
|
@ -482,6 +482,8 @@ func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||||
|
bind.mu.Lock()
|
||||||
|
defer bind.mu.Unlock()
|
||||||
sysconn, err := bind.ipv4.SyscallConn()
|
sysconn, err := bind.ipv4.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -500,6 +502,8 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||||
|
bind.mu.Lock()
|
||||||
|
defer bind.mu.Unlock()
|
||||||
sysconn, err := bind.ipv6.SyscallConn()
|
sysconn, err := bind.ipv6.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -65,12 +65,14 @@ func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
|
||||||
|
|
||||||
func (c ChannelEndpoint) SrcIP() net.IP { return nil }
|
func (c ChannelEndpoint) SrcIP() net.IP { return nil }
|
||||||
|
|
||||||
func (c *ChannelBind) Open(port uint16) (actualPort uint16, err error) {
|
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
||||||
c.closeSignal = make(chan bool)
|
c.closeSignal = make(chan bool)
|
||||||
|
fns = append(fns, c.makeReceiveFunc(*c.rx4))
|
||||||
|
fns = append(fns, c.makeReceiveFunc(*c.rx6))
|
||||||
if rand.Uint32()&1 == 0 {
|
if rand.Uint32()&1 == 0 {
|
||||||
return uint16(c.source4), nil
|
return fns, uint16(c.source4), nil
|
||||||
} else {
|
} else {
|
||||||
return uint16(c.source6), nil
|
return fns, uint16(c.source6), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,22 +89,15 @@ func (c *ChannelBind) Close() error {
|
||||||
|
|
||||||
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
|
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
|
||||||
|
|
||||||
func (c *ChannelBind) ReceiveIPv6(b []byte) (n int, ep conn.Endpoint, err error) {
|
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
|
||||||
|
return func(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||||
select {
|
select {
|
||||||
case <-c.closeSignal:
|
case <-c.closeSignal:
|
||||||
return 0, nil, net.ErrClosed
|
return 0, nil, net.ErrClosed
|
||||||
case rx := <-*c.rx6:
|
case rx := <-ch:
|
||||||
return copy(b, rx), c.target6, nil
|
return copy(b, rx), c.target6, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ChannelBind) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
|
|
||||||
select {
|
|
||||||
case <-c.closeSignal:
|
|
||||||
return 0, nil, net.ErrClosed
|
|
||||||
case rx := <-*c.rx4:
|
|
||||||
return copy(b, rx), c.target4, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
|
func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
|
||||||
|
|
17
conn/conn.go
17
conn/conn.go
|
@ -12,6 +12,11 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// A ReceiveFunc receives a single inbound packet from the network.
|
||||||
|
// It writes the data into b. n is the length of the packet.
|
||||||
|
// ep is the remote endpoint.
|
||||||
|
type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
|
||||||
|
|
||||||
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
||||||
//
|
//
|
||||||
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
|
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
|
||||||
|
@ -19,23 +24,17 @@ import (
|
||||||
type Bind interface {
|
type Bind interface {
|
||||||
// Open puts the Bind into a listening state on a given port and reports the actual
|
// Open puts the Bind into a listening state on a given port and reports the actual
|
||||||
// port that it bound to. Passing zero results in a random selection.
|
// port that it bound to. Passing zero results in a random selection.
|
||||||
Open(port uint16) (actualPort uint16, err error)
|
// fns is the set of functions that will be called to receive packets.
|
||||||
|
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
|
||||||
|
|
||||||
// Close closes the Bind listener.
|
// Close closes the Bind listener.
|
||||||
|
// All fns returned by Open must return net.ErrClosed after a call to Close.
|
||||||
Close() error
|
Close() error
|
||||||
|
|
||||||
// SetMark sets the mark for each packet sent through this Bind.
|
// SetMark sets the mark for each packet sent through this Bind.
|
||||||
// This mark is passed to the kernel as the socket option SO_MARK.
|
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||||
SetMark(mark uint32) error
|
SetMark(mark uint32) error
|
||||||
|
|
||||||
// ReceiveIPv6 reads an IPv6 UDP packet into b. It reports the number of bytes read,
|
|
||||||
// n, the packet source address ep, and any error.
|
|
||||||
ReceiveIPv6(b []byte) (n int, ep Endpoint, err error)
|
|
||||||
|
|
||||||
// ReceiveIPv4 reads an IPv4 UDP packet into b. It reports the number of bytes read,
|
|
||||||
// n, the packet source address ep, and any error.
|
|
||||||
ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
|
|
||||||
|
|
||||||
// Send writes a packet b to address ep.
|
// Send writes a packet b to address ep.
|
||||||
Send(b []byte, ep Endpoint) error
|
Send(b []byte, ep Endpoint) error
|
||||||
|
|
||||||
|
|
|
@ -11,9 +11,6 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.org/x/net/ipv6"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/ratelimiter"
|
"golang.zx2c4.com/wireguard/ratelimiter"
|
||||||
"golang.zx2c4.com/wireguard/rwcancel"
|
"golang.zx2c4.com/wireguard/rwcancel"
|
||||||
|
@ -468,8 +465,9 @@ func (device *Device) BindUpdate() error {
|
||||||
|
|
||||||
// bind to new port
|
// bind to new port
|
||||||
var err error
|
var err error
|
||||||
|
var recvFns []conn.ReceiveFunc
|
||||||
netc := &device.net
|
netc := &device.net
|
||||||
netc.port, err = netc.bind.Open(netc.port)
|
recvFns, netc.port, err = netc.bind.Open(netc.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
netc.port = 0
|
netc.port = 0
|
||||||
return err
|
return err
|
||||||
|
@ -501,11 +499,12 @@ func (device *Device) BindUpdate() error {
|
||||||
device.peers.RUnlock()
|
device.peers.RUnlock()
|
||||||
|
|
||||||
// start receiving routines
|
// start receiving routines
|
||||||
device.net.stopping.Add(2)
|
device.net.stopping.Add(len(recvFns))
|
||||||
device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
||||||
device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
||||||
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
for _, fn := range recvFns {
|
||||||
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
go device.RoutineReceiveIncoming(fn)
|
||||||
|
}
|
||||||
|
|
||||||
device.log.Verbosef("UDP bind has been updated")
|
device.log.Verbosef("UDP bind has been updated")
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -68,15 +68,15 @@ func (peer *Peer) keepKeyFreshReceiving() {
|
||||||
* Every time the bind is updated a new routine is started for
|
* Every time the bind is updated a new routine is started for
|
||||||
* IPv4 and IPv6 (separately)
|
* IPv4 and IPv6 (separately)
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
|
func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
|
||||||
defer func() {
|
defer func() {
|
||||||
device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP)
|
device.log.Verbosef("Routine: receive incoming %p - stopped", recv)
|
||||||
device.queue.decryption.wg.Done()
|
device.queue.decryption.wg.Done()
|
||||||
device.queue.handshake.wg.Done()
|
device.queue.handshake.wg.Done()
|
||||||
device.net.stopping.Done()
|
device.net.stopping.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
device.log.Verbosef("Routine: receive incoming IPv%d - started", IP)
|
device.log.Verbosef("Routine: receive incoming %p - started", recv)
|
||||||
|
|
||||||
// receive datagrams until conn is closed
|
// receive datagrams until conn is closed
|
||||||
|
|
||||||
|
@ -90,14 +90,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
|
||||||
)
|
)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
switch IP {
|
size, endpoint, err = recv(buffer[:])
|
||||||
case ipv4.Version:
|
|
||||||
size, endpoint, err = bind.ReceiveIPv4(buffer[:])
|
|
||||||
case ipv6.Version:
|
|
||||||
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
|
|
||||||
default:
|
|
||||||
panic("invalid IP version")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
device.PutMessageBuffer(buffer)
|
device.PutMessageBuffer(buffer)
|
||||||
|
|
Loading…
Reference in a new issue