diff --git a/conn_linux.go b/conn_linux.go index 88b9ef4..ff3c483 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -53,12 +53,15 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { } type NativeBind struct { - sock4 int - sock6 int + sock4 int + sock6 int + netlinkSock int + lastEndpoint *NativeEndpoint + lastMark uint32 } var _ Endpoint = (*NativeEndpoint)(nil) -var _ Bind = NativeBind{} +var _ Bind = (*NativeBind)(nil) func CreateEndpoint(s string) (Endpoint, error) { var end NativeEndpoint @@ -95,23 +98,50 @@ func CreateEndpoint(s string) (Endpoint, error) { return nil, errors.New("Invalid IP address") } -func CreateBind(port uint16) (Bind, uint16, error) { +func createNetlinkRouteSocket() (int, error) { + sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) + if err != nil { + return -1, err + } + saddr := &unix.SockaddrNetlink{ + Family: unix.AF_NETLINK, + Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), + } + err = unix.Bind(sock, saddr) + if err != nil { + unix.Close(sock) + return -1, err + } + return sock, nil + +} + +func CreateBind(port uint16) (*NativeBind, uint16, error) { var err error var bind NativeBind + bind.netlinkSock, err = createNetlinkRouteSocket() + if err != nil { + return nil, 0, err + } + + go bind.routineRouteListener() + bind.sock6, port, err = create6(port) if err != nil { + unix.Close(bind.netlinkSock) return nil, port, err } bind.sock4, port, err = create4(port) if err != nil { + unix.Close(bind.netlinkSock) unix.Close(bind.sock6) } - return bind, port, err + return &bind, port, err } -func (bind NativeBind) SetMark(value uint32) error { +func (bind *NativeBind) SetMark(value uint32) error { err := unix.SetsockoptInt( bind.sock6, unix.SOL_SOCKET, @@ -123,12 +153,19 @@ func (bind NativeBind) SetMark(value uint32) error { return err } - return unix.SetsockoptInt( + err = unix.SetsockoptInt( bind.sock4, unix.SOL_SOCKET, unix.SO_MARK, int(value), ) + + if err != nil { + return err + } + + bind.lastMark = value + return nil } func closeUnblock(fd int) error { @@ -137,16 +174,20 @@ func closeUnblock(fd int) error { return unix.Close(fd) } -func (bind NativeBind) Close() error { +func (bind *NativeBind) Close() error { err1 := closeUnblock(bind.sock6) err2 := closeUnblock(bind.sock4) + err3 := closeUnblock(bind.netlinkSock) if err1 != nil { return err1 } - return err2 + if err2 != nil { + return err2 + } + return err3 } -func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { +func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { var end NativeEndpoint n, err := receive6( bind.sock6, @@ -156,17 +197,18 @@ func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { return n, &end, err } -func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { +func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { var end NativeEndpoint n, err := receive4( bind.sock4, buff, &end, ) + bind.lastEndpoint = &end return n, &end, err } -func (bind NativeBind) Send(buff []byte, end Endpoint) error { +func (bind *NativeBind) Send(buff []byte, end Endpoint) error { nend := end.(*NativeEndpoint) if !nend.isV6 { return send4(bind.sock4, nend, buff) @@ -506,3 +548,97 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { return size, nil } + +func (bind *NativeBind) routineRouteListener() { + // TODO: this function doesn't lock the endpoint it modifies + + for msg := make([]byte, 1<<16); ; { + msgn, _, _, _, err := unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0) + if err != nil { + return + } + + for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { + + hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) + + if uint(hdr.Len) > uint(len(remain)) { + break + } + + switch hdr.Type { + case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: + + if bind.lastEndpoint == nil || bind.lastEndpoint.isV6 || bind.lastEndpoint.src4().ifindex == 0 { + break + } + + if hdr.Seq == 0xff { + if uint(len(remain)) < uint(hdr.Len) { + break + } + if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { + attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] + for { + if uint(len(attr)) < uint(unix.SizeofRtAttr) { + break + } + attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) + if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { + break + } + if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { + ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) + if uint32(bind.lastEndpoint.src4().ifindex) != ifidx { + bind.lastEndpoint.ClearSrc() + } + } + attr = attr[attrhdr.Len:] + } + } + break + } + + nlmsg := struct { + hdr unix.NlMsghdr + msg unix.RtMsg + dsthdr unix.RtAttr + dst [4]byte + srchdr unix.RtAttr + src [4]byte + markhdr unix.RtAttr + mark uint32 + }{ + unix.NlMsghdr{ + Type: uint16(unix.RTM_GETROUTE), + Flags: unix.NLM_F_REQUEST, + Seq: 0xff, + }, + unix.RtMsg{ + Family: unix.AF_INET, + Dst_len: 32, + Src_len: 32, + }, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_DST, + }, + bind.lastEndpoint.dst4().Addr, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_SRC, + }, + bind.lastEndpoint.src4().src, + unix.RtAttr{ + Len: 8, + Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix + }, + uint32(bind.lastMark), + } + nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) + unix.Write(bind.netlinkSock, (*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) + } + remain = remain[hdr.Len:] + } + } +} diff --git a/tun_linux.go b/tun_linux.go index 0672b5e..b0ffa00 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -79,7 +79,6 @@ func (tun *NativeTun) RoutineNetlinkListener() { defer unix.Close(sock) saddr := &unix.SockaddrNetlink{ Family: unix.AF_NETLINK, - Pid: uint32(os.Getpid()), Groups: uint32(groups), } err = unix.Bind(sock, saddr) @@ -90,7 +89,9 @@ func (tun *NativeTun) RoutineNetlinkListener() { // TODO: This function never actually exits in response to anything, // a go routine that goes forever. We'll want to fix that if this is - // to ever be used as any sort of library. + // to ever be used as any sort of library. See what we've done with + // calling shutdown() on the netlink socket in conn_linux.go, and + // change this to be more like that. for msg := make([]byte, 1<<16); ; {