tun: linux: work out netpoll trick

This commit is contained in:
Jason A. Donenfeld 2019-03-07 01:51:41 +01:00
parent 1fdf7b19a3
commit 92f72f5aa6

View file

@ -17,8 +17,8 @@ import (
"golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/rwcancel"
"net" "net"
"os" "os"
"strconv"
"sync" "sync"
"syscall"
"time" "time"
"unsafe" "unsafe"
) )
@ -30,8 +30,6 @@ const (
type NativeTun struct { type NativeTun struct {
tunFile *os.File tunFile *os.File
fd uintptr
fdCancel *rwcancel.RWCancel
index int32 // if index index int32 // if index
name string // name of interface name string // name of interface
errors chan error // async error handling errors chan error // async error handling
@ -52,9 +50,17 @@ func (tun *NativeTun) routineHackListener() {
/* This is needed for the detection to work across network namespaces /* This is needed for the detection to work across network namespaces
* If you are reading this and know a better method, please get in touch. * If you are reading this and know a better method, please get in touch.
*/ */
fd := int(tun.fd)
for { for {
_, err := unix.Write(fd, nil) sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
return
}
err2 := sysconn.Control(func(fd uintptr) {
_, err = unix.Write(int(fd), nil)
})
if err2 != nil {
return
}
switch err { switch err {
case unix.EINVAL: case unix.EINVAL:
tun.events <- TUNEventUp tun.events <- TUNEventUp
@ -248,22 +254,32 @@ func (tun *NativeTun) MTU() (int, error) {
uintptr(unsafe.Pointer(&ifr[0])), uintptr(unsafe.Pointer(&ifr[0])),
) )
if errno != 0 { if errno != 0 {
return 0, errors.New("failed to get MTU of TUN device: " + strconv.FormatInt(int64(errno), 10)) return 0, errors.New("failed to get MTU of TUN device: " + errno.Error())
} }
return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
} }
func (tun *NativeTun) Name() (string, error) { func (tun *NativeTun) Name() (string, error) {
sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
return "", err
}
var ifr [ifReqSize]byte var ifr [ifReqSize]byte
_, _, errno := unix.Syscall( var errno syscall.Errno
err = sysconn.Control(func(fd uintptr) {
_, _, errno = unix.Syscall(
unix.SYS_IOCTL, unix.SYS_IOCTL,
tun.fd, fd,
uintptr(unix.TUNGETIFF), uintptr(unix.TUNGETIFF),
uintptr(unsafe.Pointer(&ifr[0])), uintptr(unsafe.Pointer(&ifr[0])),
) )
})
if err != nil {
return "", errors.New("failed to get name of TUN device: " + err.Error())
}
if errno != 0 { if errno != 0 {
return "", errors.New("failed to get name of TUN device: " + strconv.FormatInt(int64(errno), 10)) return "", errors.New("failed to get name of TUN device: " + errno.Error())
} }
nullStr := ifr[:] nullStr := ifr[:]
i := bytes.IndexByte(nullStr, 0) i := bytes.IndexByte(nullStr, 0)
@ -302,7 +318,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
return tun.tunFile.Write(buff) return tun.tunFile.Write(buff)
} }
func (tun *NativeTun) doRead(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
select { select {
case err := <-tun.errors: case err := <-tun.errors:
return 0, err return 0, err
@ -320,18 +336,6 @@ func (tun *NativeTun) doRead(buff []byte, offset int) (int, error) {
} }
} }
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
for {
n, err := tun.doRead(buff, offset)
if err == nil || !rwcancel.RetryAfterError(err) {
return n, err
}
if !tun.fdCancel.ReadyRead() {
return 0, errors.New("tun device closed")
}
}
}
func (tun *NativeTun) Events() chan TUNEvent { func (tun *NativeTun) Events() chan TUNEvent {
return tun.events return tun.events
} }
@ -347,15 +351,11 @@ func (tun *NativeTun) Close() error {
close(tun.events) close(tun.events)
} }
err2 := tun.tunFile.Close() err2 := tun.tunFile.Close()
err3 := tun.fdCancel.Cancel()
if err1 != nil { if err1 != nil {
return err1 return err1
} }
if err2 != nil {
return err2 return err2
}
return err3
} }
func CreateTUN(name string, mtu int) (TUNDevice, error) { func CreateTUN(name string, mtu int) (TUNDevice, error) {
@ -364,13 +364,6 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
return nil, err return nil, err
} }
fd := os.NewFile(uintptr(nfd), cloneDevicePath)
if err != nil {
return nil, err
}
// create new device
var ifr [ifReqSize]byte var ifr [ifReqSize]byte
var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
nameBytes := []byte(name) nameBytes := []byte(name)
@ -382,13 +375,21 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(
unix.SYS_IOCTL, unix.SYS_IOCTL,
fd.Fd(), uintptr(nfd),
uintptr(unix.TUNSETIFF), uintptr(unix.TUNSETIFF),
uintptr(unsafe.Pointer(&ifr[0])), uintptr(unsafe.Pointer(&ifr[0])),
) )
if errno != 0 { if errno != 0 {
return nil, errno return nil, errno
} }
err = unix.SetNonblock(nfd, true)
// Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
fd := os.NewFile(uintptr(nfd), cloneDevicePath)
if err != nil {
return nil, err
}
return CreateTUNFromFile(fd, mtu) return CreateTUNFromFile(fd, mtu)
} }
@ -396,7 +397,6 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) { func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
tun := &NativeTun{ tun := &NativeTun{
tunFile: file, tunFile: file,
fd: file.Fd(),
events: make(chan TUNEvent, 5), events: make(chan TUNEvent, 5),
errors: make(chan error, 5), errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}), statusListenersShutdown: make(chan struct{}),
@ -404,11 +404,6 @@ func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
} }
var err error var err error
tun.fdCancel, err = rwcancel.NewRWCancel(int(tun.fd))
if err != nil {
return nil, err
}
_, err = tun.Name() _, err = tun.Name()
if err != nil { if err != nil {
return nil, err return nil, err
@ -444,23 +439,20 @@ func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
return tun, nil return tun, nil
} }
func CreateUnmonitoredTUNFromFD(tunFd int) (TUNDevice, string, error) { func CreateUnmonitoredTUNFromFD(fd int) (TUNDevice, string, error) {
file := os.NewFile(uintptr(tunFd), "/dev/tun") err := unix.SetNonblock(fd, true)
if err != nil {
return nil, "", err
}
file := os.NewFile(uintptr(fd), "/dev/tun")
tun := &NativeTun{ tun := &NativeTun{
tunFile: file, tunFile: file,
fd: file.Fd(),
events: make(chan TUNEvent, 5), events: make(chan TUNEvent, 5),
errors: make(chan error, 5), errors: make(chan error, 5),
nopi: true, nopi: true,
} }
var err error
tun.fdCancel, err = rwcancel.NewRWCancel(int(tun.fd))
if err != nil {
return nil, "", err
}
name, err := tun.Name() name, err := tun.Name()
if err != nil { if err != nil {
tun.fdCancel.Cancel()
return nil, "", err return nil, "", err
} }
return tun, name, nil return tun, name, nil