Begin work on source address caching (linux)
This commit is contained in:
		
							parent
							
								
									c545d63bb9
								
							
						
					
					
						commit
						eefa47b0f9
					
				
					 5 changed files with 273 additions and 11 deletions
				
			
		
							
								
								
									
										22
									
								
								src/conn.go
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								src/conn.go
									
									
									
									
									
								
							|  | @ -1,9 +1,31 @@ | |||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"net" | ||||
| ) | ||||
| 
 | ||||
| func parseEndpoint(s string) (*net.UDPAddr, error) { | ||||
| 
 | ||||
| 	// ensure that the host is an IP address
 | ||||
| 
 | ||||
| 	host, _, err := net.SplitHostPort(s) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if ip := net.ParseIP(host); ip == nil { | ||||
| 		return nil, errors.New("Failed to parse IP address: " + host) | ||||
| 	} | ||||
| 
 | ||||
| 	// parse address and port
 | ||||
| 
 | ||||
| 	addr, err := net.ResolveUDPAddr("udp", s) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return addr, err | ||||
| } | ||||
| 
 | ||||
| func updateUDPConn(device *Device) error { | ||||
| 	netc := &device.net | ||||
| 	netc.mutex.Lock() | ||||
|  |  | |||
|  | @ -1,10 +1,253 @@ | |||
| /* Copyright 2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. | ||||
|  * | ||||
|  * This implements userspace semantics of "sticky sockets", modeled after | ||||
|  * WireGuard's kernelspace implementation. | ||||
|  */ | ||||
| 
 | ||||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"golang.org/x/sys/unix" | ||||
| 	"net" | ||||
| 	"strconv" | ||||
| 	"unsafe" | ||||
| ) | ||||
| 
 | ||||
| /* Supports source address caching | ||||
|  * | ||||
|  * It is important that the endpoint is only updated after the packet content has been authenticated. | ||||
|  * | ||||
|  * Currently there is no way to achieve this within the net package: | ||||
|  * See e.g. https://github.com/golang/go/issues/17930
 | ||||
|  */ | ||||
| type Endpoint struct { | ||||
| 	// source (selected based on dst type)
 | ||||
| 	// (could use RawSockaddrAny and unsafe)
 | ||||
| 	srcIPv6 unix.RawSockaddrInet6 | ||||
| 	srcIPv4 unix.RawSockaddrInet4 | ||||
| 	srcIf4  int32 | ||||
| 
 | ||||
| 	dst unix.RawSockaddrAny | ||||
| } | ||||
| 
 | ||||
| func zoneToUint32(zone string) (uint32, error) { | ||||
| 	if zone == "" { | ||||
| 		return 0, nil | ||||
| 	} | ||||
| 	if intr, err := net.InterfaceByName(zone); err == nil { | ||||
| 		return uint32(intr.Index), nil | ||||
| 	} | ||||
| 	n, err := strconv.ParseUint(zone, 10, 32) | ||||
| 	return uint32(n), err | ||||
| } | ||||
| 
 | ||||
| func (end *Endpoint) ClearSrc() { | ||||
| 	end.srcIf4 = 0 | ||||
| 	end.srcIPv4 = unix.RawSockaddrInet4{} | ||||
| 	end.srcIPv6 = unix.RawSockaddrInet6{} | ||||
| } | ||||
| 
 | ||||
| func (end *Endpoint) Set(s string) error { | ||||
| 	addr, err := parseEndpoint(s) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	ipv6 := addr.IP.To16() | ||||
| 	if ipv6 != nil { | ||||
| 		zone, err := zoneToUint32(addr.Zone) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst)) | ||||
| 		ptr.Family = unix.AF_INET6 | ||||
| 		ptr.Port = uint16(addr.Port) | ||||
| 		ptr.Flowinfo = 0 | ||||
| 		ptr.Scope_id = zone | ||||
| 		copy(ptr.Addr[:], ipv6[:]) | ||||
| 		end.ClearSrc() | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	ipv4 := addr.IP.To4() | ||||
| 	if ipv4 != nil { | ||||
| 		ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) | ||||
| 		ptr.Family = unix.AF_INET | ||||
| 		ptr.Port = uint16(addr.Port) | ||||
| 		ptr.Zero = [8]byte{} | ||||
| 		copy(ptr.Addr[:], ipv4) | ||||
| 		end.ClearSrc() | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	return errors.New("Failed to recognize IP address format") | ||||
| } | ||||
| 
 | ||||
