Clear src cache if route changes to new ifindex

This commit is contained in:
Jason A. Donenfeld 2018-04-27 05:21:45 +02:00
parent 92261b770f
commit b34604245e
2 changed files with 151 additions and 14 deletions

View file

@ -55,10 +55,13 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
type NativeBind struct { type NativeBind struct {
sock4 int sock4 int
sock6 int sock6 int
netlinkSock int
lastEndpoint *NativeEndpoint
lastMark uint32
} }
var _ Endpoint = (*NativeEndpoint)(nil) var _ Endpoint = (*NativeEndpoint)(nil)
var _ Bind = NativeBind{} var _ Bind = (*NativeBind)(nil)
func CreateEndpoint(s string) (Endpoint, error) { func CreateEndpoint(s string) (Endpoint, error) {
var end NativeEndpoint var end NativeEndpoint
@ -95,23 +98,50 @@ func CreateEndpoint(s string) (Endpoint, error) {
return nil, errors.New("Invalid IP address") return nil, errors.New("Invalid IP address")
} }
func CreateBind(port uint16) (Bind, uint16, error) { func createNetlinkRouteSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
}
err = unix.Bind(sock, saddr)
if err != nil {
unix.Close(sock)
return -1, err
}
return sock, nil
}
func CreateBind(port uint16) (*NativeBind, uint16, error) {
var err error var err error
var bind NativeBind var bind NativeBind
bind.netlinkSock, err = createNetlinkRouteSocket()
if err != nil {
return nil, 0, err
}
go bind.routineRouteListener()
bind.sock6, port, err = create6(port) bind.sock6, port, err = create6(port)
if err != nil { if err != nil {
unix.Close(bind.netlinkSock)
return nil, port, err return nil, port, err
} }
bind.sock4, port, err = create4(port) bind.sock4, port, err = create4(port)
if err != nil { if err != nil {
unix.Close(bind.netlinkSock)
unix.Close(bind.sock6) unix.Close(bind.sock6)
} }
return bind, port, err return &bind, port, err
} }
func (bind NativeBind) SetMark(value uint32) error { func (bind *NativeBind) SetMark(value uint32) error {
err := unix.SetsockoptInt( err := unix.SetsockoptInt(
bind.sock6, bind.sock6,
unix.SOL_SOCKET, unix.SOL_SOCKET,
@ -123,12 +153,19 @@ func (bind NativeBind) SetMark(value uint32) error {
return err return err
} }
return unix.SetsockoptInt( err = unix.SetsockoptInt(
bind.sock4, bind.sock4,
unix.SOL_SOCKET, unix.SOL_SOCKET,
unix.SO_MARK, unix.SO_MARK,
int(value), int(value),
) )
if err != nil {
return err
}
bind.lastMark = value
return nil
} }
func closeUnblock(fd int) error { func closeUnblock(fd int) error {
@ -137,16 +174,20 @@ func closeUnblock(fd int) error {
return unix.Close(fd) return unix.Close(fd)
} }
func (bind NativeBind) Close() error { func (bind *NativeBind) Close() error {
err1 := closeUnblock(bind.sock6) err1 := closeUnblock(bind.sock6)
err2 := closeUnblock(bind.sock4) err2 := closeUnblock(bind.sock4)
err3 := closeUnblock(bind.netlinkSock)
if err1 != nil { if err1 != nil {
return err1 return err1
} }
if err2 != nil {
return err2 return err2
}
return err3
} }
func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint var end NativeEndpoint
n, err := receive6( n, err := receive6(
bind.sock6, bind.sock6,
@ -156,17 +197,18 @@ func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
return n, &end, err return n, &end, err
} }
func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint var end NativeEndpoint
n, err := receive4( n, err := receive4(
bind.sock4, bind.sock4,
buff, buff,
&end, &end,
) )
bind.lastEndpoint = &end
return n, &end, err return n, &end, err
} }
func (bind NativeBind) Send(buff []byte, end Endpoint) error { func (bind *NativeBind) Send(buff []byte, end Endpoint) error {
nend := end.(*NativeEndpoint) nend := end.(*NativeEndpoint)
if !nend.isV6 { if !nend.isV6 {
return send4(bind.sock4, nend, buff) return send4(bind.sock4, nend, buff)
@ -506,3 +548,97 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
return size, nil return size, nil
} }
func (bind *NativeBind) routineRouteListener() {
// TODO: this function doesn't lock the endpoint it modifies
for msg := make([]byte, 1<<16); ; {
msgn, _, _, _, err := unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
if err != nil {
return
}
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if uint(hdr.Len) > uint(len(remain)) {
break
}
switch hdr.Type {
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
if bind.lastEndpoint == nil || bind.lastEndpoint.isV6 || bind.lastEndpoint.src4().ifindex == 0 {
break
}
if hdr.Seq == 0xff {
if uint(len(remain)) < uint(hdr.Len) {
break
}
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
for {
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
break
}
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
break
}
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
if uint32(bind.lastEndpoint.src4().ifindex) != ifidx {
bind.lastEndpoint.ClearSrc()
}
}
attr = attr[attrhdr.Len:]
}
}
break
}
nlmsg := struct {
hdr unix.NlMsghdr
msg unix.RtMsg
dsthdr unix.RtAttr
dst [4]byte
srchdr unix.RtAttr
src [4]byte
markhdr unix.RtAttr
mark uint32
}{
unix.NlMsghdr{
Type: uint16(unix.RTM_GETROUTE),
Flags: unix.NLM_F_REQUEST,
Seq: 0xff,
},
unix.RtMsg{
Family: unix.AF_INET,
Dst_len: 32,
Src_len: 32,
},
unix.RtAttr{
Len: 8,
Type: unix.RTA_DST,
},
bind.lastEndpoint.dst4().Addr,
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
bind.lastEndpoint.src4().src,
unix.RtAttr{
Len: 8,
Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix
},
uint32(bind.lastMark),
}
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
unix.Write(bind.netlinkSock, (*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
}
remain = remain[hdr.Len:]
}
}
}

View file

@ -79,7 +79,6 @@ func (tun *NativeTun) RoutineNetlinkListener() {
defer unix.Close(sock) defer unix.Close(sock)
saddr := &unix.SockaddrNetlink{ saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK, Family: unix.AF_NETLINK,
Pid: uint32(os.Getpid()),
Groups: uint32(groups), Groups: uint32(groups),
} }
err = unix.Bind(sock, saddr) err = unix.Bind(sock, saddr)
@ -90,7 +89,9 @@ func (tun *NativeTun) RoutineNetlinkListener() {
// TODO: This function never actually exits in response to anything, // TODO: This function never actually exits in response to anything,
// a go routine that goes forever. We'll want to fix that if this is // a go routine that goes forever. We'll want to fix that if this is
// to ever be used as any sort of library. // to ever be used as any sort of library. See what we've done with
// calling shutdown() on the netlink socket in conn_linux.go, and
// change this to be more like that.
for msg := make([]byte, 1<<16); ; { for msg := make([]byte, 1<<16); ; {