Merge branch 'source-caching'

This commit is contained in:
Mathias Hall-Andersen 2017-11-19 13:19:07 +01:00
commit b5ae42349c
20 changed files with 1200 additions and 510 deletions

View file

@ -2,10 +2,35 @@ package main
import ( import (
"errors" "errors"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net" "net"
"time"
) )
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
*/
type Bind interface {
SetMark(value uint32) error
ReceiveIPv6(buff []byte) (int, Endpoint, error)
ReceiveIPv4(buff []byte) (int, Endpoint, error)
Send(buff []byte, end Endpoint) error
Close() error
}
/* An Endpoint maintains the source/destination caching for a peer
*
* dst : the remote address of a peer ("endpoint" in uapi terminology)
* src : the local address from which datagrams originate going to the peer
*/
type Endpoint interface {
ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP
SrcIP() net.IP
}
func parseEndpoint(s string) (*net.UDPAddr, error) { func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address // ensure that the host is an IP address
@ -27,63 +52,83 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
return addr, err return addr, err
} }
func updateUDPConn(device *Device) error { /* Must hold device and net lock
*/
func unsafeCloseUDPListener(device *Device) error {
var err error
netc := &device.net
if netc.bind != nil {
err = netc.bind.Close()
netc.bind = nil
}
return err
}
// must inform all listeners
func UpdateUDPListener(device *Device) error {
device.mutex.Lock()
defer device.mutex.Unlock()
netc := &device.net netc := &device.net
netc.mutex.Lock() netc.mutex.Lock()
defer netc.mutex.Unlock() defer netc.mutex.Unlock()
// close existing connection // close existing sockets
if netc.conn != nil { if err := unsafeCloseUDPListener(device); err != nil {
netc.conn.Close() return err
netc.conn = nil
// We need for that fd to be closed in all other go routines, which
// means we have to wait. TODO: find less horrible way of doing this.
time.Sleep(time.Second / 2)
} }
// open new connection // assumption: netc.update WaitGroup should be exactly 1
// open new sockets
if device.tun.isUp.Get() { if device.tun.isUp.Get() {
// listen on new address device.log.Debug.Println("UDP bind updating")
conn, err := net.ListenUDP("udp", netc.addr) // bind to new port
var err error
netc.bind, netc.port, err = CreateBind(netc.port)
if err != nil {
netc.bind = nil
return err
}
// set mark
err = netc.bind.SetMark(netc.fwmark)
if err != nil { if err != nil {
return err return err
} }
// set fwmark // clear cached source addresses
err = setMark(netc.conn, netc.fwmark) for _, peer := range device.peers {
if err != nil { peer.mutex.Lock()
return err if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.mutex.Unlock()
} }
// retrieve port (may have been chosen by kernel) // decrease waitgroup to 0
addr := conn.LocalAddr() go device.RoutineReceiveIncomming(ipv4.Version, netc.bind)
netc.conn = conn go device.RoutineReceiveIncomming(ipv6.Version, netc.bind)
netc.addr, _ = net.ResolveUDPAddr(
addr.Network(),
addr.String(),
)
// notify goroutines device.log.Debug.Println("UDP bind has been updated")
signalSend(device.signal.newUDPConn)
} }
return nil return nil
} }
func closeUDPConn(device *Device) { func CloseUDPListener(device *Device) error {
netc := &device.net device.mutex.Lock()
netc.mutex.Lock() device.net.mutex.Lock()
if netc.conn != nil { err := unsafeCloseUDPListener(device)
netc.conn.Close() device.net.mutex.Unlock()
} device.mutex.Unlock()
netc.mutex.Unlock() return err
signalSend(device.signal.newUDPConn)
} }

View file

@ -6,6 +6,126 @@ import (
"net" "net"
) )
func setMark(conn *net.UDPConn, value uint32) error { /* This code is meant to be a temporary solution
* on platforms for which the sticky socket / source caching behavior
* has not yet been implemented.
*
* See conn_linux.go for an implementation on the linux platform.
*/
type NativeBind struct {
ipv4 *net.UDPConn
ipv6 *net.UDPConn
}
type NativeEndpoint net.UDPAddr
var _ Bind = (*NativeBind)(nil)
var _ Endpoint = (*NativeEndpoint)(nil)
func CreateEndpoint(s string) (Endpoint, error) {
addr, err := parseEndpoint(s)
return (*NativeEndpoint)(addr), err
}
func (_ *NativeEndpoint) ClearSrc() {}
func (e *NativeEndpoint) DstIP() net.IP {
return (*net.UDPAddr)(e).IP
}
func (e *NativeEndpoint) SrcIP() net.IP {
return nil // not supported
}
func (e *NativeEndpoint) DstToBytes() []byte {
addr := (*net.UDPAddr)(e)
out := addr.IP
out = append(out, byte(addr.Port&0xff))
out = append(out, byte((addr.Port>>8)&0xff))
return out
}
func (e *NativeEndpoint) DstToString() string {
return (*net.UDPAddr)(e).String()
}
func (e *NativeEndpoint) SrcToString() string {
return ""
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
// listen
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
// retrieve port
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn, uaddr.Port, nil
}
func CreateBind(uport uint16) (Bind, uint16, error) {
var err error
var bind NativeBind
port := int(uport)
bind.ipv4, port, err = listenNet("udp4", port)
if err != nil {
return nil, 0, err
}
bind.ipv6, port, err = listenNet("udp6", port)
if err != nil {
bind.ipv4.Close()
return nil, 0, err
}
return &bind, uint16(port), nil
}
func (bind *NativeBind) Close() error {
err1 := bind.ipv4.Close()
err2 := bind.ipv6.Close()
if err1 != nil {
return err1
}
return err2
}
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
return n, (*NativeEndpoint)(endpoint), err
}
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
return n, (*NativeEndpoint)(endpoint), err
}
func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
var err error
nend := endpoint.(*NativeEndpoint)
if nend.IP.To16() != nil {
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
} else {
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
}
return err
}
func (bind *NativeBind) SetMark(_ uint32) error {
return nil return nil
} }

View file

@ -7,6 +7,7 @@
package main package main
import ( import (
"encoding/binary"
"errors" "errors"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"net" "net"
@ -15,20 +16,230 @@ import (
) )
/* Supports source address caching /* Supports source address caching
*
* It is important that the endpoint is only updated after the packet content has been authenticated.
* *
* Currently there is no way to achieve this within the net package: * Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930 * See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
*/ */
type Endpoint struct { type NativeEndpoint struct {
// source (selected based on dst type) src unix.RawSockaddrInet6
// (could use RawSockaddrAny and unsafe) dst unix.RawSockaddrInet6
srcIPv6 unix.RawSockaddrInet6 }
srcIPv4 unix.RawSockaddrInet4
srcIf4 int32
dst unix.RawSockaddrAny type NativeBind struct {
sock4 int
sock6 int
}
var _ Endpoint = (*NativeEndpoint)(nil)
var _ Bind = NativeBind{}
type IPv4Source struct {
src unix.RawSockaddrInet4
Ifindex int32
}
func htons(val uint16) uint16 {
var out [unsafe.Sizeof(val)]byte
binary.BigEndian.PutUint16(out[:], val)
return *((*uint16)(unsafe.Pointer(&out[0])))
}
func ntohs(val uint16) uint16 {
tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val)))
return binary.BigEndian.Uint16((*tmp)[:])
}
func CreateEndpoint(s string) (Endpoint, error) {
var end NativeEndpoint
addr, err := parseEndpoint(s)
if err != nil {
return nil, err
}
ipv4 := addr.IP.To4()
if ipv4 != nil {
dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
dst.Family = unix.AF_INET
dst.Port = htons(uint16(addr.Port))
dst.Zero = [8]byte{}
copy(dst.Addr[:], ipv4)
end.ClearSrc()
return &end, nil
}
ipv6 := addr.IP.To16()
if ipv6 != nil {
zone, err := zoneToUint32(addr.Zone)
if err != nil {
return nil, err
}
dst := &end.dst
dst.Family = unix.AF_INET6
dst.Port = htons(uint16(addr.Port))
dst.Flowinfo = 0
dst.Scope_id = zone
copy(dst.Addr[:], ipv6[:])
end.ClearSrc()
return &end, nil
}
return nil, errors.New("Failed to recognize IP address format")
}
func CreateBind(port uint16) (Bind, uint16, error) {
var err error
var bind NativeBind
bind.sock6, port, err = create6(port)
if err != nil {
return nil, port, err
}
bind.sock4, port, err = create4(port)
if err != nil {
unix.Close(bind.sock6)
}
return bind, port, err
}
func (bind NativeBind) SetMark(value uint32) error {
err := unix.SetsockoptInt(
bind.sock6,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
return unix.SetsockoptInt(
bind.sock4,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
}
func closeUnblock(fd int) error {
// shutdown to unblock readers
unix.Shutdown(fd, unix.SHUT_RD)
return unix.Close(fd)
}
func (bind NativeBind) Close() error {
err1 := closeUnblock(bind.sock6)
err2 := closeUnblock(bind.sock4)
if err1 != nil {
return err1
}
return err2
}
func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint
n, err := receive6(
bind.sock6,
buff,
&end,
)
return n, &end, err
}
func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint
n, err := receive4(
bind.sock4,
buff,
&end,
)
return n, &end, err
}
func (bind NativeBind) Send(buff []byte, end Endpoint) error {
nend := end.(*NativeEndpoint)
switch nend.dst.Family {
case unix.AF_INET6:
return send6(bind.sock6, nend, buff)
case unix.AF_INET:
return send4(bind.sock4, nend, buff)
default:
return errors.New("Unknown address family of destination")
}
}
func sockaddrToString(addr unix.RawSockaddrInet6) string {
var udpAddr net.UDPAddr
switch addr.Family {
case unix.AF_INET6:
udpAddr.Port = int(ntohs(addr.Port))
udpAddr.IP = addr.Addr[:]
return udpAddr.String()
case unix.AF_INET:
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
udpAddr.Port = int(ntohs(ptr.Port))
udpAddr.IP = net.IPv4(
ptr.Addr[0],
ptr.Addr[1],
ptr.Addr[2],
ptr.Addr[3],
)
return udpAddr.String()
default:
return "<unknown address family>"
}
}
func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP {
switch addr.Family {
case unix.AF_INET6:
return addr.Addr[:]
case unix.AF_INET:
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
return net.IPv4(
ptr.Addr[0],
ptr.Addr[1],
ptr.Addr[2],
ptr.Addr[3],
)
default:
return nil
}
}
func (end *NativeEndpoint) SrcIP() net.IP {
return rawAddrToIP(end.src)
}
func (end *NativeEndpoint) DstIP() net.IP {
return rawAddrToIP(end.dst)
}
func (end *NativeEndpoint) DstToBytes() []byte {
ptr := unsafe.Pointer(&end.src)
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
return arr[:]
}
func (end *NativeEndpoint) SrcToString() string {
return sockaddrToString(end.src)
}
func (end *NativeEndpoint) DstToString() string {
return sockaddrToString(end.dst)
}
func (end *NativeEndpoint) ClearDst() {
end.dst = unix.RawSockaddrInet6{}
}
func (end *NativeEndpoint) ClearSrc() {
end.src = unix.RawSockaddrInet6{}
} }
func zoneToUint32(zone string) (uint32, error) { func zoneToUint32(zone string) (uint32, error) {
@ -42,51 +253,116 @@ func zoneToUint32(zone string) (uint32, error) {
return uint32(n), err return uint32(n), err
} }
func (end *Endpoint) ClearSrc() { func create4(port uint16) (int, uint16, error) {
end.srcIf4 = 0
end.srcIPv4 = unix.RawSockaddrInet4{} // create socket
end.srcIPv6 = unix.RawSockaddrInet6{}
} fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
func (end *Endpoint) Set(s string) error {
addr, err := parseEndpoint(s)
if err != nil { if err != nil {
return err return -1, 0, err
} }
ipv6 := addr.IP.To16() addr := unix.SockaddrInet4{
if ipv6 != nil { Port: int(port),
zone, err := zoneToUint32(addr.Zone) }
if err != nil {
// set sockopts and bind
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err return err
} }
ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst))
ptr.Family = unix.AF_INET6 if err := unix.SetsockoptInt(
ptr.Port = uint16(addr.Port) fd,
ptr.Flowinfo = 0 unix.IPPROTO_IP,
ptr.Scope_id = zone unix.IP_PKTINFO,
copy(ptr.Addr[:], ipv6[:]) 1,
end.ClearSrc() ); err != nil {
return nil return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
} }
ipv4 := addr.IP.To4() return fd, uint16(addr.Port), err
if ipv4 != nil {
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
ptr.Family = unix.AF_INET
ptr.Port = uint16(addr.Port)
ptr.Zero = [8]byte{}
copy(ptr.Addr[:], ipv4)
end.ClearSrc()
return nil
}
return errors.New("Failed to recognize IP address format")
} }
func send6(sock uintptr, end *Endpoint, buff []byte) error { func create6(port uint16) (int, uint16, error) {
var iovec unix.Iovec
// create socket
fd, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return -1, 0, err
}
// set sockopts and bind
addr := unix.SockaddrInet6{
Port: int(port),
}
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_RECVPKTINFO,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_V6ONLY,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
}
return fd, uint16(addr.Port), err
}
func send6(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff)) iovec.SetLen(len(buff))
@ -97,11 +373,11 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
unix.Cmsghdr{ unix.Cmsghdr{
Level: unix.IPPROTO_IPV6, Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO, Type: unix.IPV6_PKTINFO,
Len: unix.SizeofInet6Pktinfo, Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
}, },
unix.Inet6Pktinfo{ unix.Inet6Pktinfo{
Addr: end.srcIPv6.Addr, Addr: end.src.Addr,
Ifindex: end.srcIPv6.Scope_id, Ifindex: end.src.Scope_id,
}, },
} }
@ -119,22 +395,41 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(
unix.SYS_SENDMSG, unix.SYS_SENDMSG,
sock, uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)), uintptr(unsafe.Pointer(&msghdr)),
0, 0,
) )
if errno == 0 {
return nil
}
// clear src and retry
if errno == unix.EINVAL { if errno == unix.EINVAL {
end.ClearSrc() end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{}
_, _, errno = unix.Syscall(
unix.SYS_SENDMSG,
uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
} }
return errno return errno
} }
func send4(sock uintptr, end *Endpoint, buff []byte) error { func send4(sock int, end *NativeEndpoint, buff []byte) error {
var iovec unix.Iovec
// construct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff)) iovec.SetLen(len(buff))
src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
cmsg := struct { cmsg := struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo pktinfo unix.Inet4Pktinfo
@ -142,11 +437,11 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
unix.Cmsghdr{ unix.Cmsghdr{
Level: unix.IPPROTO_IP, Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO, Type: unix.IP_PKTINFO,
Len: unix.SizeofInet6Pktinfo, Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
}, },
unix.Inet4Pktinfo{ unix.Inet4Pktinfo{
Spec_dst: end.srcIPv4.Addr, Spec_dst: src4.src.Addr,
Ifindex: end.srcIf4, Ifindex: src4.Ifindex,
}, },
} }
@ -156,51 +451,44 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
Name: (*byte)(unsafe.Pointer(&end.dst)), Name: (*byte)(unsafe.Pointer(&end.dst)),
Namelen: unix.SizeofSockaddrInet4, Namelen: unix.SizeofSockaddrInet4,
Control: (*byte)(unsafe.Pointer(&cmsg)), Control: (*byte)(unsafe.Pointer(&cmsg)),
Flags: 0,
} }
msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
// sendmsg(sock, &msghdr, 0) // sendmsg(sock, &msghdr, 0)
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(
unix.SYS_SENDMSG, unix.SYS_SENDMSG,
sock, uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)), uintptr(unsafe.Pointer(&msghdr)),
0, 0,
) )
// clear source and try again
if errno == unix.EINVAL { if errno == unix.EINVAL {
end.ClearSrc() end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{}
_, _, errno = unix.Syscall(
unix.SYS_SENDMSG,
uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
} }
// errno = 0 is still an error instance
if errno == 0 {
return nil
}
return errno return errno
} }
func send(c *net.UDPConn, end *Endpoint, buff []byte) error { func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// extract underlying file descriptor // contruct message header
file, err := c.File()
if err != nil {
return err
}
sock := file.Fd()
// send depending on address family of dst
family := *((*uint16)(unsafe.Pointer(&end.dst)))
if family == unix.AF_INET {
return send4(sock, end, buff)
} else if family == unix.AF_INET6 {
return send6(sock, end, buff)
}
return errors.New("Unknown address family of source")
}
func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) {
file, err := c.File()
if err != nil {
return err, nil, nil
}
var iovec unix.Iovec var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
@ -208,60 +496,87 @@ func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAdd
var cmsg struct { var cmsg struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo // big enough pktinfo unix.Inet4Pktinfo
}
var msghdr unix.Msghdr
msghdr.Iov = &iovec
msghdr.Iovlen = 1
msghdr.Name = (*byte)(unsafe.Pointer(&end.dst))
msghdr.Namelen = unix.SizeofSockaddrInet4
msghdr.Control = (*byte)(unsafe.Pointer(&cmsg))
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
// recvmsg(sock, &mskhdr, 0)
size, _, errno := unix.Syscall(
unix.SYS_RECVMSG,
uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
if errno != 0 {
return 0, errno
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
src4.src.Family = unix.AF_INET
src4.src.Addr = cmsg.pktinfo.Spec_dst
src4.Ifindex = cmsg.pktinfo.Ifindex
}
return int(size), nil
}
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
} }
var msg unix.Msghdr var msg unix.Msghdr
msg.Iov = &iovec msg.Iov = &iovec
msg.Iovlen = 1 msg.Iovlen = 1
msg.Name = (*byte)(unsafe.Pointer(&end.dst)) msg.Name = (*byte)(unsafe.Pointer(&end.dst))
msg.Namelen = uint32(unix.SizeofSockaddrAny) msg.Namelen = uint32(unix.SizeofSockaddrInet6)
msg.Control = (*byte)(unsafe.Pointer(&cmsg)) msg.Control = (*byte)(unsafe.Pointer(&cmsg))
msg.SetControllen(int(unsafe.Sizeof(cmsg))) msg.SetControllen(int(unsafe.Sizeof(cmsg)))
_, _, errno := unix.Syscall( // recvmsg(sock, &mskhdr, 0)
size, _, errno := unix.Syscall(
unix.SYS_RECVMSG, unix.SYS_RECVMSG,
file.Fd(), uintptr(sock),
uintptr(unsafe.Pointer(&msg)), uintptr(unsafe.Pointer(&msg)),
0, 0,
) )
if errno != 0 { if errno != 0 {
return errno, nil, nil return 0, errno
} }
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
end.src.Family = unix.AF_INET6
end.src.Addr = cmsg.pktinfo.Addr
end.src.Scope_id = cmsg.pktinfo.Ifindex
} }
if cmsg.cmsghdr.Level == unix.IPPROTO_IP && return int(size), nil
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo))
println(info)
}
return nil, nil, nil
}
func setMark(conn *net.UDPConn, value uint32) error {
if conn == nil {
return nil
}
file, err := conn.File()
if err != nil {
return err
}
return unix.SetsockoptInt(
int(file.Fd()),
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
} }

