Merge branch 'source-caching'

This commit is contained in:
Mathias Hall-Andersen 2017-11-19 13:19:07 +01:00
commit b5ae42349c
20 changed files with 1200 additions and 510 deletions

View file

@ -2,10 +2,35 @@ package main
import (
"errors"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net"
"time"
)
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
*/
type Bind interface {
SetMark(value uint32) error
ReceiveIPv6(buff []byte) (int, Endpoint, error)
ReceiveIPv4(buff []byte) (int, Endpoint, error)
Send(buff []byte, end Endpoint) error
Close() error
}
/* An Endpoint maintains the source/destination caching for a peer
*
* dst : the remote address of a peer ("endpoint" in uapi terminology)
* src : the local address from which datagrams originate going to the peer
*/
type Endpoint interface {
ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP
SrcIP() net.IP
}
func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address
@ -27,63 +52,83 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
return addr, err
}
func updateUDPConn(device *Device) error {
/* Must hold device and net lock
*/
func unsafeCloseUDPListener(device *Device) error {
var err error
netc := &device.net
if netc.bind != nil {
err = netc.bind.Close()
netc.bind = nil
}
return err
}
// must inform all listeners
func UpdateUDPListener(device *Device) error {
device.mutex.Lock()
defer device.mutex.Unlock()
netc := &device.net
netc.mutex.Lock()
defer netc.mutex.Unlock()
// close existing connection
// close existing sockets
if netc.conn != nil {
netc.conn.Close()
netc.conn = nil
// We need for that fd to be closed in all other go routines, which
// means we have to wait. TODO: find less horrible way of doing this.
time.Sleep(time.Second / 2)
if err := unsafeCloseUDPListener(device); err != nil {
return err
}
// open new connection
// assumption: netc.update WaitGroup should be exactly 1
// open new sockets
if device.tun.isUp.Get() {
// listen on new address
device.log.Debug.Println("UDP bind updating")
conn, err := net.ListenUDP("udp", netc.addr)
// bind to new port
var err error
netc.bind, netc.port, err = CreateBind(netc.port)
if err != nil {
netc.bind = nil
return err
}
// set mark
err = netc.bind.SetMark(netc.fwmark)
if err != nil {
return err
}
// set fwmark
// clear cached source addresses
err = setMark(netc.conn, netc.fwmark)
if err != nil {
return err
for _, peer := range device.peers {
peer.mutex.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.mutex.Unlock()
}
// retrieve port (may have been chosen by kernel)
// decrease waitgroup to 0
addr := conn.LocalAddr()
netc.conn = conn
netc.addr, _ = net.ResolveUDPAddr(
addr.Network(),
addr.String(),
)
go device.RoutineReceiveIncomming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncomming(ipv6.Version, netc.bind)
// notify goroutines
signalSend(device.signal.newUDPConn)
device.log.Debug.Println("UDP bind has been updated")
}
return nil
}
func closeUDPConn(device *Device) {
netc := &device.net
netc.mutex.Lock()
if netc.conn != nil {
netc.conn.Close()
}
netc.mutex.Unlock()
signalSend(device.signal.newUDPConn)
func CloseUDPListener(device *Device) error {
device.mutex.Lock()
device.net.mutex.Lock()
err := unsafeCloseUDPListener(device)
device.net.mutex.Unlock()
device.mutex.Unlock()
return err
}

View file

@ -6,6 +6,126 @@ import (
"net"
)
func setMark(conn *net.UDPConn, value uint32) error {
/* This code is meant to be a temporary solution
* on platforms for which the sticky socket / source caching behavior
* has not yet been implemented.
*
* See conn_linux.go for an implementation on the linux platform.
*/
type NativeBind struct {
ipv4 *net.UDPConn
ipv6 *net.UDPConn
}
type NativeEndpoint net.UDPAddr
var _ Bind = (*NativeBind)(nil)
var _ Endpoint = (*NativeEndpoint)(nil)
func CreateEndpoint(s string) (Endpoint, error) {
addr, err := parseEndpoint(s)
return (*NativeEndpoint)(addr), err
}
func (_ *NativeEndpoint) ClearSrc() {}
func (e *NativeEndpoint) DstIP() net.IP {
return (*net.UDPAddr)(e).IP
}
func (e *NativeEndpoint) SrcIP() net.IP {
return nil // not supported
}
func (e *NativeEndpoint) DstToBytes() []byte {
addr := (*net.UDPAddr)(e)
out := addr.IP
out = append(out, byte(addr.Port&0xff))
out = append(out, byte((addr.Port>>8)&0xff))
return out
}
func (e *NativeEndpoint) DstToString() string {
return (*net.UDPAddr)(e).String()
}
func (e *NativeEndpoint) SrcToString() string {
return ""
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
// listen
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
// retrieve port
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn, uaddr.Port, nil
}
func CreateBind(uport uint16) (Bind, uint16, error) {
var err error
var bind NativeBind
port := int(uport)
bind.ipv4, port, err = listenNet("udp4", port)
if err != nil {
return nil, 0, err
}
bind.ipv6, port, err = listenNet("udp6", port)
if err != nil {
bind.ipv4.Close()
return nil, 0, err
}
return &bind, uint16(port), nil
}
func (bind *NativeBind) Close() error {
err1 := bind.ipv4.Close()
err2 := bind.ipv6.Close()
if err1 != nil {
return err1
}
return err2
}
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
return n, (*NativeEndpoint)(endpoint), err
}
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
return n, (*NativeEndpoint)(endpoint), err
}
func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
var err error
nend := endpoint.(*NativeEndpoint)
if nend.IP.To16() != nil {
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
} else {
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
}
return err
}
func (bind *NativeBind) SetMark(_ uint32) error {
return nil
}

View file

@ -7,6 +7,7 @@
package main
import (
"encoding/binary"
"errors"
"golang.org/x/sys/unix"
"net"
@ -15,20 +16,230 @@ import (
)
/* Supports source address caching
*
* It is important that the endpoint is only updated after the packet content has been authenticated.
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
*/
type Endpoint struct {
// source (selected based on dst type)
// (could use RawSockaddrAny and unsafe)
srcIPv6 unix.RawSockaddrInet6
srcIPv4 unix.RawSockaddrInet4
srcIf4 int32
type NativeEndpoint struct {
src unix.RawSockaddrInet6
dst unix.RawSockaddrInet6
}
dst unix.RawSockaddrAny
type NativeBind struct {
sock4 int
sock6 int
}
var _ Endpoint = (*NativeEndpoint)(nil)
var _ Bind = NativeBind{}
type IPv4Source struct {
src unix.RawSockaddrInet4
Ifindex int32
}
func htons(val uint16) uint16 {
var out [unsafe.Sizeof(val)]byte
binary.BigEndian.PutUint16(out[:], val)
return *((*uint16)(unsafe.Pointer(&out[0])))
}
func ntohs(val uint16) uint16 {
tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val)))
return binary.BigEndian.Uint16((*tmp)[:])
}
func CreateEndpoint(s string) (Endpoint, error) {
var end NativeEndpoint
addr, err := parseEndpoint(s)
if err != nil {
return nil, err
}
ipv4 := addr.IP.To4()
if ipv4 != nil {
dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
dst.Family = unix.AF_INET
dst.Port = htons(uint16(addr.Port))
dst.Zero = [8]byte{}
copy(dst.Addr[:], ipv4)
end.ClearSrc()
return &end, nil
}
ipv6 := addr.IP.To16()
if ipv6 != nil {
zone, err := zoneToUint32(addr.Zone)
if err != nil {
return nil, err
}
dst := &end.dst
dst.Family = unix.AF_INET6
dst.Port = htons(uint16(addr.Port))
dst.Flowinfo = 0
dst.Scope_id = zone
copy(dst.Addr[:], ipv6[:])
end.ClearSrc()
return &end, nil
}
return nil, errors.New("Failed to recognize IP address format")
}
func CreateBind(port uint16) (Bind, uint16, error) {
var err error
var bind NativeBind
bind.sock6, port, err = create6(port)
if err != nil {
return nil, port, err
}
bind.sock4, port, err = create4(port)
if err != nil {
unix.Close(bind.sock6)
}
return bind, port, err
}
func (bind NativeBind) SetMark(value uint32) error {
err := unix.SetsockoptInt(
bind.sock6,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
return unix.SetsockoptInt(
bind.sock4,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
}
func closeUnblock(fd int) error {
// shutdown to unblock readers
unix.Shutdown(fd, unix.SHUT_RD)
return unix.Close(fd)
}
func (bind NativeBind) Close() error {
err1 := closeUnblock(bind.sock6)
err2 := closeUnblock(bind.sock4)
if err1 != nil {
return err1
}
return err2
}
func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint
n, err := receive6(
bind.sock6,
buff,
&end,
)
return n, &end, err
}
func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint
n, err := receive4(
bind.sock4,
buff,
&end,
)
return n, &end, err
}
func (bind NativeBind) Send(buff []byte, end Endpoint) error {
nend := end.(*NativeEndpoint)
switch nend.dst.Family {
case unix.AF_INET6:
return send6(bind.sock6, nend, buff)
case unix.AF_INET:
return send4(bind.sock4, nend, buff)
default:
return errors.New("Unknown address family of destination")
}
}
func sockaddrToString(addr unix.RawSockaddrInet6) string {
var udpAddr net.UDPAddr
switch addr.Family {
case unix.AF_INET6:
udpAddr.Port = int(ntohs(addr.Port))
udpAddr.IP = addr.Addr[:]
return udpAddr.String()
case unix.AF_INET:
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
udpAddr.Port = int(ntohs(ptr.Port))
udpAddr.IP = net.IPv4(
ptr.Addr[0],
ptr.Addr[1],
ptr.Addr[2],
ptr.Addr[3],
)
return udpAddr.String()
default:
return "<unknown address family>"
}
}
func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP {
switch addr.Family {
case unix.AF_INET6:
return addr.Addr[:]
case unix.AF_INET:
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
return net.IPv4(
ptr.Addr[0],
ptr.Addr[1],
ptr.Addr[2],
ptr.Addr[3],
)
default:
return nil
}
}
func (end *NativeEndpoint) SrcIP() net.IP {
return rawAddrToIP(end.src)
}
func (end *NativeEndpoint) DstIP() net.IP {
return rawAddrToIP(end.dst)
}
func (end *NativeEndpoint) DstToBytes() []byte {
ptr := unsafe.Pointer(&end.src)
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
return arr[:]
}
func (end *NativeEndpoint) SrcToString() string {
return sockaddrToString(end.src)
}
func (end *NativeEndpoint) DstToString() string {
return sockaddrToString(end.dst)
}
func (end *NativeEndpoint) ClearDst() {
end.dst = unix.RawSockaddrInet6{}
}
func (end *NativeEndpoint) ClearSrc() {
end.src = unix.RawSockaddrInet6{}
}
func zoneToUint32(zone string) (uint32, error) {
@ -42,51 +253,116 @@ func zoneToUint32(zone string) (uint32, error) {
return uint32(n), err
}
func (end *Endpoint) ClearSrc() {
end.srcIf4 = 0
end.srcIPv4 = unix.RawSockaddrInet4{}
end.srcIPv6 = unix.RawSockaddrInet6{}
}
func create4(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
func (end *Endpoint) Set(s string) error {
addr, err := parseEndpoint(s)
if err != nil {
return err
return -1, 0, err
}
ipv6 := addr.IP.To16()
if ipv6 != nil {
zone, err := zoneToUint32(addr.Zone)
if err != nil {
addr := unix.SockaddrInet4{
Port: int(port),
}
// set sockopts and bind
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst))
ptr.Family = unix.AF_INET6
ptr.Port = uint16(addr.Port)
ptr.Flowinfo = 0
ptr.Scope_id = zone
copy(ptr.Addr[:], ipv6[:])
end.ClearSrc()
return nil
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IP,
unix.IP_PKTINFO,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
}
ipv4 := addr.IP.To4()
if ipv4 != nil {
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
ptr.Family = unix.AF_INET
ptr.Port = uint16(addr.Port)
ptr.Zero = [8]byte{}
copy(ptr.Addr[:], ipv4)
end.ClearSrc()
return nil
}
return errors.New("Failed to recognize IP address format")
return fd, uint16(addr.Port), err
}
func send6(sock uintptr, end *Endpoint, buff []byte) error {
var iovec unix.Iovec
func create6(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return -1, 0, err
}
// set sockopts and bind
addr := unix.SockaddrInet6{
Port: int(port),
}
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_RECVPKTINFO,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_V6ONLY,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
}
return fd, uint16(addr.Port), err
}
func send6(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
@ -97,11 +373,11 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
Len: unix.SizeofInet6Pktinfo,
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet6Pktinfo{
Addr: end.srcIPv6.Addr,
Ifindex: end.srcIPv6.Scope_id,
Addr: end.src.Addr,
Ifindex: end.src.Scope_id,
},
}
@ -119,22 +395,41 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
_, _, errno := unix.Syscall(
unix.SYS_SENDMSG,
sock,
uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
if errno == 0 {
return nil
}
// clear src and retry
if errno == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{}
_, _, errno = unix.Syscall(
unix.SYS_SENDMSG,
uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
}
return errno
}
func send4(sock uintptr, end *Endpoint, buff []byte) error {
var iovec unix.Iovec
func send4(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
@ -142,11 +437,11 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
Len: unix.SizeofInet6Pktinfo,
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet4Pktinfo{
Spec_dst: end.srcIPv4.Addr,
Ifindex: end.srcIf4,
Spec_dst: src4.src.Addr,
Ifindex: src4.Ifindex,
},
}
@ -156,51 +451,44 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
Name: (*byte)(unsafe.Pointer(&end.dst)),
Namelen: unix.SizeofSockaddrInet4,
Control: (*byte)(unsafe.Pointer(&cmsg)),
Flags: 0,
}
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
// sendmsg(sock, &msghdr, 0)
_, _, errno := unix.Syscall(
unix.SYS_SENDMSG,
sock,
uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
// clear source and try again
if errno == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{}
_, _, errno = unix.Syscall(
unix.SYS_SENDMSG,
uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
}
// errno = 0 is still an error instance
if errno == 0 {
return nil
}
return errno
}
func send(c *net.UDPConn, end *Endpoint, buff []byte) error {
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// extract underlying file descriptor
file, err := c.File()
if err != nil {
return err
}
sock := file.Fd()
// send depending on address family of dst
family := *((*uint16)(unsafe.Pointer(&end.dst)))
if family == unix.AF_INET {
return send4(sock, end, buff)
} else if family == unix.AF_INET6 {
return send6(sock, end, buff)
}
return errors.New("Unknown address family of source")
}
func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) {
file, err := c.File()
if err != nil {
return err, nil, nil
}
// contruct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
@ -208,60 +496,87 @@ func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAdd
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo // big enough
pktinfo unix.Inet4Pktinfo
}
var msghdr unix.Msghdr
msghdr.Iov = &iovec
msghdr.Iovlen = 1
msghdr.Name = (*byte)(unsafe.Pointer(&end.dst))
msghdr.Namelen = unix.SizeofSockaddrInet4
msghdr.Control = (*byte)(unsafe.Pointer(&cmsg))
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
// recvmsg(sock, &mskhdr, 0)
size, _, errno := unix.Syscall(
unix.SYS_RECVMSG,
uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
if errno != 0 {
return 0, errno
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
src4.src.Family = unix.AF_INET
src4.src.Addr = cmsg.pktinfo.Spec_dst
src4.Ifindex = cmsg.pktinfo.Ifindex
}
return int(size), nil
}
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}
var msg unix.Msghdr
msg.Iov = &iovec
msg.Iovlen = 1
msg.Name = (*byte)(unsafe.Pointer(&end.dst))
msg.Namelen = uint32(unix.SizeofSockaddrAny)
msg.Namelen = uint32(unix.SizeofSockaddrInet6)
msg.Control = (*byte)(unsafe.Pointer(&cmsg))
msg.SetControllen(int(unsafe.Sizeof(cmsg)))
_, _, errno := unix.Syscall(
// recvmsg(sock, &mskhdr, 0)
size, _, errno := unix.Syscall(
unix.SYS_RECVMSG,
file.Fd(),
uintptr(sock),
uintptr(unsafe.Pointer(&msg)),
0,
)
if errno != 0 {
return errno, nil, nil
return 0, errno
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
end.src.Family = unix.AF_INET6
end.src.Addr = cmsg.pktinfo.Addr
end.src.Scope_id = cmsg.pktinfo.Ifindex
}
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo))
println(info)
}
return nil, nil, nil
}
func setMark(conn *net.UDPConn, value uint32) error {
if conn == nil {
return nil
}
file, err := conn.File()
if err != nil {
return err
}
return unix.SetsockoptInt(
int(file.Fd()),
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
return int(size), nil
}

View file

@ -5,10 +5,8 @@ import (
"crypto/rand"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"net"
"sync"
"time"
"unsafe"
)
type CookieChecker struct {
@ -76,7 +74,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
return hmac.Equal(mac1[:], msg[smac1:smac2])
}
func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
st.mutex.RLock()
defer st.mutex.RUnlock()
@ -89,8 +87,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
var cookie [blake2s.Size128]byte
func() {
mac, _ := blake2s.New128(st.mac2.secret[:])
mac.Write(src.IP)
mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
mac.Write(src)
mac.Sum(cookie[:0])
}()
@ -111,7 +108,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
func (st *CookieChecker) CreateReply(
msg []byte,
recv uint32,
src *net.UDPAddr,
src []byte,
) (*MessageCookieReply, error) {
st.mutex.RLock()
@ -136,8 +133,7 @@ func (st *CookieChecker) CreateReply(
var cookie [blake2s.Size128]byte
func() {
mac, _ := blake2s.New128(st.mac2.secret[:])
mac.Write(src.IP)
mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
mac.Write(src)
mac.Sum(cookie[:0])
}()

View file

@ -1,7 +1,6 @@
package main
import (
"net"
"testing"
)
@ -25,7 +24,7 @@ func TestCookieMAC1(t *testing.T) {
// check mac1
src, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4000")
src := []byte{192, 168, 13, 37, 10, 10, 10}
checkMAC1 := func(msg []byte) {
generator.AddMacs(msg)
@ -128,12 +127,12 @@ func TestCookieMAC1(t *testing.T) {
msg[5] ^= 0x20
srcBad1, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4001")
srcBad1 := []byte{192, 168, 13, 37, 40, 01}
if checker.CheckMAC2(msg, srcBad1) {
t.Fatal("MAC2 generation/verification failed")
}
srcBad2, _ := net.ResolveUDPAddr("udp", "192.168.13.38:4000")
srcBad2 := []byte{192, 168, 13, 38, 40, 01}
if checker.CheckMAC2(msg, srcBad2) {
t.Fatal("MAC2 generation/verification failed")
}

View file

@ -2,29 +2,25 @@ package main
import (
"os"
"os/exec"
)
/* Daemonizes the process on linux
*
* This is done by spawning and releasing a copy with the --foreground flag
*
* TODO: Use env variable to spawn in background
*/
func Daemonize(attr *os.ProcAttr) error {
// I would like to use os.Executable,
// however this means dropping support for Go <1.8
path, err := exec.LookPath(os.Args[0])
if err != nil {
return err
}
func Daemonize() error {
argv := []string{os.Args[0], "--foreground"}
argv = append(argv, os.Args[1:]...)
attr := &os.ProcAttr{
Dir: ".",
Env: os.Environ(),
Files: []*os.File{
os.Stdin,
nil,
nil,
},
}
process, err := os.StartProcess(
argv[0],
path,
argv,
attr,
)

View file

@ -1,7 +1,6 @@
package main
import (
"net"
"runtime"
"sync"
"sync/atomic"
@ -9,8 +8,9 @@ import (
)
type Device struct {
log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers
closed AtomicBool // device is closed? (acting as guard)
log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers
fwMark uint32
tun struct {
device TUNDevice
@ -22,9 +22,9 @@ type Device struct {
}
net struct {
mutex sync.RWMutex
addr *net.UDPAddr // UDP source address
conn *net.UDPConn // UDP "connection"
fwmark uint32
bind Bind // bind interface
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
}
mutex sync.RWMutex
privateKey NoisePrivateKey
@ -37,8 +37,7 @@ type Device struct {
handshake chan QueueHandshakeElement
}
signal struct {
stop chan struct{} // halts all go routines
newUDPConn chan struct{} // a net.conn was set (consumed by the receiver routine)
stop chan struct{}
}
underLoadUntil atomic.Value
ratelimiter Ratelimiter
@ -128,21 +127,23 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
device.pool.messageBuffers.Put(msg)
}
func NewDevice(tun TUNDevice, logLevel int) *Device {
func NewDevice(tun TUNDevice, logger *Logger) *Device {
device := new(Device)
device.mutex.Lock()
defer device.mutex.Unlock()
device.log = NewLogger(logLevel, "("+tun.Name()+") ")
device.log = logger
device.peers = make(map[NoisePublicKey]*Peer)
device.tun.device = tun
device.indices.Init()
device.ratelimiter.Init()
device.routingTable.Reset()
device.underLoadUntil.Store(time.Time{})
// setup pools
// setup buffer pool
device.pool.messageBuffers = sync.Pool{
New: func() interface{} {
@ -159,7 +160,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
// prepare signals
device.signal.stop = make(chan struct{})
device.signal.newUDPConn = make(chan struct{}, 1)
// prepare net
device.net.port = 0
device.net.bind = nil
// start workers
@ -168,12 +173,9 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
go device.RoutineDecryption()
go device.RoutineHandshake()
}
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
go device.RoutineReadFromTUN()
go device.RoutineReceiveIncomming()
return device
}
@ -202,9 +204,13 @@ func (device *Device) RemoveAllPeers() {
}
func (device *Device) Close() {
if device.closed.Swap(true) {
return
}
device.log.Info.Println("Closing device")
device.RemoveAllPeers()
close(device.signal.stop)
closeUDPConn(device)
CloseUDPListener(device)
device.tun.device.Close()
}

View file

@ -2,6 +2,7 @@ package main
import (
"bytes"
"os"
"testing"
)
@ -15,6 +16,10 @@ type DummyTUN struct {
events chan TUNEvent
}
func (tun *DummyTUN) File() *os.File {
return nil
}
func (tun *DummyTUN) Name() string {
return tun.name
}
@ -67,7 +72,8 @@ func randDevice(t *testing.T) *Device {
t.Fatal(err)
}
tun, _ := CreateDummyTUN("dummy")
device := NewDevice(tun, LogLevelError)
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
}

View file

@ -2,10 +2,15 @@ package main
import (
"fmt"
"log"
"os"
"os/signal"
"runtime"
"strconv"
)
const (
ENV_WG_TUN_FD = "WG_TUN_FD"
ENV_WG_UAPI_FD = "WG_UAPI_FD"
)
func printUsage() {
@ -43,28 +48,6 @@ func main() {
interfaceName = os.Args[1]
}
// daemonize the process
if !foreground {
err := Daemonize()
if err != nil {
log.Println("Failed to daemonize:", err)
}
return
}
// increase number of go workers (for Go <1.5)
runtime.GOMAXPROCS(runtime.NumCPU())
// open TUN device
tun, err := CreateTUN(interfaceName)
if err != nil {
log.Println("Failed to create tun device:", err)
return
}
// get log level (default: info)
logLevel := func() int {
@ -79,25 +62,103 @@ func main() {
return LogLevelInfo
}()
logger := NewLogger(
logLevel,
fmt.Sprintf("(%s) ", interfaceName),
)
logger.Debug.Println("Debug log enabled")
// open TUN device (or use supplied fd)
tun, err := func() (TUNDevice, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" {
return CreateTUN(interfaceName)
}
// construct tun device from supplied fd
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
if err != nil {
return nil, err
}
file := os.NewFile(uintptr(fd), "")
return CreateTUNFromFile(interfaceName, file)
}()
if err != nil {
logger.Error.Println("Failed to create TUN device:", err)
os.Exit(ExitSetupFailed)
}
// open UAPI file (or use supplied fd)
fileUAPI, err := func() (*os.File, error) {
uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
if uapiFdStr == "" {
return UAPIOpen(interfaceName)
}
// use supplied fd
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), ""), nil
}()
if err != nil {
logger.Error.Println("UAPI listen error:", err)
os.Exit(ExitSetupFailed)
return
}
// daemonize the process
if !foreground {
env := os.Environ()
env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD))
env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
attr := &os.ProcAttr{
Files: []*os.File{
nil, // stdin
nil, // stdout
nil, // stderr
tun.File(),
fileUAPI,
},
Dir: ".",
Env: env,
}
err = Daemonize(attr)
if err != nil {
logger.Error.Println("Failed to daemonize:", err)
os.Exit(ExitSetupFailed)
}
return
}
// increase number of go workers (for Go <1.5)
runtime.GOMAXPROCS(runtime.NumCPU())
// create wireguard device
device := NewDevice(tun, logLevel)
device := NewDevice(tun, logger)
logInfo := device.log.Info
logError := device.log.Error
logInfo.Println("Starting device")
logger.Info.Println("Device started")
// start configuration lister
uapi, err := NewUAPIListener(interfaceName)
if err != nil {
logError.Fatal("UAPI listen error:", err)
}
// start uapi listener
errs := make(chan error)
term := make(chan os.Signal)
wait := device.WaitChannel()
uapi, err := UAPIListen(interfaceName, fileUAPI)
go func() {
for {
conn, err := uapi.Accept()
@ -109,7 +170,7 @@ func main() {
}
}()
logInfo.Println("UAPI listener started")
logger.Info.Println("UAPI listener started")
// wait for program to terminate
@ -122,9 +183,10 @@ func main() {
case <-errs:
}
// clean up UAPI bind
// clean up
uapi.Close()
device.Close()
logInfo.Println("Closing")
logger.Info.Println("Shutting down")
}

View file

@ -21,6 +21,14 @@ func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.flag) == AtomicTrue
}
func (a *AtomicBool) Swap(val bool) bool {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
return atomic.SwapInt32(&a.flag, flag) == AtomicTrue
}
func (a *AtomicBool) Set(val bool) {
flag := AtomicFalse
if val {

View file

@ -117,8 +117,8 @@ func TestNoiseHandshake(t *testing.T) {
var err error
var out []byte
var nonce [12]byte
out = key1.send.aead.Seal(out, nonce[:], testMsg, nil)
out, err = key2.receive.aead.Open(out[:0], nonce[:], out, nil)
out = key1.send.Seal(out, nonce[:], testMsg, nil)
out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
assertNil(t, err)
assertEqual(t, out, testMsg)
}()
@ -128,8 +128,8 @@ func TestNoiseHandshake(t *testing.T) {
var err error
var out []byte
var nonce [12]byte
out = key2.send.aead.Seal(out, nonce[:], testMsg, nil)
out, err = key1.receive.aead.Open(out[:0], nonce[:], out, nil)
out = key2.send.Seal(out, nonce[:], testMsg, nil)
out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
assertNil(t, err)
assertEqual(t, out, testMsg)
}()

View file

@ -4,7 +4,6 @@ import (
"encoding/base64"
"errors"
"fmt"
"net"
"sync"
"time"
)
@ -16,7 +15,7 @@ type Peer struct {
keyPairs KeyPairs
handshake Handshake
device *Device
endpoint *net.UDPAddr
endpoint Endpoint
stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
@ -106,6 +105,10 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
handshake.mutex.Unlock()
// reset endpoint
peer.endpoint = nil
// prepare queuing
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
@ -130,11 +133,31 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil
}
func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.mutex.RLock()
defer peer.device.net.mutex.RUnlock()
peer.mutex.RLock()
defer peer.mutex.RUnlock()
if peer.endpoint == nil {
return errors.New("No known endpoint for peer")
}
return peer.device.net.bind.Send(buffer, peer.endpoint)
}
/* Returns a short string identification for logging
*/
func (peer *Peer) String() string {
if peer.endpoint == nil {
return fmt.Sprintf(
"peer(%d unknown %s)",
peer.id,
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
)
}
return fmt.Sprintf(
"peer(%d %s %s)",
peer.id,
peer.endpoint.String(),
peer.endpoint.DstToString(),
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
)
}

View file

@ -13,19 +13,20 @@ import (
)
type QueueHandshakeElement struct {
msgType uint32
packet []byte
buffer *[MaxMessageSize]byte
source *net.UDPAddr
msgType uint32
packet []byte
endpoint Endpoint
buffer *[MaxMessageSize]byte
}
type QueueInboundElement struct {
dropped int32
mutex sync.Mutex
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
keyPair *KeyPair
dropped int32
mutex sync.Mutex
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
keyPair *KeyPair
endpoint Endpoint
}
func (elem *QueueInboundElement) Drop() {
@ -92,130 +93,122 @@ func (device *Device) addToHandshakeQueue(
}
}
func (device *Device) RoutineReceiveIncomming() {
func (device *Device) RoutineReceiveIncomming(IP int, bind Bind) {
logDebug := device.log.Debug
logDebug.Println("Routine, receive incomming, started")
logDebug.Println("Routine, receive incomming, IP version:", IP)
for {
// wait for new conn
// receive datagrams until conn is closed
logDebug.Println("Waiting for udp socket")
buffer := device.GetMessageBuffer()
select {
case <-device.signal.stop:
return
var (
err error
size int
endpoint Endpoint
)
case <-device.signal.newUDPConn:
for {
// fetch connection
// read next datagram
device.net.mutex.RLock()
conn := device.net.conn
device.net.mutex.RUnlock()
if conn == nil {
switch IP {
case ipv4.Version:
size, endpoint, err = bind.ReceiveIPv4(buffer[:])
case ipv6.Version:
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
default:
return
}
if err != nil {
break
}
if size < MinMessageSize {
continue
}
logDebug.Println("Listening for inbound packets")
// check size of packet
// receive datagrams until conn is closed
packet := buffer[:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
buffer := device.GetMessageBuffer()
var okay bool
for {
switch msgType {
// read next datagram
// check if transport
size, raddr, err := conn.ReadFromUDP(buffer[:])
case MessageTransportType:
if err != nil {
break
}
// check size
if size < MinMessageSize {
if len(packet) < MessageTransportType {
continue
}
// check size of packet
// lookup key pair
packet := buffer[:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
var okay bool
switch msgType {
// check if transport
case MessageTransportType:
// check size
if len(packet) < MessageTransportType {
continue
}
// lookup key pair
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indices.Lookup(receiver)
keyPair := value.keyPair
if keyPair == nil {
continue
}
// check key-pair expiry
if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
}
// create work element
peer := value.peer
elem := &QueueInboundElement{
packet: packet,
buffer: buffer,
keyPair: keyPair,
dropped: AtomicFalse,
}
elem.mutex.Lock()
// add to decryption queues
device.addToDecryptionQueue(device.queue.decryption, elem)
device.addToInboundQueue(peer.queue.inbound, elem)
buffer = device.GetMessageBuffer()
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indices.Lookup(receiver)
keyPair := value.keyPair
if keyPair == nil {
continue
// otherwise it is a handshake related packet
case MessageInitiationType:
okay = len(packet) == MessageInitiationSize
case MessageResponseType:
okay = len(packet) == MessageResponseSize
case MessageCookieReplyType:
okay = len(packet) == MessageCookieReplySize
}
if okay {
device.addToHandshakeQueue(
device.queue.handshake,
QueueHandshakeElement{
msgType: msgType,
buffer: buffer,
packet: packet,
source: raddr,
},
)
buffer = device.GetMessageBuffer()
// check key-pair expiry
if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
}
// create work element
peer := value.peer
elem := &QueueInboundElement{
packet: packet,
buffer: buffer,
keyPair: keyPair,
dropped: AtomicFalse,
endpoint: endpoint,
}
elem.mutex.Lock()
// add to decryption queues
device.addToDecryptionQueue(device.queue.decryption, elem)
device.addToInboundQueue(peer.queue.inbound, elem)
buffer = device.GetMessageBuffer()
continue
// otherwise it is a fixed size & handshake related packet
case MessageInitiationType:
okay = len(packet) == MessageInitiationSize
case MessageResponseType:
okay = len(packet) == MessageResponseSize
case MessageCookieReplyType:
okay = len(packet) == MessageCookieReplySize
}
if okay {
device.addToHandshakeQueue(
device.queue.handshake,
QueueHandshakeElement{
msgType: msgType,
buffer: buffer,
packet: packet,
endpoint: endpoint,
},
)
buffer = device.GetMessageBuffer()
}
}
}
@ -293,8 +286,6 @@ func (device *Device) RoutineHandshake() {
// unmarshal packet
logDebug.Println("Process cookie reply from:", elem.source.String())
var reply MessageCookieReply
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &reply)
@ -321,15 +312,25 @@ func (device *Device) RoutineHandshake() {
return
}
// endpoints destination address is the source of the datagram
srcBytes := elem.endpoint.DstToBytes()
if device.IsUnderLoad() {
if !device.mac.CheckMAC2(elem.packet, elem.source) {
// verify MAC2 field
if !device.mac.CheckMAC2(elem.packet, srcBytes) {
// construct cookie reply
logDebug.Println("Sending cookie reply to:", elem.source.String())
logDebug.Println(
"Sending cookie reply to:",
elem.endpoint.DstToString(),
)
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
reply, err := device.mac.CreateReply(elem.packet, sender, elem.source)
sender := binary.LittleEndian.Uint32(elem.packet[4:8])
reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
if err != nil {
logError.Println("Failed to create cookie reply:", err)
return
@ -339,17 +340,16 @@ func (device *Device) RoutineHandshake() {
writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, reply)
_, err = device.net.conn.WriteToUDP(
writer.Bytes(),
elem.source,
)
device.net.bind.Send(writer.Bytes(), elem.endpoint)
if err != nil {
logDebug.Println("Failed to send cookie reply:", err)
}
continue
}
if !device.ratelimiter.Allow(elem.source.IP) {
// check ratelimiter
if !device.ratelimiter.Allow(elem.endpoint.DstIP()) {
continue
}
}
@ -380,8 +380,7 @@ func (device *Device) RoutineHandshake() {
if peer == nil {
logInfo.Println(
"Recieved invalid initiation message from",
elem.source.IP.String(),
elem.source.Port,
elem.endpoint.DstToString(),
)
continue
}
@ -392,10 +391,9 @@ func (device *Device) RoutineHandshake() {
peer.TimerAnyAuthenticatedPacketReceived()
// update endpoint
// TODO: Discover destination address also, only update on change
peer.mutex.Lock()
peer.endpoint = elem.source
peer.endpoint = elem.endpoint
peer.mutex.Unlock()
// create response
@ -418,9 +416,11 @@ func (device *Device) RoutineHandshake() {
// send response
_, err = peer.SendBuffer(packet)
err = peer.SendBuffer(packet)
if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal()
} else {
logError.Println("Failed to send response to:", peer.String(), err)
}
case MessageResponseType:
@ -441,12 +441,17 @@ func (device *Device) RoutineHandshake() {
if peer == nil {
logInfo.Println(
"Recieved invalid response message from",
elem.source.IP.String(),
elem.source.Port,
elem.endpoint.DstToString(),
)
continue
}
// update endpoint
peer.mutex.Lock()
peer.endpoint = elem.endpoint
peer.mutex.Unlock()
logDebug.Println("Received handshake initation from", peer)
peer.TimerEphemeralKeyCreated()
@ -515,6 +520,12 @@ func (peer *Peer) RoutineSequentialReceiver() {
}
kp.mutex.Unlock()
// update endpoint
peer.mutex.Lock()
peer.endpoint = elem.endpoint
peer.mutex.Unlock()
// check for keep-alive
if len(elem.packet) == 0 {
@ -546,7 +557,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.routingTable.LookupIPv4(src) != peer {
logInfo.Println("Packet with unallowed source IP from", peer.String())
logInfo.Println(
"IPv4 packet with unallowed source address from",
peer.String(),
)
continue
}
@ -571,7 +585,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routingTable.LookupIPv6(src) != peer {
logInfo.Println("Packet with unallowed source IP from", peer.String())
logInfo.Println(
"IPv6 packet with unallowed source address from",
peer.String(),
)
continue
}
@ -580,7 +597,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue
}
// write to tun
// write to tun device
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
_, err := device.tun.device.Write(elem.packet)

View file

@ -2,7 +2,6 @@ package main
import (
"encoding/binary"
"errors"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
@ -105,26 +104,6 @@ func addToEncryptionQueue(
}
}
func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
peer.device.net.mutex.RLock()
defer peer.device.net.mutex.RUnlock()
peer.mutex.RLock()
defer peer.mutex.RUnlock()
endpoint := peer.endpoint
if endpoint == nil {
return 0, errors.New("No known endpoint for peer")
}
conn := peer.device.net.conn
if conn == nil {
return 0, errors.New("No UDP socket for device")
}
return conn.WriteToUDP(buffer, endpoint)
}
/* Reads packets from the TUN and inserts
* into nonce queue for peer
*
@ -343,7 +322,7 @@ func (peer *Peer) RoutineSequentialSender() {
// send message and return buffer to pool
length := uint64(len(elem.packet))
_, err := peer.SendBuffer(elem.packet)
err := peer.SendBuffer(elem.packet)
device.PutMessageBuffer(elem.buffer)
if err != nil {
logDebug.Println("Failed to send authenticated packet to peer", peer.String())

View file

@ -20,6 +20,14 @@
# wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1
# interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further
# details on how this is accomplished.
# This code is ported to the WireGuard-Go directly from the kernel project.
#
# Please ensure that you have installed the newest version of the WireGuard
# tools from the WireGuard project and before running these tests as:
#
# ./netns.sh <path to wireguard-go>
set -e
exec 3>&1
@ -27,8 +35,8 @@ export WG_HIDE_KEYS=never
netns0="wg-test-$$-0"
netns1="wg-test-$$-1"
netns2="wg-test-$$-2"
program="../wireguard-go"
export LOG_LEVEL="error"
program=$1
export LOG_LEVEL="info"
pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; }
pp() { pretty "" "$*"; "$@"; }
@ -72,13 +80,11 @@ pp ip netns add $netns2
ip0 link set up dev lo
# ip0 link add dev wg1 type wireguard
n0 $program -f wg1 &
sleep 1
n0 $program wg1
ip0 link set wg1 netns $netns1
# ip0 link add dev wg1 type wireguard
n0 $program -f wg2 &
sleep 1
n0 $program wg2
ip0 link set wg2 netns $netns2
key1="$(pp wg genkey)"
@ -185,14 +191,14 @@ ip0 -4 addr del 127.0.0.1/8 dev lo
ip0 -4 addr add 127.212.121.99/8 dev lo
n0 wg set wg1 listen-port 9999
n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000
n1 ping6 -W 1 -c 1 fd00::20000
[[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]]
n1 ping6 -W 1 -c 1 fd00::2
[[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]]
# Test using IPv6 that roaming works
n1 wg set wg1 listen-port 9998
n1 wg set wg1 peer "$pub2" endpoint [::1]:20000
n1 ping -W 1 -c 1 192.168.241.2
[[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]]
[[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]]
# Test that crypto-RP filter works
n1 wg set wg1 peer "$pub2" allowed-ips 192.168.241.0/24
@ -212,7 +218,7 @@ n2 ncat -u 192.168.241.1 1111 <<<"X"
! read -r -N 1 -t 1 out <&4
kill $nmap_pid
n0 wg set wg1 peer "$more_specific_key" remove
[[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]]
[[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]]
ip1 link del wg1
ip2 link del wg2
@ -263,7 +269,7 @@ n0 iptables -t nat -A POSTROUTING -s 192.168.1.0/24 -d 10.0.0.0/24 -j SNAT --to
n0 wg set wg1 peer "$pub2" endpoint 10.0.0.100:20000 persistent-keepalive 1
n1 ping -W 1 -c 1 192.168.241.2
n2 ping -W 1 -c 1 192.168.241.1
[[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
[[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
# Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`).
pp sleep 3
n2 ping -W 1 -c 1 192.168.241.1
@ -289,7 +295,7 @@ ip2 link del wg2
# ip1 link add dev wg1 type wireguard
# ip2 link add dev wg1 type wireguard
n1 $program wg1
n2 $program wg1
n2 $program wg2
configure_peers
@ -336,17 +342,83 @@ waitiface $netns1 veth1
waitiface $netns2 veth2
n0 wg set wg2 peer "$pub1" endpoint 10.0.0.1:10000
n2 ping -W 1 -c 1 192.168.241.1
[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::1]:10000
n2 ping -W 1 -c 1 192.168.241.1
[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]]
[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]]
n0 wg set wg2 peer "$pub1" endpoint 10.0.0.2:10000
n2 ping -W 1 -c 1 192.168.241.1
[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]]
[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]]
n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::2]:10000
n2 ping -W 1 -c 1 192.168.241.1
[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]]
[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]]
ip1 link del veth1
ip1 link del wg1
ip2 link del wg2
# Test that Netlink/IPC is working properly by doing things that usually cause split responses
n0 $program wg0
sleep 5
config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" )
for a in {1..255}; do
for b in {0..255}; do
config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" )
done
done
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
i=0
for ip in $(n0 wg show wg0 allowed-ips); do
((++i))
done
((i == 255*256*2+1))
ip0 link del wg0
n0 $program wg0
config=( "[Interface]" "PrivateKey=$(wg genkey)" )
for a in {1..40}; do
config+=( "[Peer]" "PublicKey=$(wg genkey)" )
for b in {1..52}; do
config+=( "AllowedIPs=$a.$b.0.0/16" )
done
done
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
i=0
while read -r line; do
j=0
for ip in $line; do
((++j))
done
((j == 53))
((++i))
done < <(n0 wg show wg0 allowed-ips)
((i == 40))
ip0 link del wg0
n0 $program wg0
config=( )
for i in {1..29}; do
config+=( "[Peer]" "PublicKey=$(wg genkey)" )
done
config+=( "[Peer]" "PublicKey=$(wg genkey)" "AllowedIPs=255.2.3.4/32,abcd::255/128" )
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
n0 wg showconf wg0 > /dev/null
ip0 link del wg0
! n0 wg show doesnotexist || false
declare -A objects
while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
[[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue
objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"
done < /dev/kmsg
alldeleted=1
for object in "${!objects[@]}"; do
if [[ ${objects["$object"]} != *createddestroyed ]]; then
echo "Error: $object: merely ${objects["$object"]}" >&3
alldeleted=0
fi
done
[[ $alldeleted -eq 1 ]]
pretty "" "Objects that were created were also destroyed."

View file

@ -279,34 +279,31 @@ func (peer *Peer) RoutineHandshakeInitiator() {
break AttemptHandshakes
}
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
// marshal and send
// marshal handshake message
writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.mac.AddMacs(packet)
_, err = peer.SendBuffer(packet)
if err != nil {
// send to endpoint
err = peer.SendBuffer(packet)
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
timeout := time.NewTimer(RekeyTimeout + jitter)
if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal()
logDebug.Println(
"Handshake initiation attempt",
attempts, "sent to", peer.String(),
)
} else {
logError.Println(
"Failed to send handshake initiation message to",
peer.String(), ":", err,
)
continue
}
peer.TimerAnyAuthenticatedPacketTraversal()
// set handshake timeout
timeout := time.NewTimer(RekeyTimeout + jitter)
logDebug.Println(
"Handshake initiation attempt",
attempts, "sent to", peer.String(),
)
// wait for handshake or timeout
select {

View file

@ -1,6 +1,7 @@
package main
import (
"os"
"sync/atomic"
)
@ -15,6 +16,7 @@ const (
)
type TUNDevice interface {
File() *os.File // returns the file descriptor of the device
Read([]byte) (int, error) // read a packet from the device (without any additional headers)
Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
MTU() (int, error) // returns the MTU of the device
@ -47,7 +49,7 @@ func (device *Device) RoutineTUNEventReader() {
if !device.tun.isUp.Get() {
logInfo.Println("Interface set up")
device.tun.isUp.Set(true)
updateUDPConn(device)
UpdateUDPListener(device)
}
}
@ -55,7 +57,7 @@ func (device *Device) RoutineTUNEventReader() {
if device.tun.isUp.Get() {
logInfo.Println("Interface set down")
device.tun.isUp.Set(false)
closeUDPConn(device)
CloseUDPListener(device)
}
}
}

View file

@ -56,6 +56,10 @@ type NativeTun struct {
events chan TUNEvent // device related events
}
func (tun *NativeTun) File() *os.File {
return tun.fd
}
func (tun *NativeTun) RoutineNetlinkListener() {
sock := int(C.bind_rtmgrp())
if sock < 0 {
@ -222,7 +226,7 @@ func (tun *NativeTun) MTU() (int, error) {
val := binary.LittleEndian.Uint32(ifr[16:20])
if val >= (1 << 31) {
return int(val-(1<<31)) - (1 << 31), nil
return int(toInt32(val)), nil
}
return int(val), nil
}
@ -248,6 +252,29 @@ func (tun *NativeTun) Close() error {
return nil
}
func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) {
device := &NativeTun{
fd: fd,
name: name,
events: make(chan TUNEvent, 5),
errors: make(chan error, 5),
}
// start event listener
var err error
device.index, err = getIFIndex(device.name)
if err != nil {
return nil, err
}
go device.RoutineNetlinkListener()
// set default MTU
return device, device.setMTU(DefaultMTU)
}
func CreateTUN(name string) (TUNDevice, error) {
// open clone device

View file

@ -39,9 +39,10 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send("private_key=" + device.privateKey.ToHex())
}
if device.net.addr != nil {
send(fmt.Sprintf("listen_port=%d", device.net.addr.Port))
if device.net.port != 0 {
send(fmt.Sprintf("listen_port=%d", device.net.port))
}
if device.net.fwmark != 0 {
send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
}
@ -53,7 +54,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
if peer.endpoint != nil {
send("endpoint=" + peer.endpoint.String())
send("endpoint=" + peer.endpoint.DstToString())
}
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
@ -134,56 +135,38 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "listen_port":
port, err := strconv.ParseUint(value, 10, 16)
if err != nil {
logError.Println("Failed to set listen_port:", err)
logError.Println("Failed to parse listen_port:", err)
return &IPCError{Code: ipcErrorInvalid}
}
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
if err != nil {
logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorInvalid}
}
device.net.mutex.Lock()
device.net.addr = addr
device.net.mutex.Unlock()
err = updateUDPConn(device)
if err != nil {
device.net.port = uint16(port)
if err := UpdateUDPListener(device); err != nil {
logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorPortInUse}
}
// TODO: Clear source address of all peers
case "fwmark":
fwmark, err := strconv.ParseUint(value, 10, 32)
// parse fwmark field
fwmark, err := func() (uint32, error) {
if value == "" {
return 0, nil
}
mark, err := strconv.ParseUint(value, 10, 32)
return uint32(mark), err
}()
if err != nil {
logError.Println("Invalid fwmark", err)
return &IPCError{Code: ipcErrorInvalid}
}
device.net.mutex.Lock()
if fwmark > 0 || device.net.fwmark > 0 {
device.net.fwmark = uint32(fwmark)
err := setMark(
device.net.conn,
device.net.fwmark,
)
if err != nil {
logError.Println("Failed to set fwmark:", err)
device.net.mutex.Unlock()
return &IPCError{Code: ipcErrorIO}
}
// TODO: Clear source address of all peers
}
device.net.fwmark = uint32(fwmark)
device.net.mutex.Unlock()
case "public_key":
// switch to peer configuration
deviceConfig = false
case "replace_peers":
@ -218,7 +201,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.mutex.RLock()
if device.publicKey.Equals(pubKey) {
// create dummy instance
// create dummy instance (not added to device)
peer = &Peer{}
dummy = true
@ -244,6 +227,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
case "remove":
// remove currently selected peer from device
if value != "true" {
logError.Println("Failed to set remove, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
@ -256,6 +242,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
dummy = true
case "preshared_key":
// update PSK
peer.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value)
peer.mutex.Unlock()
@ -265,15 +254,25 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
case "endpoint":
addr, err := parseEndpoint(value)
// set endpoint destination
err := func() error {
peer.mutex.Lock()
defer peer.mutex.Unlock()
endpoint, err := CreateEndpoint(value)
if err != nil {
return err
}
peer.endpoint = endpoint
signalSend(peer.signal.handshakeReset)
return nil
}()
if err != nil {
logError.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalid}
}
peer.mutex.Lock()
peer.endpoint = addr
peer.mutex.Unlock()
signalSend(peer.signal.handshakeReset)
case "persistent_keepalive_interval":

View file

@ -10,12 +10,12 @@ import (
)
const (
ipcErrorIO = -int64(unix.EIO)
ipcErrorProtocol = -int64(unix.EPROTO)
ipcErrorInvalid = -int64(unix.EINVAL)
ipcErrorPortInUse = -int64(unix.EADDRINUSE)
socketDirectory = "/var/run/wireguard"
socketName = "%s.sock"
ipcErrorIO = -int64(unix.EIO)
ipcErrorProtocol = -int64(unix.EPROTO)
ipcErrorInvalid = -int64(unix.EINVAL)
ipcErrorPortInUse = -int64(unix.EADDRINUSE)
socketDirectory = "/var/run/wireguard"
socketName = "%s.sock"
)
type UAPIListener struct {
@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr {
return nil
}
func connectUnixSocket(path string) (net.Listener, error) {
func UAPIListen(name string, file *os.File) (net.Listener, error) {
// attempt inital connection
// wrap file in listener
listener, err := net.Listen("unix", path)
if err == nil {
return listener, nil
}
// check if active
_, err = net.Dial("unix", path)
if err == nil {
return nil, errors.New("Unix socket in use")
}
// attempt cleanup
err = os.Remove(path)
if err != nil {
return nil, err
}
return net.Listen("unix", path)
}
func NewUAPIListener(name string) (net.Listener, error) {
// check if path exist
err := os.MkdirAll(socketDirectory, 077)
if err != nil && !os.IsExist(err) {
return nil, err
}
// open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
listener, err := connectUnixSocket(socketPath)
listener, err := net.FileListener(file)
if err != nil {
return nil, err
}
@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) {
// watch for deletion of socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
uapi.inotifyFd, err = unix.InotifyInit()
if err != nil {
return nil, err
@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) {
go func(l *UAPIListener) {
var buff [4096]byte
for {
unix.Read(uapi.inotifyFd, buff[:])
// start with lstat to avoid race condition
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
l.connErr <- err
return
}
unix.Read(uapi.inotifyFd, buff[:])
}
}(uapi)
@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) {
return uapi, nil
}
func UAPIOpen(name string) (*os.File, error) {
// check if path exist
err := os.MkdirAll(socketDirectory, 0600)
if err != nil && !os.IsExist(err) {
return nil, err
}
// open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
listener, err := func() (*net.UnixListener, error) {
// initial connection attempt
listener, err := net.ListenUnix("unix", addr)
if err == nil {
return listener, nil
}
// check if socket already active
_, err = net.Dial("unix", socketPath)
if err == nil {
return nil, errors.New("unix socket in use")
}
// cleanup & attempt again
err = os.Remove(socketPath)
if err != nil {
return nil, err
}
return net.ListenUnix("unix", addr)
}()
if err != nil {
return nil, err
}
return listener.File()
}