| func send6(sock uintptr, end *Endpoint, buff []byte) error { | ||||
| 	var iovec unix.Iovec | ||||
| 
 | ||||
| 	iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) | ||||
| 	iovec.SetLen(len(buff)) | ||||
| 
 | ||||
| 	cmsg := struct { | ||||
| 		cmsghdr unix.Cmsghdr | ||||
| 		pktinfo unix.Inet6Pktinfo | ||||
| 	}{ | ||||
| 		unix.Cmsghdr{ | ||||
| 			Level: unix.IPPROTO_IPV6, | ||||
| 			Type:  unix.IPV6_PKTINFO, | ||||
| 			Len:   unix.SizeofInet6Pktinfo, | ||||
| 		}, | ||||
| 		unix.Inet6Pktinfo{ | ||||
| 			Addr:    end.srcIPv6.Addr, | ||||
| 			Ifindex: end.srcIPv6.Scope_id, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	msghdr := unix.Msghdr{ | ||||
| 		Iov:     &iovec, | ||||
| 		Iovlen:  1, | ||||
| 		Name:    (*byte)(unsafe.Pointer(&end.dst)), | ||||
| 		Namelen: unix.SizeofSockaddrInet6, | ||||
| 		Control: (*byte)(unsafe.Pointer(&cmsg)), | ||||
| 	} | ||||
| 
 | ||||
| 	msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) | ||||
| 
 | ||||
| 	// sendmsg(sock, &msghdr, 0)
 | ||||
| 
 | ||||
| 	_, _, errno := unix.Syscall( | ||||
| 		unix.SYS_SENDMSG, | ||||
| 		sock, | ||||
| 		uintptr(unsafe.Pointer(&msghdr)), | ||||
| 		0, | ||||
| 	) | ||||
| 	if errno == unix.EINVAL { | ||||
| 		end.ClearSrc() | ||||
| 	} | ||||
| 	return errno | ||||
| } | ||||
| 
 | ||||
| func send4(sock uintptr, end *Endpoint, buff []byte) error { | ||||
| 	var iovec unix.Iovec | ||||
| 
 | ||||
| 	iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) | ||||
| 	iovec.SetLen(len(buff)) | ||||
| 
 | ||||
| 	cmsg := struct { | ||||
| 		cmsghdr unix.Cmsghdr | ||||
| 		pktinfo unix.Inet4Pktinfo | ||||
| 	}{ | ||||
| 		unix.Cmsghdr{ | ||||
| 			Level: unix.IPPROTO_IP, | ||||
| 			Type:  unix.IP_PKTINFO, | ||||
| 			Len:   unix.SizeofInet6Pktinfo, | ||||
| 		}, | ||||
| 		unix.Inet4Pktinfo{ | ||||
| 			Spec_dst: end.srcIPv4.Addr, | ||||
| 			Ifindex:  end.srcIf4, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	msghdr := unix.Msghdr{ | ||||
| 		Iov:     &iovec, | ||||
| 		Iovlen:  1, | ||||
| 		Name:    (*byte)(unsafe.Pointer(&end.dst)), | ||||
| 		Namelen: unix.SizeofSockaddrInet4, | ||||
| 		Control: (*byte)(unsafe.Pointer(&cmsg)), | ||||
| 	} | ||||
| 
 | ||||
| 	msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) | ||||
| 
 | ||||
| 	// sendmsg(sock, &msghdr, 0)
 | ||||
| 
 | ||||
| 	_, _, errno := unix.Syscall( | ||||
| 		unix.SYS_SENDMSG, | ||||
| 		sock, | ||||
| 		uintptr(unsafe.Pointer(&msghdr)), | ||||
| 		0, | ||||
| 	) | ||||
| 	if errno == unix.EINVAL { | ||||
| 		end.ClearSrc() | ||||
| 	} | ||||
| 	return errno | ||||
| } | ||||
| 
 | ||||
| func send(c *net.UDPConn, end *Endpoint, buff []byte) error { | ||||
| 
 | ||||
| 	// extract underlying file descriptor
 | ||||
| 
 | ||||
| 	file, err := c.File() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	sock := file.Fd() | ||||
| 
 | ||||
| 	// send depending on address family of dst
 | ||||
| 
 | ||||
| 	family := *((*uint16)(unsafe.Pointer(&end.dst))) | ||||
| 	if family == unix.AF_INET { | ||||
| 		return send4(sock, end, buff) | ||||
| 	} else if family == unix.AF_INET6 { | ||||
| 		return send6(sock, end, buff) | ||||
| 	} | ||||
| 	return errors.New("Unknown address family of source") | ||||
| } | ||||
| 
 | ||||