View file

@ -5,10 +5,8 @@ import (
"crypto/rand" "crypto/rand"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"net"
"sync" "sync"
"time" "time"
"unsafe"
) )
type CookieChecker struct { type CookieChecker struct {
@ -76,7 +74,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
return hmac.Equal(mac1[:], msg[smac1:smac2]) return hmac.Equal(mac1[:], msg[smac1:smac2])
} }
func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool { func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
st.mutex.RLock() st.mutex.RLock()
defer st.mutex.RUnlock() defer st.mutex.RUnlock()
@ -89,8 +87,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
var cookie [blake2s.Size128]byte var cookie [blake2s.Size128]byte
func() { func() {
mac, _ := blake2s.New128(st.mac2.secret[:]) mac, _ := blake2s.New128(st.mac2.secret[:])
mac.Write(src.IP) mac.Write(src)
mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
mac.Sum(cookie[:0]) mac.Sum(cookie[:0])
}() }()
@ -111,7 +108,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
func (st *CookieChecker) CreateReply( func (st *CookieChecker) CreateReply(
msg []byte, msg []byte,
recv uint32, recv uint32,
src *net.UDPAddr, src []byte,
) (*MessageCookieReply, error) { ) (*MessageCookieReply, error) {
st.mutex.RLock() st.mutex.RLock()
@ -136,8 +133,7 @@ func (st *CookieChecker) CreateReply(
var cookie [blake2s.Size128]byte var cookie [blake2s.Size128]byte
func() { func() {
mac, _ := blake2s.New128(st.mac2.secret[:]) mac, _ := blake2s.New128(st.mac2.secret[:])
mac.Write(src.IP) mac.Write(src)
mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
mac.Sum(cookie[:0]) mac.Sum(cookie[:0])
}() }()

View file

@ -1,7 +1,6 @@
package main package main
import ( import (
"net"
"testing" "testing"
) )
@ -25,7 +24,7 @@ func TestCookieMAC1(t *testing.T) {
// check mac1 // check mac1
src, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4000") src := []byte{192, 168, 13, 37, 10, 10, 10}
checkMAC1 := func(msg []byte) { checkMAC1 := func(msg []byte) {
generator.AddMacs(msg) generator.AddMacs(msg)
@ -128,12 +127,12 @@ func TestCookieMAC1(t *testing.T) {
msg[5] ^= 0x20 msg[5] ^= 0x20
srcBad1, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4001") srcBad1 := []byte{192, 168, 13, 37, 40, 01}
if checker.CheckMAC2(msg, srcBad1) { if checker.CheckMAC2(msg, srcBad1) {
t.Fatal("MAC2 generation/verification failed") t.Fatal("MAC2 generation/verification failed")
} }
srcBad2, _ := net.ResolveUDPAddr("udp", "192.168.13.38:4000") srcBad2 := []byte{192, 168, 13, 38, 40, 01}
if checker.CheckMAC2(msg, srcBad2) { if checker.CheckMAC2(msg, srcBad2) {
t.Fatal("MAC2 generation/verification failed") t.Fatal("MAC2 generation/verification failed")
} }

View file

@ -2,29 +2,25 @@ package main
import ( import (
"os" "os"
"os/exec"
) )
/* Daemonizes the process on linux /* Daemonizes the process on linux
* *
* This is done by spawning and releasing a copy with the --foreground flag * This is done by spawning and releasing a copy with the --foreground flag
*
* TODO: Use env variable to spawn in background
*/ */
func Daemonize(attr *os.ProcAttr) error {
// I would like to use os.Executable,
// however this means dropping support for Go <1.8
path, err := exec.LookPath(os.Args[0])
if err != nil {
return err
}
func Daemonize() error {
argv := []string{os.Args[0], "--foreground"} argv := []string{os.Args[0], "--foreground"}
argv = append(argv, os.Args[1:]...) argv = append(argv, os.Args[1:]...)
attr := &os.ProcAttr{
Dir: ".",
Env: os.Environ(),
Files: []*os.File{
os.Stdin,
nil,
nil,
},
}
process, err := os.StartProcess( process, err := os.StartProcess(
argv[0], path,
argv, argv,
attr, attr,
) )

View file

@ -1,7 +1,6 @@
package main package main
import ( import (
"net"
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -9,8 +8,9 @@ import (
) )
type Device struct { type Device struct {
log *Logger // collection of loggers for levels closed AtomicBool // device is closed? (acting as guard)
idCounter uint // for assigning debug ids to peers log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers
fwMark uint32 fwMark uint32
tun struct { tun struct {
device TUNDevice device TUNDevice
@ -22,9 +22,9 @@ type Device struct {
} }
net struct { net struct {
mutex sync.RWMutex mutex sync.RWMutex
addr *net.UDPAddr // UDP source address bind Bind // bind interface
conn *net.UDPConn // UDP "connection" port uint16 // listening port
fwmark uint32 fwmark uint32 // mark value (0 = disabled)
} }
mutex sync.RWMutex mutex sync.RWMutex
privateKey NoisePrivateKey privateKey NoisePrivateKey
@ -37,8 +37,7 @@ type Device struct {
handshake chan QueueHandshakeElement handshake chan QueueHandshakeElement
} }
signal struct { signal struct {
stop chan struct{} // halts all go routines stop chan struct{}
newUDPConn chan struct{} // a net.conn was set (consumed by the receiver routine)
} }
underLoadUntil atomic.Value underLoadUntil atomic.Value
ratelimiter Ratelimiter ratelimiter Ratelimiter
@ -128,21 +127,23 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
device.pool.messageBuffers.Put(msg) device.pool.messageBuffers.Put(msg)
} }
func NewDevice(tun TUNDevice, logLevel int) *Device { func NewDevice(tun TUNDevice, logger *Logger) *Device {
device := new(Device) device := new(Device)
device.mutex.Lock() device.mutex.Lock()
defer device.mutex.Unlock() defer device.mutex.Unlock()
device.log = NewLogger(logLevel, "("+tun.Name()+") ") device.log = logger
device.peers = make(map[NoisePublicKey]*Peer) device.peers = make(map[NoisePublicKey]*Peer)
device.tun.device = tun device.tun.device = tun
device.indices.Init() device.indices.Init()
device.ratelimiter.Init() device.ratelimiter.Init()
device.routingTable.Reset() device.routingTable.Reset()
device.underLoadUntil.Store(time.Time{}) device.underLoadUntil.Store(time.Time{})
// setup pools // setup buffer pool
device.pool.messageBuffers = sync.Pool{ device.pool.messageBuffers = sync.Pool{
New: func() interface{} { New: func() interface{} {
@ -159,7 +160,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
// prepare signals // prepare signals
device.signal.stop = make(chan struct{}) device.signal.stop = make(chan struct{})
device.signal.newUDPConn = make(chan struct{}, 1)
// prepare net
device.net.port = 0
device.net.bind = nil
// start workers // start workers
@ -168,12 +173,9 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
go device.RoutineDecryption() go device.RoutineDecryption()
go device.RoutineHandshake() go device.RoutineHandshake()
} }
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader() go device.RoutineTUNEventReader()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
go device.RoutineReadFromTUN()
go device.RoutineReceiveIncomming()
return device return device
} }
@ -202,9 +204,13 @@ func (device *Device) RemoveAllPeers() {
} }
func (device *Device) Close() { func (device *Device) Close() {
if device.closed.Swap(true) {
return
}
device.log.Info.Println("Closing device")
device.RemoveAllPeers() device.RemoveAllPeers()
close(device.signal.stop) close(device.signal.stop)
closeUDPConn(device) CloseUDPListener(device)
device.tun.device.Close() device.tun.device.Close()
} }

View file

@ -2,6 +2,7 @@ package main
import ( import (
"bytes" "bytes"
"os"
"testing" "testing"
) )
@ -15,6 +16,10 @@ type DummyTUN struct {
events chan TUNEvent events chan TUNEvent
} }
func (tun *DummyTUN) File() *os.File {
return nil
}
func (tun *DummyTUN) Name() string { func (tun *DummyTUN) Name() string {
return tun.name return tun.name
} }
@ -67,7 +72,8 @@ func randDevice(t *testing.T) *Device {
t.Fatal(err) t.Fatal(err)
} }
tun, _ := CreateDummyTUN("dummy") tun, _ := CreateDummyTUN("dummy")
device := NewDevice(tun, LogLevelError) logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk) device.SetPrivateKey(sk)
return device return device
} }

View file

@ -2,10 +2,15 @@ package main
import ( import (
"fmt" "fmt"
"log"
"os" "os"
"os/signal" "os/signal"
"runtime" "runtime"
"strconv"
)
const (
ENV_WG_TUN_FD = "WG_TUN_FD"
ENV_WG_UAPI_FD = "WG_UAPI_FD"
) )
func printUsage() { func printUsage() {
@ -43,28 +48,6 @@ func main() {
interfaceName = os.Args[1] interfaceName = os.Args[1]
} }
// daemonize the process
if !foreground {
err := Daemonize()
if err != nil {
log.Println("Failed to daemonize:", err)
}
return
}
// increase number of go workers (for Go <1.5)
runtime.GOMAXPROCS(runtime.NumCPU())
// open TUN device
tun, err := CreateTUN(interfaceName)
if err != nil {
log.Println("Failed to create tun device:", err)
return
}
// get log level (default: info) // get log level (default: info)
logLevel := func() int { logLevel := func() int {
@ -79,25 +62,103 @@ func main() {
return LogLevelInfo return LogLevelInfo
}() }()
logger := NewLogger(
logLevel,
fmt.Sprintf("(%s) ", interfaceName),
)
logger.Debug.Println("Debug log enabled")
// open TUN device (or use supplied fd)
tun, err := func() (TUNDevice, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" {
return CreateTUN(interfaceName)
}
// construct tun device from supplied fd
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
if err != nil {
return nil, err
}
file := os.NewFile(uintptr(fd), "")
return CreateTUNFromFile(interfaceName, file)
}()
if err != nil {
logger.Error.Println("Failed to create TUN device:", err)
os.Exit(ExitSetupFailed)
}
// open UAPI file (or use supplied fd)
fileUAPI, err := func() (*os.File, error) {
uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
if uapiFdStr == "" {
return UAPIOpen(interfaceName)
}
// use supplied fd
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), ""), nil
}()
if err != nil {
logger.Error.Println("UAPI listen error:", err)
os.Exit(ExitSetupFailed)
return
}
// daemonize the process
if !foreground {
env := os.Environ()
env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD))
env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
attr := &os.ProcAttr{
Files: []*os.File{
nil, // stdin
nil, // stdout
nil, // stderr
tun.File(),
fileUAPI,
},
Dir: ".",
Env: env,
}
err = Daemonize(attr)
if err != nil {
logger.Error.Println("Failed to daemonize:", err)
os.Exit(ExitSetupFailed)
}
return
}
// increase number of go workers (for Go <1.5)
runtime.GOMAXPROCS(runtime.NumCPU())
// create wireguard device // create wireguard device
device := NewDevice(tun, logLevel) device := NewDevice(tun, logger)
logInfo := device.log.Info logger.Info.Println("Device started")
logError := device.log.Error
logInfo.Println("Starting device")
// start configuration lister // start uapi listener
uapi, err := NewUAPIListener(interfaceName)
if err != nil {
logError.Fatal("UAPI listen error:", err)
}
errs := make(chan error) errs := make(chan error)
term := make(chan os.Signal) term := make(chan os.Signal)
wait := device.WaitChannel() wait := device.WaitChannel()
uapi, err := UAPIListen(interfaceName, fileUAPI)
go func() { go func() {
for { for {
conn, err := uapi.Accept() conn, err := uapi.Accept()
@ -109,7 +170,7 @@ func main() {
} }
}() }()
logInfo.Println("UAPI listener started") logger.Info.Println("UAPI listener started")
// wait for program to terminate // wait for program to terminate
@ -122,9 +183,10 @@ func main() {
case <-errs: case <-errs:
} }
// clean up UAPI bind // clean up
uapi.Close() uapi.Close()
device.Close()
logInfo.Println("Closing") logger.Info.Println("Shutting down")
} }

View file

@ -21,6 +21,14 @@ func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.flag) == AtomicTrue return atomic.LoadInt32(&a.flag) == AtomicTrue
} }
func (a *AtomicBool) Swap(val bool) bool {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
return atomic.SwapInt32(&a.flag, flag) == AtomicTrue
}
func (a *AtomicBool) Set(val bool) { func (a *AtomicBool) Set(val bool) {
flag := AtomicFalse flag := AtomicFalse
if val { if val {

View file

@ -117,8 +117,8 @@ func TestNoiseHandshake(t *testing.T) {
var err error var err error
var out []byte var out []byte
var nonce [12]byte var nonce [12]byte
out = key1.send.aead.Seal(out, nonce[:], testMsg, nil) out = key1.send.Seal(out, nonce[:], testMsg, nil)
out, err = key2.receive.aead.Open(out[:0], nonce[:], out, nil) out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
assertNil(t, err) assertNil(t, err)
assertEqual(t, out, testMsg) assertEqual(t, out, testMsg)
}() }()
@ -128,8 +128,8 @@ func TestNoiseHandshake(t *testing.T) {
var err error var err error
var out []byte var out []byte
var nonce [12]byte var nonce [12]byte
out = key2.send.aead.Seal(out, nonce[:], testMsg, nil) out = key2.send.Seal(out, nonce[:], testMsg, nil)
out, err = key1.receive.aead.Open(out[:0], nonce[:], out, nil) out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
assertNil(t, err) assertNil(t, err)
assertEqual(t, out, testMsg) assertEqual(t, out, testMsg)
}() }()

View file

@ -4,7 +4,6 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"net"
"sync" "sync"
"time" "time"
) )
@ -16,7 +15,7 @@ type Peer struct {
keyPairs KeyPairs keyPairs KeyPairs
handshake Handshake handshake Handshake
device *Device device *Device
endpoint *net.UDPAddr endpoint Endpoint
stats struct { stats struct {
txBytes uint64 // bytes send to peer (endpoint) txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer rxBytes uint64 // bytes received from peer
@ -106,6 +105,10 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic) handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
handshake.mutex.Unlock() handshake.mutex.Unlock()
// reset endpoint
peer.endpoint = nil
// prepare queuing // prepare queuing
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
@ -130,11 +133,31 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil return peer, nil
} }
func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.mutex.RLock()
defer peer.device.net.mutex.RUnlock()
peer.mutex.RLock()
defer peer.mutex.RUnlock()
if peer.endpoint == nil {
return errors.New("No known endpoint for peer")
}
return peer.device.net.bind.Send(buffer, peer.endpoint)
}
/* Returns a short string identification for logging
*/
func (peer *Peer) String() string { func (peer *Peer) String() string {
if peer.endpoint == nil {
return fmt.Sprintf(
"peer(%d unknown %s)",
peer.id,
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
)
}
return fmt.Sprintf( return fmt.Sprintf(
"peer(%d %s %s)", "peer(%d %s %s)",
peer.id, peer.id,
peer.endpoint.String(), peer.endpoint.DstToString(),
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
) )
} }

