diff --git a/device.go b/device.go index cc12ac9..08a6471 100644 --- a/device.go +++ b/device.go @@ -293,7 +293,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { // prepare signals - device.signals.stop = make(chan struct{}, 1) + device.signals.stop = make(chan struct{}, 0) // prepare net diff --git a/tun_linux.go b/tun_linux.go index 3510f94..c642fe7 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -31,13 +31,15 @@ const ( ) type NativeTun struct { - fd *os.File - index int32 // if index - name string // name of interface - errors chan error // async error handling - events chan TUNEvent // device related events - nopi bool // the device was pased IFF_NO_PI - rwcancel *rwcancel.RWCancel + fd *os.File + index int32 // if index + name string // name of interface + errors chan error // async error handling + events chan TUNEvent // device related events + nopi bool // the device was pased IFF_NO_PI + rwcancel *rwcancel.RWCancel + netlinkSock int + shutdownHackListener chan struct{} } func (tun *NativeTun) File() *os.File { @@ -45,10 +47,6 @@ func (tun *NativeTun) File() *os.File { } func (tun *NativeTun) RoutineHackListener() { - // TODO: This function never actually exits in response to anything, - // a go routine that goes forever. We'll want to fix that if this is - // to ever be used as any sort of library. - /* This is needed for the detection to work across network namespaces * If you are reading this and know a better method, please get in touch. */ @@ -61,47 +59,38 @@ func (tun *NativeTun) RoutineHackListener() { case unix.EIO: tun.events <- TUNEventDown default: + return + } + select { + case <-time.After(time.Second / 10): + case <-tun.shutdownHackListener: + return } - time.Sleep(time.Second / 10) } } -func toRTMGRP(sc uint) uint { - return 1 << (sc - 1) -} - -func (tun *NativeTun) RoutineNetlinkListener() { - - groups := toRTMGRP(unix.RTNLGRP_LINK) - groups |= toRTMGRP(unix.RTNLGRP_IPV4_IFADDR) - groups |= toRTMGRP(unix.RTNLGRP_IPV6_IFADDR) +func createNetlinkSocket() (int, error) { sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) if err != nil { - tun.errors <- errors.New("Failed to create netlink event listener socket") - return + return -1, err } - defer unix.Close(sock) saddr := &unix.SockaddrNetlink{ Family: unix.AF_NETLINK, - Groups: uint32(groups), + Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))), } err = unix.Bind(sock, saddr) if err != nil { - tun.errors <- errors.New("Failed to bind netlink event listener socket") - return + return -1, err } + return sock, nil +} - // TODO: This function never actually exits in response to anything, - // a go routine that goes forever. We'll want to fix that if this is - // to ever be used as any sort of library. See what we've done with - // calling shutdown() on the netlink socket in conn_linux.go, and - // change this to be more like that. - +func (tun *NativeTun) RoutineNetlinkListener() { for msg := make([]byte, 1<<16); ; { - msgn, _, _, _, err := unix.Recvmsg(sock, msg[:], nil, 0) + msgn, _, _, _, err := unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0) if err != nil { - tun.errors <- fmt.Errorf("Failed to receive netlink message: %s", err.Error()) + tun.errors <- fmt.Errorf("failed to receive netlink message: %s", err.Error()) return } @@ -339,13 +328,16 @@ func (tun *NativeTun) Events() chan TUNEvent { } func (tun *NativeTun) Close() error { - err := tun.fd.Close() - if err != nil { - return err - } + err1 := tun.fd.Close() + err2 := closeUnblock(tun.netlinkSock) tun.rwcancel.Cancel() close(tun.events) - return nil + close(tun.shutdownHackListener) + + if err1 != nil { + return err1 + } + return err2 } func CreateTUN(name string) (TUNDevice, error) { @@ -375,7 +367,7 @@ func CreateTUN(name string) (TUNDevice, error) { var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) nameBytes := []byte(name) if len(nameBytes) >= unix.IFNAMSIZ { - return nil, errors.New("Interface name too long") + return nil, errors.New("interface name too long") } copy(ifr[:], nameBytes) binary.LittleEndian.PutUint16(ifr[16:], flags) @@ -395,10 +387,11 @@ func CreateTUN(name string) (TUNDevice, error) { func CreateTUNFromFile(fd *os.File) (TUNDevice, error) { device := &NativeTun{ - fd: fd, - events: make(chan TUNEvent, 5), - errors: make(chan error, 5), - nopi: false, + fd: fd, + events: make(chan TUNEvent, 5), + errors: make(chan error, 5), + shutdownHackListener: make(chan struct{}, 0), + nopi: false, } var err error @@ -419,10 +412,20 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) { return nil, err } + // set default MTU + + err = device.setMTU(DefaultMTU) + if err != nil { + return nil, err + } + + device.netlinkSock, err = createNetlinkSocket() + if err != nil { + return nil, err + } + go device.RoutineNetlinkListener() go device.RoutineHackListener() // cross namespace - // set default MTU - - return device, device.setMTU(DefaultMTU) + return device, nil }