| func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) { | ||||
| 
 | ||||
| 	file, err := c.File() | ||||
| 	if err != nil { | ||||
| 		return err, nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	var iovec unix.Iovec | ||||
| 	iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) | ||||
| 	iovec.SetLen(len(buff)) | ||||
| 
 | ||||
| 	var cmsg struct { | ||||
| 		cmsghdr unix.Cmsghdr | ||||
| 		pktinfo unix.Inet6Pktinfo // big enough
 | ||||
| 	} | ||||
| 
 | ||||
| 	var msg unix.Msghdr | ||||
| 	msg.Iov = &iovec | ||||
| 	msg.Iovlen = 1 | ||||
| 	msg.Name = (*byte)(unsafe.Pointer(&end.dst)) | ||||
| 	msg.Namelen = uint32(unix.SizeofSockaddrAny) | ||||
| 	msg.Control = (*byte)(unsafe.Pointer(&cmsg)) | ||||
| 	msg.SetControllen(int(unsafe.Sizeof(cmsg))) | ||||
| 
 | ||||
| 	_, _, errno := unix.Syscall( | ||||
| 		unix.SYS_RECVMSG, | ||||
| 		file.Fd(), | ||||
| 		uintptr(unsafe.Pointer(&msg)), | ||||
| 		0, | ||||
| 	) | ||||
| 
 | ||||
| 	if errno != 0 { | ||||
| 		return errno, nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && | ||||
| 		cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && | ||||
| 		cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { | ||||
| 
 | ||||
| 	} | ||||
| 
 | ||||
| 	if cmsg.cmsghdr.Level == unix.IPPROTO_IP && | ||||
| 		cmsg.cmsghdr.Type == unix.IP_PKTINFO && | ||||
| 		cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { | ||||
| 
 | ||||
| 		info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo)) | ||||
| 		println(info) | ||||
| 
 | ||||
| 	} | ||||
| 
 | ||||
| 	return nil, nil, nil | ||||
| } | ||||
| 
 | ||||
| func setMark(conn *net.UDPConn, value uint32) error { | ||||
| 	if conn == nil { | ||||
| 		return nil | ||||
|  |  | |||
|  | @ -29,6 +29,11 @@ func (a *AtomicBool) Set(val bool) { | |||
| 	atomic.StoreInt32(&a.flag, flag) | ||||
| } | ||||
| 
 | ||||
| func toInt32(n uint32) int32 { | ||||
| 	mask := uint32(1 << 31) | ||||
| 	return int32(-(n & mask) + (n & ^mask)) | ||||
| } | ||||
| 
 | ||||
| func min(a uint, b uint) uint { | ||||
| 	if a > b { | ||||
| 		return b | ||||
|  |  | |||
|  | @ -120,14 +120,6 @@ func (tun *NativeTun) Name() string { | |||
| 	return tun.name | ||||
| } | ||||
| 
 | ||||
| func toInt32(val []byte) int32 { | ||||
| 	n := binary.LittleEndian.Uint32(val[:4]) | ||||
| 	if n >= (1 << 31) { | ||||
| 		return -int32(^n) - 1 | ||||
| 	} | ||||
| 	return int32(n) | ||||
| } | ||||
| 
 | ||||
| func getDummySock() (int, error) { | ||||
| 	return unix.Socket( | ||||
| 		unix.AF_INET, | ||||
|  | @ -157,7 +149,8 @@ func getIFIndex(name string) (int32, error) { | |||
| 		return 0, errno | ||||
| 	} | ||||
| 
 | ||||
| 	return toInt32(ifr[unix.IFNAMSIZ:]), nil | ||||
| 	index := binary.LittleEndian.Uint32(ifr[unix.IFNAMSIZ:]) | ||||
| 	return toInt32(index), nil | ||||
| } | ||||
| 
 | ||||
| func (tun *NativeTun) setMTU(n int) error { | ||||
|  |  | |||
|  | @ -273,8 +273,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | |||
| 				} | ||||
| 
 | ||||
| 			case "endpoint": | ||||
| 				// TODO: Only IP and port
 | ||||
| 				addr, err := net.ResolveUDPAddr("udp", value) | ||||
| 				addr, err := parseEndpoint(value) | ||||
| 				if err != nil { | ||||
| 					logError.Println("Failed to set endpoint:", value) | ||||
| 					return &IPCError{Code: ipcErrorInvalid} | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue