Clear src cache if route changes to new ifindex
This commit is contained in:
parent
92261b770f
commit
b34604245e
160
conn_linux.go
160
conn_linux.go
|
@ -53,12 +53,15 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
|
||||||
}
|
}
|
||||||
|
|
||||||
type NativeBind struct {
|
type NativeBind struct {
|
||||||
sock4 int
|
sock4 int
|
||||||
sock6 int
|
sock6 int
|
||||||
|
netlinkSock int
|
||||||
|
lastEndpoint *NativeEndpoint
|
||||||
|
lastMark uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Endpoint = (*NativeEndpoint)(nil)
|
var _ Endpoint = (*NativeEndpoint)(nil)
|
||||||
var _ Bind = NativeBind{}
|
var _ Bind = (*NativeBind)(nil)
|
||||||
|
|
||||||
func CreateEndpoint(s string) (Endpoint, error) {
|
func CreateEndpoint(s string) (Endpoint, error) {
|
||||||
var end NativeEndpoint
|
var end NativeEndpoint
|
||||||
|
@ -95,23 +98,50 @@ func CreateEndpoint(s string) (Endpoint, error) {
|
||||||
return nil, errors.New("Invalid IP address")
|
return nil, errors.New("Invalid IP address")
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateBind(port uint16) (Bind, uint16, error) {
|
func createNetlinkRouteSocket() (int, error) {
|
||||||
|
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
saddr := &unix.SockaddrNetlink{
|
||||||
|
Family: unix.AF_NETLINK,
|
||||||
|
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
|
||||||
|
}
|
||||||
|
err = unix.Bind(sock, saddr)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(sock)
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
return sock, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateBind(port uint16) (*NativeBind, uint16, error) {
|
||||||
var err error
|
var err error
|
||||||
var bind NativeBind
|
var bind NativeBind
|
||||||
|
|
||||||
|
bind.netlinkSock, err = createNetlinkRouteSocket()
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
go bind.routineRouteListener()
|
||||||
|
|
||||||
bind.sock6, port, err = create6(port)
|
bind.sock6, port, err = create6(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
unix.Close(bind.netlinkSock)
|
||||||
return nil, port, err
|
return nil, port, err
|
||||||
}
|
}
|
||||||
|
|
||||||
bind.sock4, port, err = create4(port)
|
bind.sock4, port, err = create4(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
unix.Close(bind.netlinkSock)
|
||||||
unix.Close(bind.sock6)
|
unix.Close(bind.sock6)
|
||||||
}
|
}
|
||||||
return bind, port, err
|
return &bind, port, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind NativeBind) SetMark(value uint32) error {
|
func (bind *NativeBind) SetMark(value uint32) error {
|
||||||
err := unix.SetsockoptInt(
|
err := unix.SetsockoptInt(
|
||||||
bind.sock6,
|
bind.sock6,
|
||||||
unix.SOL_SOCKET,
|
unix.SOL_SOCKET,
|
||||||
|
@ -123,12 +153,19 @@ func (bind NativeBind) SetMark(value uint32) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return unix.SetsockoptInt(
|
err = unix.SetsockoptInt(
|
||||||
bind.sock4,
|
bind.sock4,
|
||||||
unix.SOL_SOCKET,
|
unix.SOL_SOCKET,
|
||||||
unix.SO_MARK,
|
unix.SO_MARK,
|
||||||
int(value),
|
int(value),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
bind.lastMark = value
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func closeUnblock(fd int) error {
|
func closeUnblock(fd int) error {
|
||||||
|
@ -137,16 +174,20 @@ func closeUnblock(fd int) error {
|
||||||
return unix.Close(fd)
|
return unix.Close(fd)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind NativeBind) Close() error {
|
func (bind *NativeBind) Close() error {
|
||||||
err1 := closeUnblock(bind.sock6)
|
err1 := closeUnblock(bind.sock6)
|
||||||
err2 := closeUnblock(bind.sock4)
|
err2 := closeUnblock(bind.sock4)
|
||||||
|
err3 := closeUnblock(bind.netlinkSock)
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
return err1
|
return err1
|
||||||
}
|
}
|
||||||
return err2
|
if err2 != nil {
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
return err3
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||||
var end NativeEndpoint
|
var end NativeEndpoint
|
||||||
n, err := receive6(
|
n, err := receive6(
|
||||||
bind.sock6,
|
bind.sock6,
|
||||||
|
@ -156,17 +197,18 @@ func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||||
return n, &end, err
|
return n, &end, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
||||||
var end NativeEndpoint
|
var end NativeEndpoint
|
||||||
n, err := receive4(
|
n, err := receive4(
|
||||||
bind.sock4,
|
bind.sock4,
|
||||||
buff,
|
buff,
|
||||||
&end,
|
&end,
|
||||||
)
|
)
|
||||||
|
bind.lastEndpoint = &end
|
||||||
return n, &end, err
|
return n, &end, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind NativeBind) Send(buff []byte, end Endpoint) error {
|
func (bind *NativeBind) Send(buff []byte, end Endpoint) error {
|
||||||
nend := end.(*NativeEndpoint)
|
nend := end.(*NativeEndpoint)
|
||||||
if !nend.isV6 {
|
if !nend.isV6 {
|
||||||
return send4(bind.sock4, nend, buff)
|
return send4(bind.sock4, nend, buff)
|
||||||
|
@ -506,3 +548,97 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||||
|
|
||||||
return size, nil
|
return size, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) routineRouteListener() {
|
||||||
|
// TODO: this function doesn't lock the endpoint it modifies
|
||||||
|
|
||||||
|
for msg := make([]byte, 1<<16); ; {
|
||||||
|
msgn, _, _, _, err := unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||||
|
|
||||||
|
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||||
|
|
||||||
|
if uint(hdr.Len) > uint(len(remain)) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
switch hdr.Type {
|
||||||
|
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
|
||||||
|
|
||||||
|
if bind.lastEndpoint == nil || bind.lastEndpoint.isV6 || bind.lastEndpoint.src4().ifindex == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Seq == 0xff {
|
||||||
|
if uint(len(remain)) < uint(hdr.Len) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
|
||||||
|
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
|
||||||
|
for {
|
||||||
|
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
|
||||||
|
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
|
||||||
|
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
|
||||||
|
if uint32(bind.lastEndpoint.src4().ifindex) != ifidx {
|
||||||
|
bind.lastEndpoint.ClearSrc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr = attr[attrhdr.Len:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
nlmsg := struct {
|
||||||
|
hdr unix.NlMsghdr
|
||||||
|
msg unix.RtMsg
|
||||||
|
dsthdr unix.RtAttr
|
||||||
|
dst [4]byte
|
||||||
|
srchdr unix.RtAttr
|
||||||
|
src [4]byte
|
||||||
|
markhdr unix.RtAttr
|
||||||
|
mark uint32
|
||||||
|
}{
|
||||||
|
unix.NlMsghdr{
|
||||||
|
Type: uint16(unix.RTM_GETROUTE),
|
||||||
|
Flags: unix.NLM_F_REQUEST,
|
||||||
|
Seq: 0xff,
|
||||||
|
},
|
||||||
|
unix.RtMsg{
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Dst_len: 32,
|
||||||
|
Src_len: 32,
|
||||||
|
},
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: unix.RTA_DST,
|
||||||
|
},
|
||||||
|
bind.lastEndpoint.dst4().Addr,
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: unix.RTA_SRC,
|
||||||
|
},
|
||||||
|
bind.lastEndpoint.src4().src,
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix
|
||||||
|
},
|
||||||
|
uint32(bind.lastMark),
|
||||||
|
}
|
||||||
|
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
||||||
|
unix.Write(bind.netlinkSock, (*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||||
|
}
|
||||||
|
remain = remain[hdr.Len:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -79,7 +79,6 @@ func (tun *NativeTun) RoutineNetlinkListener() {
|
||||||
defer unix.Close(sock)
|
defer unix.Close(sock)
|
||||||
saddr := &unix.SockaddrNetlink{
|
saddr := &unix.SockaddrNetlink{
|
||||||
Family: unix.AF_NETLINK,
|
Family: unix.AF_NETLINK,
|
||||||
Pid: uint32(os.Getpid()),
|
|
||||||
Groups: uint32(groups),
|
Groups: uint32(groups),
|
||||||
}
|
}
|
||||||
err = unix.Bind(sock, saddr)
|
err = unix.Bind(sock, saddr)
|
||||||
|
@ -90,7 +89,9 @@ func (tun *NativeTun) RoutineNetlinkListener() {
|
||||||
|
|
||||||
// TODO: This function never actually exits in response to anything,
|
// TODO: This function never actually exits in response to anything,
|
||||||
// a go routine that goes forever. We'll want to fix that if this is
|
// a go routine that goes forever. We'll want to fix that if this is
|
||||||
// to ever be used as any sort of library.
|
// to ever be used as any sort of library. See what we've done with
|
||||||
|
// calling shutdown() on the netlink socket in conn_linux.go, and
|
||||||
|
// change this to be more like that.
|
||||||
|
|
||||||
for msg := make([]byte, 1<<16); ; {
|
for msg := make([]byte, 1<<16); ; {
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue