tun: make NativeTun.Close well behaved, not crash on double close

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2021-02-18 14:53:22 -08:00 committed by Jason A. Donenfeld
parent fecb8f482a
commit 0f4809f366
5 changed files with 62 additions and 43 deletions

View file

@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"sync"
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
@ -26,6 +27,7 @@ type NativeTun struct {
events chan Event events chan Event
errors chan error errors chan error
routeSocket int routeSocket int
closeOnce sync.Once
} }
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
@ -256,14 +258,16 @@ func (tun *NativeTun) Flush() error {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err2 error var err1, err2 error
err1 := tun.tunFile.Close() tun.closeOnce.Do(func() {
if tun.routeSocket != -1 { err1 = tun.tunFile.Close()
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) if tun.routeSocket != -1 {
err2 = unix.Close(tun.routeSocket) unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
} else if tun.events != nil { err2 = unix.Close(tun.routeSocket)
close(tun.events) } else if tun.events != nil {
} close(tun.events)
}
})
if err1 != nil { if err1 != nil {
return err1 return err1
} }

View file

@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"sync"
"syscall" "syscall"
"unsafe" "unsafe"
@ -82,6 +83,7 @@ type NativeTun struct {
events chan Event events chan Event
errors chan error errors chan error
routeSocket int routeSocket int
closeOnce sync.Once
} }
func (tun *NativeTun) routineRouteListener(tunIfindex int) { func (tun *NativeTun) routineRouteListener(tunIfindex int) {
@ -472,16 +474,18 @@ func (tun *NativeTun) Flush() error {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err3 error var err1, err2, err3 error
err1 := tun.tunFile.Close() tun.closeOnce.Do(func() {
err2 := tunDestroy(tun.name) err1 = tun.tunFile.Close()
if tun.routeSocket != -1 { err2 = tunDestroy(tun.name)
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) if tun.routeSocket != -1 {
err3 = unix.Close(tun.routeSocket) unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
tun.routeSocket = -1 err3 = unix.Close(tun.routeSocket)
} else if tun.events != nil { tun.routeSocket = -1
close(tun.events) } else if tun.events != nil {
} close(tun.events)
}
})
if err1 != nil { if err1 != nil {
return err1 return err1
} }

View file

@ -39,6 +39,8 @@ type NativeTun struct {
hackListenerClosed sync.Mutex hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{} statusListenersShutdown chan struct{}
closeOnce sync.Once
nameOnce sync.Once // guards calling initNameCache, which sets following fields nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface nameCache string // name of interface
nameErr error nameErr error
@ -372,17 +374,18 @@ func (tun *NativeTun) Events() chan Event {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err1 error var err1, err2 error
if tun.statusListenersShutdown != nil { tun.closeOnce.Do(func() {
close(tun.statusListenersShutdown) if tun.statusListenersShutdown != nil {
if tun.netlinkCancel != nil { close(tun.statusListenersShutdown)
err1 = tun.netlinkCancel.Cancel() if tun.netlinkCancel != nil {
err1 = tun.netlinkCancel.Cancel()
}
} else if tun.events != nil {
close(tun.events)
} }
} else if tun.events != nil { err2 = tun.tunFile.Close()
close(tun.events) })
}
err2 := tun.tunFile.Close()
if err1 != nil { if err1 != nil {
return err1 return err1
} }

View file

@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"sync"
"syscall" "syscall"
"unsafe" "unsafe"
@ -32,6 +33,7 @@ type NativeTun struct {
events chan Event events chan Event
errors chan error errors chan error
routeSocket int routeSocket int
closeOnce sync.Once
} }
func (tun *NativeTun) routineRouteListener(tunIfindex int) { func (tun *NativeTun) routineRouteListener(tunIfindex int) {
@ -245,15 +247,17 @@ func (tun *NativeTun) Flush() error {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err2 error var err1, err2 error
err1 := tun.tunFile.Close() tun.closeOnce.Do(func() {
if tun.routeSocket != -1 { err1 = tun.tunFile.Close()
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) if tun.routeSocket != -1 {
err2 = unix.Close(tun.routeSocket) unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
tun.routeSocket = -1 err2 = unix.Close(tun.routeSocket)
} else if tun.events != nil { tun.routeSocket = -1
close(tun.events) } else if tun.events != nil {
} close(tun.events)
}
})
if err1 != nil { if err1 != nil {
return err1 return err1
} }

View file

@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"log" "log"
"os" "os"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
_ "unsafe" _ "unsafe"
@ -42,6 +43,7 @@ type NativeTun struct {
rate rateJuggler rate rateJuggler
session wintun.Session session wintun.Session
readWait windows.Handle readWait windows.Handle
closeOnce sync.Once
} }
var WintunPool, _ = wintun.MakePool("WireGuard") var WintunPool, _ = wintun.MakePool("WireGuard")
@ -122,13 +124,15 @@ func (tun *NativeTun) Events() chan Event {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
tun.close = true
tun.session.End()
var err error var err error
if tun.wt != nil { tun.closeOnce.Do(func() {
_, err = tun.wt.Delete(false) tun.close = true
} tun.session.End()
close(tun.events) if tun.wt != nil {
_, err = tun.wt.Delete(false)
}
close(tun.events)
})
return err return err
} }