View file

@ -13,19 +13,20 @@ import (
) )
type QueueHandshakeElement struct { type QueueHandshakeElement struct {
msgType uint32 msgType uint32
packet []byte packet []byte
buffer *[MaxMessageSize]byte endpoint Endpoint
source *net.UDPAddr buffer *[MaxMessageSize]byte
} }
type QueueInboundElement struct { type QueueInboundElement struct {
dropped int32 dropped int32
mutex sync.Mutex mutex sync.Mutex
buffer *[MaxMessageSize]byte buffer *[MaxMessageSize]byte
packet []byte packet []byte
counter uint64 counter uint64
keyPair *KeyPair keyPair *KeyPair
endpoint Endpoint
} }
func (elem *QueueInboundElement) Drop() { func (elem *QueueInboundElement) Drop() {
@ -92,130 +93,122 @@ func (device *Device) addToHandshakeQueue(
} }
} }
func (device *Device) RoutineReceiveIncomming() { func (device *Device) RoutineReceiveIncomming(IP int, bind Bind) {
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, receive incomming, started") logDebug.Println("Routine, receive incomming, IP version:", IP)
for { for {
// wait for new conn // receive datagrams until conn is closed
logDebug.Println("Waiting for udp socket") buffer := device.GetMessageBuffer()
select { var (
case <-device.signal.stop: err error
return size int
endpoint Endpoint
)
case <-device.signal.newUDPConn: for {
// fetch connection // read next datagram
device.net.mutex.RLock() switch IP {
conn := device.net.conn case ipv4.Version:
device.net.mutex.RUnlock() size, endpoint, err = bind.ReceiveIPv4(buffer[:])
if conn == nil { case ipv6.Version:
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
default:
return
}
if err != nil {
break
}
if size < MinMessageSize {
continue continue
} }
logDebug.Println("Listening for inbound packets") // check size of packet
// receive datagrams until conn is closed packet := buffer[:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
buffer := device.GetMessageBuffer() var okay bool
for { switch msgType {
// read next datagram // check if transport
size, raddr, err := conn.ReadFromUDP(buffer[:]) case MessageTransportType:
if err != nil { // check size
break
}
if size < MinMessageSize { if len(packet) < MessageTransportType {
continue continue
} }
// check size of packet // lookup key pair
packet := buffer[:size] receiver := binary.LittleEndian.Uint32(
msgType := binary.LittleEndian.Uint32(packet[:4]) packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
var okay bool value := device.indices.Lookup(receiver)
keyPair := value.keyPair
switch msgType { if keyPair == nil {
// check if transport
case MessageTransportType:
// check size
if len(packet) < MessageTransportType {
continue
}
// lookup key pair
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indices.Lookup(receiver)
keyPair := value.keyPair
if keyPair == nil {
continue
}
// check key-pair expiry
if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
}
// create work element
peer := value.peer
elem := &QueueInboundElement{
packet: packet,
buffer: buffer,
keyPair: keyPair,
dropped: AtomicFalse,
}
elem.mutex.Lock()
// add to decryption queues
device.addToDecryptionQueue(device.queue.decryption, elem)
device.addToInboundQueue(peer.queue.inbound, elem)
buffer = device.GetMessageBuffer()
continue continue
// otherwise it is a handshake related packet
case MessageInitiationType:
okay = len(packet) == MessageInitiationSize
case MessageResponseType:
okay = len(packet) == MessageResponseSize
case MessageCookieReplyType:
okay = len(packet) == MessageCookieReplySize
} }
if okay { // check key-pair expiry
device.addToHandshakeQueue(
device.queue.handshake, if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
QueueHandshakeElement{ continue
msgType: msgType,
buffer: buffer,
packet: packet,
source: raddr,
},
)
buffer = device.GetMessageBuffer()
} }
// create work element
peer := value.peer
elem := &QueueInboundElement{
packet: packet,
buffer: buffer,
keyPair: keyPair,
dropped: AtomicFalse,
endpoint: endpoint,
}
elem.mutex.Lock()
// add to decryption queues
device.addToDecryptionQueue(device.queue.decryption, elem)
device.addToInboundQueue(peer.queue.inbound, elem)
buffer = device.GetMessageBuffer()
continue
// otherwise it is a fixed size & handshake related packet
case MessageInitiationType:
okay = len(packet) == MessageInitiationSize
case MessageResponseType:
okay = len(packet) == MessageResponseSize
case MessageCookieReplyType:
okay = len(packet) == MessageCookieReplySize
}
if okay {
device.addToHandshakeQueue(
device.queue.handshake,
QueueHandshakeElement{
msgType: msgType,
buffer: buffer,
packet: packet,
endpoint: endpoint,
},
)
buffer = device.GetMessageBuffer()
} }
} }
} }
@ -293,8 +286,6 @@ func (device *Device) RoutineHandshake() {
// unmarshal packet // unmarshal packet
logDebug.Println("Process cookie reply from:", elem.source.String())
var reply MessageCookieReply var reply MessageCookieReply
reader := bytes.NewReader(elem.packet) reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &reply) err := binary.Read(reader, binary.LittleEndian, &reply)
@ -321,15 +312,25 @@ func (device *Device) RoutineHandshake() {
return return
} }
// endpoints destination address is the source of the datagram
srcBytes := elem.endpoint.DstToBytes()
if device.IsUnderLoad() { if device.IsUnderLoad() {
if !device.mac.CheckMAC2(elem.packet, elem.source) {
// verify MAC2 field
if !device.mac.CheckMAC2(elem.packet, srcBytes) {
// construct cookie reply // construct cookie reply
logDebug.Println("Sending cookie reply to:", elem.source.String()) logDebug.Println(
"Sending cookie reply to:",
elem.endpoint.DstToString(),
)
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" sender := binary.LittleEndian.Uint32(elem.packet[4:8])
reply, err := device.mac.CreateReply(elem.packet, sender, elem.source) reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
if err != nil { if err != nil {
logError.Println("Failed to create cookie reply:", err) logError.Println("Failed to create cookie reply:", err)
return return
@ -339,17 +340,16 @@ func (device *Device) RoutineHandshake() {
writer := bytes.NewBuffer(temp[:0]) writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, reply) binary.Write(writer, binary.LittleEndian, reply)
_, err = device.net.conn.WriteToUDP( device.net.bind.Send(writer.Bytes(), elem.endpoint)
writer.Bytes(),
elem.source,
)
if err != nil { if err != nil {
logDebug.Println("Failed to send cookie reply:", err) logDebug.Println("Failed to send cookie reply:", err)
} }
continue continue
} }
if !device.ratelimiter.Allow(elem.source.IP) { // check ratelimiter
if !device.ratelimiter.Allow(elem.endpoint.DstIP()) {
continue continue
} }
} }
@ -380,8 +380,7 @@ func (device *Device) RoutineHandshake() {
if peer == nil { if peer == nil {
logInfo.Println( logInfo.Println(
"Recieved invalid initiation message from", "Recieved invalid initiation message from",
elem.source.IP.String(), elem.endpoint.DstToString(),
elem.source.Port,
) )
continue continue
} }
@ -392,10 +391,9 @@ func (device *Device) RoutineHandshake() {
peer.TimerAnyAuthenticatedPacketReceived() peer.TimerAnyAuthenticatedPacketReceived()
// update endpoint // update endpoint
// TODO: Discover destination address also, only update on change
peer.mutex.Lock() peer.mutex.Lock()
peer.endpoint = elem.source peer.endpoint = elem.endpoint
peer.mutex.Unlock() peer.mutex.Unlock()
// create response // create response
@ -418,9 +416,11 @@ func (device *Device) RoutineHandshake() {
// send response // send response
_, err = peer.SendBuffer(packet) err = peer.SendBuffer(packet)
if err == nil { if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal() peer.TimerAnyAuthenticatedPacketTraversal()
} else {
logError.Println("Failed to send response to:", peer.String(), err)
} }
case MessageResponseType: case MessageResponseType:
@ -441,12 +441,17 @@ func (device *Device) RoutineHandshake() {
if peer == nil { if peer == nil {
logInfo.Println( logInfo.Println(
"Recieved invalid response message from", "Recieved invalid response message from",
elem.source.IP.String(), elem.endpoint.DstToString(),
elem.source.Port,
) )
continue continue
} }
// update endpoint
peer.mutex.Lock()
peer.endpoint = elem.endpoint
peer.mutex.Unlock()
logDebug.Println("Received handshake initation from", peer) logDebug.Println("Received handshake initation from", peer)
peer.TimerEphemeralKeyCreated() peer.TimerEphemeralKeyCreated()
@ -515,6 +520,12 @@ func (peer *Peer) RoutineSequentialReceiver() {
} }
kp.mutex.Unlock() kp.mutex.Unlock()
// update endpoint
peer.mutex.Lock()
peer.endpoint = elem.endpoint
peer.mutex.Unlock()
// check for keep-alive // check for keep-alive
if len(elem.packet) == 0 { if len(elem.packet) == 0 {
@ -546,7 +557,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.routingTable.LookupIPv4(src) != peer { if device.routingTable.LookupIPv4(src) != peer {
logInfo.Println("Packet with unallowed source IP from", peer.String()) logInfo.Println(
"IPv4 packet with unallowed source address from",
peer.String(),
)
continue continue
} }
@ -571,7 +585,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routingTable.LookupIPv6(src) != peer { if device.routingTable.LookupIPv6(src) != peer {
logInfo.Println("Packet with unallowed source IP from", peer.String()) logInfo.Println(
"IPv6 packet with unallowed source address from",
peer.String(),
)
continue continue
} }
@ -580,7 +597,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue continue
} }
// write to tun // write to tun device
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
_, err := device.tun.device.Write(elem.packet) _, err := device.tun.device.Write(elem.packet)

View file

@ -2,7 +2,6 @@ package main
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
@ -105,26 +104,6 @@ func addToEncryptionQueue(
} }
} }
func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
peer.device.net.mutex.RLock()
defer peer.device.net.mutex.RUnlock()
peer.mutex.RLock()
defer peer.mutex.RUnlock()
endpoint := peer.endpoint
if endpoint == nil {
return 0, errors.New("No known endpoint for peer")
}
conn := peer.device.net.conn
if conn == nil {
return 0, errors.New("No UDP socket for device")
}
return conn.WriteToUDP(buffer, endpoint)
}
/* Reads packets from the TUN and inserts /* Reads packets from the TUN and inserts
* into nonce queue for peer * into nonce queue for peer
* *
@ -343,7 +322,7 @@ func (peer *Peer) RoutineSequentialSender() {
// send message and return buffer to pool // send message and return buffer to pool
length := uint64(len(elem.packet)) length := uint64(len(elem.packet))
_, err := peer.SendBuffer(elem.packet) err := peer.SendBuffer(elem.packet)
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
if err != nil { if err != nil {
logDebug.Println("Failed to send authenticated packet to peer", peer.String()) logDebug.Println("Failed to send authenticated packet to peer", peer.String())

View file

@ -20,6 +20,14 @@
# wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1 # wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1
# interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further # interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further
# details on how this is accomplished. # details on how this is accomplished.
# This code is ported to the WireGuard-Go directly from the kernel project.
#
# Please ensure that you have installed the newest version of the WireGuard
# tools from the WireGuard project and before running these tests as:
#
# ./netns.sh <path to wireguard-go>
set -e set -e
exec 3>&1 exec 3>&1
@ -27,8 +35,8 @@ export WG_HIDE_KEYS=never
netns0="wg-test-$$-0" netns0="wg-test-$$-0"
netns1="wg-test-$$-1" netns1="wg-test-$$-1"
netns2="wg-test-$$-2" netns2="wg-test-$$-2"
program="../wireguard-go" program=$1
export LOG_LEVEL="error" export LOG_LEVEL="info"
pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; } pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; }
pp() { pretty "" "$*"; "$@"; } pp() { pretty "" "$*"; "$@"; }
@ -72,13 +80,11 @@ pp ip netns add $netns2
ip0 link set up dev lo ip0 link set up dev lo
# ip0 link add dev wg1 type wireguard # ip0 link add dev wg1 type wireguard
n0 $program -f wg1 & n0 $program wg1
sleep 1
ip0 link set wg1 netns $netns1 ip0 link set wg1 netns $netns1
# ip0 link add dev wg1 type wireguard # ip0 link add dev wg1 type wireguard
n0 $program -f wg2 & n0 $program wg2
sleep 1
ip0 link set wg2 netns $netns2 ip0 link set wg2 netns $netns2
key1="$(pp wg genkey)" key1="$(pp wg genkey)"
@ -185,14 +191,14 @@ ip0 -4 addr del 127.0.0.1/8 dev lo
ip0 -4 addr add 127.212.121.99/8 dev lo ip0 -4 addr add 127.212.121.99/8 dev lo
n0 wg set wg1 listen-port 9999 n0 wg set wg1 listen-port 9999
n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000 n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000
n1 ping6 -W 1 -c 1 fd00::20000 n1 ping6 -W 1 -c 1 fd00::2
[[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]] [[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]]
# Test using IPv6 that roaming works # Test using IPv6 that roaming works
n1 wg set wg1 listen-port 9998 n1 wg set wg1 listen-port 9998
n1 wg set wg1 peer "$pub2" endpoint [::1]:20000 n1 wg set wg1 peer "$pub2" endpoint [::1]:20000
n1 ping -W 1 -c 1 192.168.241.2 n1 ping -W 1 -c 1 192.168.241.2
[[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]] [[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]]
# Test that crypto-RP filter works # Test that crypto-RP filter works
n1 wg set wg1 peer "$pub2" allowed-ips 192.168.241.0/24 n1 wg set wg1 peer "$pub2" allowed-ips 192.168.241.0/24
@ -212,7 +218,7 @@ n2 ncat -u 192.168.241.1 1111 <<<"X"
! read -r -N 1 -t 1 out <&4 ! read -r -N 1 -t 1 out <&4
kill $nmap_pid kill $nmap_pid
n0 wg set wg1 peer "$more_specific_key" remove n0 wg set wg1 peer "$more_specific_key" remove
[[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]] [[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]]
ip1 link del wg1 ip1 link del wg1
ip2 link del wg2 ip2 link del wg2
@ -263,7 +269,7 @@ n0 iptables -t nat -A POSTROUTING -s 192.168.1.0/24 -d 10.0.0.0/24 -j SNAT --to
n0 wg set wg1 peer "$pub2" endpoint 10.0.0.100:20000 persistent-keepalive 1 n0 wg set wg1 peer "$pub2" endpoint 10.0.0.100:20000 persistent-keepalive 1
n1 ping -W 1 -c 1 192.168.241.2 n1 ping -W 1 -c 1 192.168.241.2
n2 ping -W 1 -c 1 192.168.241.1 n2 ping -W 1 -c 1 192.168.241.1
[[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]] [[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
# Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`). # Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`).
pp sleep 3 pp sleep 3
n2 ping -W 1 -c 1 192.168.241.1 n2 ping -W 1 -c 1 192.168.241.1
@ -289,7 +295,7 @@ ip2 link del wg2
# ip1 link add dev wg1 type wireguard # ip1 link add dev wg1 type wireguard
# ip2 link add dev wg1 type wireguard # ip2 link add dev wg1 type wireguard
n1 $program wg1 n1 $program wg1
n2 $program wg1 n2 $program wg2
configure_peers configure_peers
@ -336,17 +342,83 @@ waitiface $netns1 veth1
waitiface $netns2 veth2 waitiface $netns2 veth2
n0 wg set wg2 peer "$pub1" endpoint 10.0.0.1:10000 n0 wg set wg2 peer "$pub1" endpoint 10.0.0.1:10000
n2 ping -W 1 -c 1 192.168.241.1 n2 ping -W 1 -c 1 192.168.241.1
[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]] [[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::1]:10000 n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::1]:10000
n2 ping -W 1 -c 1 192.168.241.1 n2 ping -W 1 -c 1 192.168.241.1
[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]] [[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]]
n0 wg set wg2 peer "$pub1" endpoint 10.0.0.2:10000 n0 wg set wg2 peer "$pub1" endpoint 10.0.0.2:10000
n2 ping -W 1 -c 1 192.168.241.1 n2 ping -W 1 -c 1 192.168.241.1
[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]] [[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]]
n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::2]:10000 n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::2]:10000
n2 ping -W 1 -c 1 192.168.241.1 n2 ping -W 1 -c 1 192.168.241.1
[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]] [[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]]
ip1 link del veth1 ip1 link del veth1
ip1 link del wg1 ip1 link del wg1
ip2 link del wg2 ip2 link del wg2
# Test that Netlink/IPC is working properly by doing things that usually cause split responses
n0 $program wg0
sleep 5
config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" )
for a in {1..255}; do
for b in {0..255}; do
config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" )
done
done
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
i=0
for ip in $(n0 wg show wg0 allowed-ips); do
((++i))
done
((i == 255*256*2+1))
ip0 link del wg0
n0 $program wg0
config=( "[Interface]" "PrivateKey=$(wg genkey)" )
for a in {1..40}; do
config+=( "[Peer]" "PublicKey=$(wg genkey)" )
for b in {1..52}; do
config+=( "AllowedIPs=$a.$b.0.0/16" )
done
done
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
i=0
while read -r line; do
j=0
for ip in $line; do
((++j))
done
((j == 53))
((++i))
done < <(n0 wg show wg0 allowed-ips)
((i == 40))
ip0 link del wg0
n0 $program wg0
config=( )
for i in {1..29}; do
config+=( "[Peer]" "PublicKey=$(wg genkey)" )
done
config+=( "[Peer]" "PublicKey=$(wg genkey)" "AllowedIPs=255.2.3.4/32,abcd::255/128" )
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
n0 wg showconf wg0 > /dev/null
ip0 link del wg0
! n0 wg show doesnotexist || false
declare -A objects
while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
[[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue
objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"
done < /dev/kmsg
alldeleted=1
for object in "${!objects[@]}"; do
if [[ ${objects["$object"]} != *createddestroyed ]]; then
echo "Error: $object: merely ${objects["$object"]}" >&3
alldeleted=0
fi
done
[[ $alldeleted -eq 1 ]]
pretty "" "Objects that were created were also destroyed."

View file

@ -279,34 +279,31 @@ func (peer *Peer) RoutineHandshakeInitiator() {
break AttemptHandshakes break AttemptHandshakes
} }
jitter := time.Millisecond * time.Duration(rand.Uint32()%334) // marshal handshake message
// marshal and send
writer := bytes.NewBuffer(temp[:0]) writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, msg) binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes() packet := writer.Bytes()
peer.mac.AddMacs(packet) peer.mac.AddMacs(packet)
_, err = peer.SendBuffer(packet) // send to endpoint
if err != nil {
err = peer.SendBuffer(packet)
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
timeout := time.NewTimer(RekeyTimeout + jitter)
if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal()
logDebug.Println(
"Handshake initiation attempt",
attempts, "sent to", peer.String(),
)
} else {
logError.Println( logError.Println(
"Failed to send handshake initiation message to", "Failed to send handshake initiation message to",
peer.String(), ":", err, peer.String(), ":", err,
) )
continue
} }
peer.TimerAnyAuthenticatedPacketTraversal()
// set handshake timeout
timeout := time.NewTimer(RekeyTimeout + jitter)
logDebug.Println(
"Handshake initiation attempt",
attempts, "sent to", peer.String(),
)
// wait for handshake or timeout // wait for handshake or timeout
select { select {

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"os"
"sync/atomic" "sync/atomic"
) )
@ -15,6 +16,7 @@ const (
) )
type TUNDevice interface { type TUNDevice interface {
File() *os.File // returns the file descriptor of the device
Read([]byte) (int, error) // read a packet from the device (without any additional headers) Read([]byte) (int, error) // read a packet from the device (without any additional headers)
Write([]byte) (int, error) // writes a packet to the device (without any additional headers) Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
MTU() (int, error) // returns the MTU of the device MTU() (int, error) // returns the MTU of the device
@ -47,7 +49,7 @@ func (device *Device) RoutineTUNEventReader() {
if !device.tun.isUp.Get() { if !device.tun.isUp.Get() {
logInfo.Println("Interface set up") logInfo.Println("Interface set up")
device.tun.isUp.Set(true) device.tun.isUp.Set(true)
updateUDPConn(device) UpdateUDPListener(device)
} }
} }
@ -55,7 +57,7 @@ func (device *Device) RoutineTUNEventReader() {
if device.tun.isUp.Get() { if device.tun.isUp.Get() {
logInfo.Println("Interface set down") logInfo.Println("Interface set down")
device.tun.isUp.Set(false) device.tun.isUp.Set(false)
closeUDPConn(device) CloseUDPListener(device)
} }
} }
} }

View file

@ -56,6 +56,10 @@ type NativeTun struct {
events chan TUNEvent // device related events events chan TUNEvent // device related events
} }
func (tun *NativeTun) File() *os.File {
return tun.fd
}
func (tun *NativeTun) RoutineNetlinkListener() { func (tun *NativeTun) RoutineNetlinkListener() {
sock := int(C.bind_rtmgrp()) sock := int(C.bind_rtmgrp())
if sock < 0 { if sock < 0 {
@ -222,7 +226,7 @@ func (tun *NativeTun) MTU() (int, error) {
val := binary.LittleEndian.Uint32(ifr[16:20]) val := binary.LittleEndian.Uint32(ifr[16:20])
if val >= (1 << 31) { if val >= (1 << 31) {
return int(val-(1<<31)) - (1 << 31), nil return int(toInt32(val)), nil
} }
return int(val), nil return int(val), nil
} }
@ -248,6 +252,29 @@ func (tun *NativeTun) Close() error {
return nil return nil
} }
func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) {
device := &NativeTun{
fd: fd,
name: name,
events: make(chan TUNEvent, 5),
errors: make(chan error, 5),
}
// start event listener
var err error
device.index, err = getIFIndex(device.name)
if err != nil {
return nil, err
}
go device.RoutineNetlinkListener()
// set default MTU
return device, device.setMTU(DefaultMTU)
}
func CreateTUN(name string) (TUNDevice, error) { func CreateTUN(name string) (TUNDevice, error) {
// open clone device // open clone device

View file

@ -39,9 +39,10 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send("private_key=" + device.privateKey.ToHex()) send("private_key=" + device.privateKey.ToHex())
} }
if device.net.addr != nil { if device.net.port != 0 {
send(fmt.Sprintf("listen_port=%d", device.net.addr.Port)) send(fmt.Sprintf("listen_port=%d", device.net.port))
} }
if device.net.fwmark != 0 { if device.net.fwmark != 0 {
send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
} }
@ -53,7 +54,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send("public_key=" + peer.handshake.remoteStatic.ToHex()) send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex()) send("preshared_key=" + peer.handshake.presharedKey.ToHex())
if peer.endpoint != nil { if peer.endpoint != nil {
send("endpoint=" + peer.endpoint.String()) send("endpoint=" + peer.endpoint.DstToString())
} }
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
@ -134,56 +135,38 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "listen_port": case "listen_port":
port, err := strconv.ParseUint(value, 10, 16) port, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {
logError.Println("Failed to set listen_port:", err) logError.Println("Failed to parse listen_port:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
device.net.port = uint16(port)
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) if err := UpdateUDPListener(device); err != nil {
if err != nil {
logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorInvalid}
}
device.net.mutex.Lock()
device.net.addr = addr
device.net.mutex.Unlock()
err = updateUDPConn(device)
if err != nil {
logError.Println("Failed to set listen_port:", err) logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorPortInUse} return &IPCError{Code: ipcErrorPortInUse}
} }
// TODO: Clear source address of all peers
case "fwmark": case "fwmark":
fwmark, err := strconv.ParseUint(value, 10, 32)
// parse fwmark field
fwmark, err := func() (uint32, error) {
if value == "" {
return 0, nil
}
mark, err := strconv.ParseUint(value, 10, 32)
return uint32(mark), err
}()
if err != nil { if err != nil {
logError.Println("Invalid fwmark", err) logError.Println("Invalid fwmark", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
device.net.mutex.Lock() device.net.mutex.Lock()
if fwmark > 0 || device.net.fwmark > 0 { device.net.fwmark = uint32(fwmark)
device.net.fwmark = uint32(fwmark)
err := setMark(
device.net.conn,
device.net.fwmark,
)
if err != nil {
logError.Println("Failed to set fwmark:", err)
device.net.mutex.Unlock()
return &IPCError{Code: ipcErrorIO}
}
// TODO: Clear source address of all peers
}
device.net.mutex.Unlock() device.net.mutex.Unlock()
case "public_key": case "public_key":
// switch to peer configuration // switch to peer configuration
deviceConfig = false deviceConfig = false
case "replace_peers": case "replace_peers":
@ -218,7 +201,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.mutex.RLock() device.mutex.RLock()
if device.publicKey.Equals(pubKey) { if device.publicKey.Equals(pubKey) {
// create dummy instance // create dummy instance (not added to device)
peer = &Peer{} peer = &Peer{}
dummy = true dummy = true
@ -244,6 +227,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
} }
case "remove": case "remove":
// remove currently selected peer from device
if value != "true" { if value != "true" {
logError.Println("Failed to set remove, invalid value:", value) logError.Println("Failed to set remove, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
@ -256,6 +242,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
dummy = true dummy = true
case "preshared_key": case "preshared_key":
// update PSK
peer.mutex.Lock() peer.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value) err := peer.handshake.presharedKey.FromHex(value)
peer.mutex.Unlock() peer.mutex.Unlock()
@ -265,15 +254,25 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
} }
case "endpoint": case "endpoint":
addr, err := parseEndpoint(value)
// set endpoint destination
err := func() error {
peer.mutex.Lock()
defer peer.mutex.Unlock()
endpoint, err := CreateEndpoint(value)
if err != nil {
return err
}
peer.endpoint = endpoint
signalSend(peer.signal.handshakeReset)
return nil
}()
if err != nil { if err != nil {
logError.Println("Failed to set endpoint:", value) logError.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
peer.mutex.Lock()
peer.endpoint = addr
peer.mutex.Unlock()
signalSend(peer.signal.handshakeReset)
case "persistent_keepalive_interval": case "persistent_keepalive_interval":

View file

@ -10,12 +10,12 @@ import (
) )
const ( const (
ipcErrorIO = -int64(unix.EIO) ipcErrorIO = -int64(unix.EIO)
ipcErrorProtocol = -int64(unix.EPROTO) ipcErrorProtocol = -int64(unix.EPROTO)
ipcErrorInvalid = -int64(unix.EINVAL) ipcErrorInvalid = -int64(unix.EINVAL)
ipcErrorPortInUse = -int64(unix.EADDRINUSE) ipcErrorPortInUse = -int64(unix.EADDRINUSE)
socketDirectory = "/var/run/wireguard" socketDirectory = "/var/run/wireguard"
socketName = "%s.sock" socketName = "%s.sock"
) )
type UAPIListener struct { type UAPIListener struct {
@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr {
return nil return nil
} }
func connectUnixSocket(path string) (net.Listener, error) { func UAPIListen(name string, file *os.File) (net.Listener, error) {
// attempt inital connection // wrap file in listener
listener, err := net.Listen("unix", path) listener, err := net.FileListener(file)
if err == nil {
return listener, nil
}
// check if active
_, err = net.Dial("unix", path)
if err == nil {
return nil, errors.New("Unix socket in use")
}
// attempt cleanup
err = os.Remove(path)
if err != nil {
return nil, err
}
return net.Listen("unix", path)
}
func NewUAPIListener(name string) (net.Listener, error) {
// check if path exist
err := os.MkdirAll(socketDirectory, 077)
if err != nil && !os.IsExist(err) {
return nil, err
}
// open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
listener, err := connectUnixSocket(socketPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) {
// watch for deletion of socket // watch for deletion of socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
uapi.inotifyFd, err = unix.InotifyInit() uapi.inotifyFd, err = unix.InotifyInit()
if err != nil { if err != nil {
return nil, err return nil, err
@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) {
go func(l *UAPIListener) { go func(l *UAPIListener) {
var buff [4096]byte var buff [4096]byte
for { for {
unix.Read(uapi.inotifyFd, buff[:]) // start with lstat to avoid race condition
if _, err := os.Lstat(socketPath); os.IsNotExist(err) { if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
l.connErr <- err l.connErr <- err
return return
} }
unix.Read(uapi.inotifyFd, buff[:])
} }
}(uapi) }(uapi)
@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) {
return uapi, nil return uapi, nil
} }
func UAPIOpen(name string) (*os.File, error) {
// check if path exist
err := os.MkdirAll(socketDirectory, 0600)
if err != nil && !os.IsExist(err) {
return nil, err
}
// open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
listener, err := func() (*net.UnixListener, error) {
// initial connection attempt
listener, err := net.ListenUnix("unix", addr)
if err == nil {
return listener, nil
}
// check if socket already active
_, err = net.Dial("unix", socketPath)
if err == nil {
return nil, errors.New("unix socket in use")
}
// cleanup & attempt again
err = os.Remove(socketPath)
if err != nil {
return nil, err
}
return net.ListenUnix("unix", addr)
}()
if err != nil {
return nil, err
}
return listener.File()
}