Begin work on source address caching (linux)
This commit is contained in:
		
							parent
							
								
									c545d63bb9
								
							
						
					
					
						commit
						eefa47b0f9
					
				
					 5 changed files with 273 additions and 11 deletions
				
			
		
							
								
								
									
										22
									
								
								src/conn.go
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								src/conn.go
									
									
									
									
									
								
							|  | @ -1,9 +1,31 @@ | ||||||
| package main | package main | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"errors" | ||||||
| 	"net" | 	"net" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | func parseEndpoint(s string) (*net.UDPAddr, error) { | ||||||
|  | 
 | ||||||
|  | 	// ensure that the host is an IP address
 | ||||||
|  | 
 | ||||||
|  | 	host, _, err := net.SplitHostPort(s) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if ip := net.ParseIP(host); ip == nil { | ||||||
|  | 		return nil, errors.New("Failed to parse IP address: " + host) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// parse address and port
 | ||||||
|  | 
 | ||||||
|  | 	addr, err := net.ResolveUDPAddr("udp", s) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return addr, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func updateUDPConn(device *Device) error { | func updateUDPConn(device *Device) error { | ||||||
| 	netc := &device.net | 	netc := &device.net | ||||||
| 	netc.mutex.Lock() | 	netc.mutex.Lock() | ||||||
|  |  | ||||||
|  | @ -1,10 +1,253 @@ | ||||||
|  | /* Copyright 2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. | ||||||
|  |  * | ||||||
|  |  * This implements userspace semantics of "sticky sockets", modeled after | ||||||
|  |  * WireGuard's kernelspace implementation. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
| package main | package main | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"errors" | ||||||
| 	"golang.org/x/sys/unix" | 	"golang.org/x/sys/unix" | ||||||
| 	"net" | 	"net" | ||||||
|  | 	"strconv" | ||||||
|  | 	"unsafe" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | /* Supports source address caching | ||||||
|  |  * | ||||||
|  |  * It is important that the endpoint is only updated after the packet content has been authenticated. | ||||||
|  |  * | ||||||
|  |  * Currently there is no way to achieve this within the net package: | ||||||
|  |  * See e.g. https://github.com/golang/go/issues/17930
 | ||||||
|  |  */ | ||||||
|  | type Endpoint struct { | ||||||
|  | 	// source (selected based on dst type)
 | ||||||
|  | 	// (could use RawSockaddrAny and unsafe)
 | ||||||
|  | 	srcIPv6 unix.RawSockaddrInet6 | ||||||
|  | 	srcIPv4 unix.RawSockaddrInet4 | ||||||
|  | 	srcIf4  int32 | ||||||
|  | 
 | ||||||
|  | 	dst unix.RawSockaddrAny | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func zoneToUint32(zone string) (uint32, error) { | ||||||
|  | 	if zone == "" { | ||||||
|  | 		return 0, nil | ||||||
|  | 	} | ||||||
|  | 	if intr, err := net.InterfaceByName(zone); err == nil { | ||||||
|  | 		return uint32(intr.Index), nil | ||||||
|  | 	} | ||||||
|  | 	n, err := strconv.ParseUint(zone, 10, 32) | ||||||
|  | 	return uint32(n), err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (end *Endpoint) ClearSrc() { | ||||||
|  | 	end.srcIf4 = 0 | ||||||
|  | 	end.srcIPv4 = unix.RawSockaddrInet4{} | ||||||
|  | 	end.srcIPv6 = unix.RawSockaddrInet6{} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (end *Endpoint) Set(s string) error { | ||||||
|  | 	addr, err := parseEndpoint(s) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	ipv6 := addr.IP.To16() | ||||||
|  | 	if ipv6 != nil { | ||||||
|  | 		zone, err := zoneToUint32(addr.Zone) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst)) | ||||||
|  | 		ptr.Family = unix.AF_INET6 | ||||||
|  | 		ptr.Port = uint16(addr.Port) | ||||||
|  | 		ptr.Flowinfo = 0 | ||||||
|  | 		ptr.Scope_id = zone | ||||||
|  | 		copy(ptr.Addr[:], ipv6[:]) | ||||||
|  | 		end.ClearSrc() | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	ipv4 := addr.IP.To4() | ||||||
|  | 	if ipv4 != nil { | ||||||
|  | 		ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) | ||||||
|  | 		ptr.Family = unix.AF_INET | ||||||
|  | 		ptr.Port = uint16(addr.Port) | ||||||
|  | 		ptr.Zero = [8]byte{} | ||||||
|  | 		copy(ptr.Addr[:], ipv4) | ||||||
|  | 		end.ClearSrc() | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return errors.New("Failed to recognize IP address format") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func send6(sock uintptr, end *Endpoint, buff []byte) error { | ||||||
|  | 	var iovec unix.Iovec | ||||||
|  | 
 | ||||||
|  | 	iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) | ||||||
|  | 	iovec.SetLen(len(buff)) | ||||||
|  | 
 | ||||||
|  | 	cmsg := struct { | ||||||
|  | 		cmsghdr unix.Cmsghdr | ||||||
|  | 		pktinfo unix.Inet6Pktinfo | ||||||
|  | 	}{ | ||||||
|  | 		unix.Cmsghdr{ | ||||||
|  | 			Level: unix.IPPROTO_IPV6, | ||||||
|  | 			Type:  unix.IPV6_PKTINFO, | ||||||
|  | 			Len:   unix.SizeofInet6Pktinfo, | ||||||
|  | 		}, | ||||||
|  | 		unix.Inet6Pktinfo{ | ||||||
|  | 			Addr:    end.srcIPv6.Addr, | ||||||
|  | 			Ifindex: end.srcIPv6.Scope_id, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	msghdr := unix.Msghdr{ | ||||||
|  | 		Iov:     &iovec, | ||||||
|  | 		Iovlen:  1, | ||||||
|  | 		Name:    (*byte)(unsafe.Pointer(&end.dst)), | ||||||
|  | 		Namelen: unix.SizeofSockaddrInet6, | ||||||
|  | 		Control: (*byte)(unsafe.Pointer(&cmsg)), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) | ||||||
|  | 
 | ||||||
|  | 	// sendmsg(sock, &msghdr, 0)
 | ||||||
|  | 
 | ||||||
|  | 	_, _, errno := unix.Syscall( | ||||||
|  | 		unix.SYS_SENDMSG, | ||||||
|  | 		sock, | ||||||
|  | 		uintptr(unsafe.Pointer(&msghdr)), | ||||||
|  | 		0, | ||||||
|  | 	) | ||||||
|  | 	if errno == unix.EINVAL { | ||||||
|  | 		end.ClearSrc() | ||||||
|  | 	} | ||||||
|  | 	return errno | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func send4(sock uintptr, end *Endpoint, buff []byte) error { | ||||||
|  | 	var iovec unix.Iovec | ||||||
|  | 
 | ||||||
|  | 	iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) | ||||||
|  | 	iovec.SetLen(len(buff)) | ||||||
|  | 
 | ||||||
|  | 	cmsg := struct { | ||||||
|  | 		cmsghdr unix.Cmsghdr | ||||||
|  | 		pktinfo unix.Inet4Pktinfo | ||||||
|  | 	}{ | ||||||
|  | 		unix.Cmsghdr{ | ||||||
|  | 			Level: unix.IPPROTO_IP, | ||||||
|  | 			Type:  unix.IP_PKTINFO, | ||||||
|  | 			Len:   unix.SizeofInet6Pktinfo, | ||||||
|  | 		}, | ||||||
|  | 		unix.Inet4Pktinfo{ | ||||||
|  | 			Spec_dst: end.srcIPv4.Addr, | ||||||
|  | 			Ifindex:  end.srcIf4, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	msghdr := unix.Msghdr{ | ||||||
|  | 		Iov:     &iovec, | ||||||
|  | 		Iovlen:  1, | ||||||
|  | 		Name:    (*byte)(unsafe.Pointer(&end.dst)), | ||||||
|  | 		Namelen: unix.SizeofSockaddrInet4, | ||||||
|  | 		Control: (*byte)(unsafe.Pointer(&cmsg)), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) | ||||||
|  | 
 | ||||||
|  | 	// sendmsg(sock, &msghdr, 0)
 | ||||||
|  | 
 | ||||||
|  | 	_, _, errno := unix.Syscall( | ||||||
|  | 		unix.SYS_SENDMSG, | ||||||
|  | 		sock, | ||||||
|  | 		uintptr(unsafe.Pointer(&msghdr)), | ||||||
|  | 		0, | ||||||
|  | 	) | ||||||
|  | 	if errno == unix.EINVAL { | ||||||
|  | 		end.ClearSrc() | ||||||
|  | 	} | ||||||
|  | 	return errno | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func send(c *net.UDPConn, end *Endpoint, buff []byte) error { | ||||||
|  | 
 | ||||||
|  | 	// extract underlying file descriptor
 | ||||||
|  | 
 | ||||||
|  | 	file, err := c.File() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	sock := file.Fd() | ||||||
|  | 
 | ||||||
|  | 	// send depending on address family of dst
 | ||||||
|  | 
 | ||||||
|  | 	family := *((*uint16)(unsafe.Pointer(&end.dst))) | ||||||
|  | 	if family == unix.AF_INET { | ||||||
|  | 		return send4(sock, end, buff) | ||||||
|  | 	} else if family == unix.AF_INET6 { | ||||||
|  | 		return send6(sock, end, buff) | ||||||
|  | 	} | ||||||
|  | 	return errors.New("Unknown address family of source") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) { | ||||||
|  | 
 | ||||||
|  | 	file, err := c.File() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err, nil, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var iovec unix.Iovec | ||||||
|  | 	iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) | ||||||
|  | 	iovec.SetLen(len(buff)) | ||||||
|  | 
 | ||||||
|  | 	var cmsg struct { | ||||||
|  | 		cmsghdr unix.Cmsghdr | ||||||
|  | 		pktinfo unix.Inet6Pktinfo // big enough
 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var msg unix.Msghdr | ||||||
|  | 	msg.Iov = &iovec | ||||||
|  | 	msg.Iovlen = 1 | ||||||
|  | 	msg.Name = (*byte)(unsafe.Pointer(&end.dst)) | ||||||
|  | 	msg.Namelen = uint32(unix.SizeofSockaddrAny) | ||||||
|  | 	msg.Control = (*byte)(unsafe.Pointer(&cmsg)) | ||||||
|  | 	msg.SetControllen(int(unsafe.Sizeof(cmsg))) | ||||||
|  | 
 | ||||||
|  | 	_, _, errno := unix.Syscall( | ||||||
|  | 		unix.SYS_RECVMSG, | ||||||
|  | 		file.Fd(), | ||||||
|  | 		uintptr(unsafe.Pointer(&msg)), | ||||||
|  | 		0, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	if errno != 0 { | ||||||
|  | 		return errno, nil, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && | ||||||
|  | 		cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && | ||||||
|  | 		cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { | ||||||
|  | 
 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if cmsg.cmsghdr.Level == unix.IPPROTO_IP && | ||||||
|  | 		cmsg.cmsghdr.Type == unix.IP_PKTINFO && | ||||||
|  | 		cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { | ||||||
|  | 
 | ||||||
|  | 		info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo)) | ||||||
|  | 		println(info) | ||||||
|  | 
 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil, nil, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func setMark(conn *net.UDPConn, value uint32) error { | func setMark(conn *net.UDPConn, value uint32) error { | ||||||
| 	if conn == nil { | 	if conn == nil { | ||||||
| 		return nil | 		return nil | ||||||
|  |  | ||||||
|  | @ -29,6 +29,11 @@ func (a *AtomicBool) Set(val bool) { | ||||||
| 	atomic.StoreInt32(&a.flag, flag) | 	atomic.StoreInt32(&a.flag, flag) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func toInt32(n uint32) int32 { | ||||||
|  | 	mask := uint32(1 << 31) | ||||||
|  | 	return int32(-(n & mask) + (n & ^mask)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func min(a uint, b uint) uint { | func min(a uint, b uint) uint { | ||||||
| 	if a > b { | 	if a > b { | ||||||
| 		return b | 		return b | ||||||
|  |  | ||||||
|  | @ -120,14 +120,6 @@ func (tun *NativeTun) Name() string { | ||||||
| 	return tun.name | 	return tun.name | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func toInt32(val []byte) int32 { |  | ||||||
| 	n := binary.LittleEndian.Uint32(val[:4]) |  | ||||||
| 	if n >= (1 << 31) { |  | ||||||
| 		return -int32(^n) - 1 |  | ||||||
| 	} |  | ||||||
| 	return int32(n) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func getDummySock() (int, error) { | func getDummySock() (int, error) { | ||||||
| 	return unix.Socket( | 	return unix.Socket( | ||||||
| 		unix.AF_INET, | 		unix.AF_INET, | ||||||
|  | @ -157,7 +149,8 @@ func getIFIndex(name string) (int32, error) { | ||||||
| 		return 0, errno | 		return 0, errno | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return toInt32(ifr[unix.IFNAMSIZ:]), nil | 	index := binary.LittleEndian.Uint32(ifr[unix.IFNAMSIZ:]) | ||||||
|  | 	return toInt32(index), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (tun *NativeTun) setMTU(n int) error { | func (tun *NativeTun) setMTU(n int) error { | ||||||
|  |  | ||||||
|  | @ -273,8 +273,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 			case "endpoint": | 			case "endpoint": | ||||||
| 				// TODO: Only IP and port
 | 				addr, err := parseEndpoint(value) | ||||||
| 				addr, err := net.ResolveUDPAddr("udp", value) |  | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					logError.Println("Failed to set endpoint:", value) | 					logError.Println("Failed to set endpoint:", value) | ||||||
| 					return &IPCError{Code: ipcErrorInvalid} | 					return &IPCError{Code: ipcErrorInvalid} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue