boundif: introduce API for socket binding

This commit is contained in:
Jason A. Donenfeld 2019-03-03 05:01:06 +01:00
parent 69f0fe67b6
commit b8e85267cf
8 changed files with 155 additions and 21 deletions

34
device/boundif_android.go Normal file
View file

@ -0,0 +1,34 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
func (device *Device) PeekLookAtSocketFd4() (fd int, err error) {
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
if err != nil {
return
}
err = sysconn.Control(func(f uintptr) {
fd = int(f)
})
if err != nil {
return
}
return
}
func (device *Device) PeekLookAtSocketFd6() (fd int, err error) {
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
if err != nil {
return
}
err = sysconn.Control(func(f uintptr) {
fd = int(f)
})
if err != nil {
return
}
return
}

44
device/boundif_darwin.go Normal file
View file

@ -0,0 +1,44 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"golang.org/x/sys/unix"
)
func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
if err != nil {
return nil
}
err2 := sysconn.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, int(interfaceIndex))
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
return nil
}
func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
if err != nil {
return nil
}
err2 := sysconn.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, int(interfaceIndex))
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
return nil
}

56
device/boundif_windows.go Normal file
View file

@ -0,0 +1,56 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"encoding/binary"
"golang.org/x/sys/windows"
"unsafe"
)
const (
sockoptIP_UNICAST_IF = 31
sockoptIPV6_UNICAST_IF = 31
)
func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, interfaceIndex)
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
if err != nil {
return err
}
err2 := sysconn.Control(func(fd uintptr) {
err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex))
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
return nil
}
func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
if err != nil {
return err
}
err2 := sysconn.Control(func(fd uintptr) {
err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex))
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
return nil
}

View file

@ -20,14 +20,14 @@ import (
* See conn_linux.go for an implementation on the linux platform. * See conn_linux.go for an implementation on the linux platform.
*/ */
type NativeBind struct { type nativeBind struct {
ipv4 *net.UDPConn ipv4 *net.UDPConn
ipv6 *net.UDPConn ipv6 *net.UDPConn
} }
type NativeEndpoint net.UDPAddr type NativeEndpoint net.UDPAddr
var _ Bind = (*NativeBind)(nil) var _ Bind = (*nativeBind)(nil)
var _ Endpoint = (*NativeEndpoint)(nil) var _ Endpoint = (*NativeEndpoint)(nil)
func CreateEndpoint(s string) (Endpoint, error) { func CreateEndpoint(s string) (Endpoint, error) {
@ -100,7 +100,7 @@ func extractErrno(err error) error {
func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
var err error var err error
var bind NativeBind var bind nativeBind
port := int(uport) port := int(uport)
@ -119,7 +119,7 @@ func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
return &bind, uint16(port), nil return &bind, uint16(port), nil
} }
func (bind *NativeBind) Close() error { func (bind *nativeBind) Close() error {
var err1, err2 error var err1, err2 error
if bind.ipv4 != nil { if bind.ipv4 != nil {
err1 = bind.ipv4.Close() err1 = bind.ipv4.Close()
@ -133,7 +133,7 @@ func (bind *NativeBind) Close() error {
return err2 return err2
} }
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
if bind.ipv4 == nil { if bind.ipv4 == nil {
return 0, nil, syscall.EAFNOSUPPORT return 0, nil, syscall.EAFNOSUPPORT
} }
@ -144,7 +144,7 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
return n, (*NativeEndpoint)(endpoint), err return n, (*NativeEndpoint)(endpoint), err
} }
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
if bind.ipv6 == nil { if bind.ipv6 == nil {
return 0, nil, syscall.EAFNOSUPPORT return 0, nil, syscall.EAFNOSUPPORT
} }
@ -152,7 +152,7 @@ func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
return n, (*NativeEndpoint)(endpoint), err return n, (*NativeEndpoint)(endpoint), err
} }
func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error { func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
var err error var err error
nend := endpoint.(*NativeEndpoint) nend := endpoint.(*NativeEndpoint)
if nend.IP.To4() != nil { if nend.IP.To4() != nil {

View file

@ -63,7 +63,7 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
} }
type NativeBind struct { type nativeBind struct {
sock4 int sock4 int
sock6 int sock6 int
netlinkSock int netlinkSock int
@ -72,7 +72,7 @@ type NativeBind struct {
} }
var _ Endpoint = (*NativeEndpoint)(nil) var _ Endpoint = (*NativeEndpoint)(nil)
var _ Bind = (*NativeBind)(nil) var _ Bind = (*nativeBind)(nil)
func CreateEndpoint(s string) (Endpoint, error) { func CreateEndpoint(s string) (Endpoint, error) {
var end NativeEndpoint var end NativeEndpoint
@ -127,9 +127,9 @@ func createNetlinkRouteSocket() (int, error) {
} }
func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) { func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
var err error var err error
var bind NativeBind var bind nativeBind
var newPort uint16 var newPort uint16
bind.netlinkSock, err = createNetlinkRouteSocket() bind.netlinkSock, err = createNetlinkRouteSocket()
@ -176,7 +176,7 @@ func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
return &bind, port, nil return &bind, port, nil
} }
func (bind *NativeBind) SetMark(value uint32) error { func (bind *nativeBind) SetMark(value uint32) error {
if bind.sock6 != -1 { if bind.sock6 != -1 {
err := unix.SetsockoptInt( err := unix.SetsockoptInt(
bind.sock6, bind.sock6,
@ -213,7 +213,7 @@ func closeUnblock(fd int) error {
return unix.Close(fd) return unix.Close(fd)
} }
func (bind *NativeBind) Close() error { func (bind *nativeBind) Close() error {
var err1, err2, err3 error var err1, err2, err3 error
if bind.sock6 != -1 { if bind.sock6 != -1 {
err1 = closeUnblock(bind.sock6) err1 = closeUnblock(bind.sock6)
@ -232,7 +232,7 @@ func (bind *NativeBind) Close() error {
return err3 return err3
} }
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint var end NativeEndpoint
if bind.sock6 == -1 { if bind.sock6 == -1 {
return 0, nil, syscall.EAFNOSUPPORT return 0, nil, syscall.EAFNOSUPPORT
@ -245,7 +245,7 @@ func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
return n, &end, err return n, &end, err
} }
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint var end NativeEndpoint
if bind.sock4 == -1 { if bind.sock4 == -1 {
return 0, nil, syscall.EAFNOSUPPORT return 0, nil, syscall.EAFNOSUPPORT
@ -258,7 +258,7 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
return n, &end, err return n, &end, err
} }
func (bind *NativeBind) Send(buff []byte, end Endpoint) error { func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
nend := end.(*NativeEndpoint) nend := end.(*NativeEndpoint)
if !nend.isV6 { if !nend.isV6 {
if bind.sock4 == -1 { if bind.sock4 == -1 {
@ -592,7 +592,7 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
return size, nil return size, nil
} }
func (bind *NativeBind) routineRouteListener(device *Device) { func (bind *nativeBind) routineRouteListener(device *Device) {
type peerEndpointPtr struct { type peerEndpointPtr struct {
peer *Peer peer *Peer
endpoint *Endpoint endpoint *Endpoint

View file

@ -7,6 +7,6 @@
package device package device
func (bind *NativeBind) SetMark(mark uint32) error { func (bind *nativeBind) SetMark(mark uint32) error {
return nil return nil
} }

View file

@ -25,7 +25,7 @@ func init() {
} }
} }
func (bind *NativeBind) SetMark(mark uint32) error { func (bind *nativeBind) SetMark(mark uint32) error {
var operr error var operr error
if fwmarkIoctl == 0 { if fwmarkIoctl == 0 {
return nil return nil

View file

@ -258,10 +258,10 @@ func (peer *Peer) Stop() {
peer.ZeroAndFlushAll() peer.ZeroAndFlushAll()
} }
var roamingDisabled bool var RoamingDisabled bool
func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) { func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
if roamingDisabled { if RoamingDisabled {
return return
} }
peer.Lock() peer.Lock()