tun: make NativeTun.Close well behaved, not crash on double close

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2021-02-18 14:53:22 -08:00 committed by Jason A. Donenfeld
parent fecb8f482a
commit 0f4809f366
5 changed files with 62 additions and 43 deletions

View file

@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"sync"
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
@ -26,6 +27,7 @@ type NativeTun struct {
events chan Event events chan Event
errors chan error errors chan error
routeSocket int routeSocket int
closeOnce sync.Once
} }
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
@ -256,14 +258,16 @@ func (tun *NativeTun) Flush() error {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err2 error var err1, err2 error
err1 := tun.tunFile.Close() tun.closeOnce.Do(func() {
err1 = tun.tunFile.Close()
if tun.routeSocket != -1 { if tun.routeSocket != -1 {
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
err2 = unix.Close(tun.routeSocket) err2 = unix.Close(tun.routeSocket)
} else if tun.events != nil { } else if tun.events != nil {
close(tun.events) close(tun.events)
} }
})
if err1 != nil { if err1 != nil {
return err1 return err1
} }

View file

@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"sync"
"syscall" "syscall"
"unsafe" "unsafe"
@ -82,6 +83,7 @@ type NativeTun struct {
events chan Event events chan Event
errors chan error errors chan error
routeSocket int routeSocket int
closeOnce sync.Once
} }
func (tun *NativeTun) routineRouteListener(tunIfindex int) { func (tun *NativeTun) routineRouteListener(tunIfindex int) {
@ -472,9 +474,10 @@ func (tun *NativeTun) Flush() error {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err3 error var err1, err2, err3 error
err1 := tun.tunFile.Close() tun.closeOnce.Do(func() {
err2 := tunDestroy(tun.name) err1 = tun.tunFile.Close()
err2 = tunDestroy(tun.name)
if tun.routeSocket != -1 { if tun.routeSocket != -1 {
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
err3 = unix.Close(tun.routeSocket) err3 = unix.Close(tun.routeSocket)
@ -482,6 +485,7 @@ func (tun *NativeTun) Close() error {
} else if tun.events != nil { } else if tun.events != nil {
close(tun.events) close(tun.events)
} }
})
if err1 != nil { if err1 != nil {
return err1 return err1
} }

View file

@ -39,6 +39,8 @@ type NativeTun struct {
hackListenerClosed sync.Mutex hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{} statusListenersShutdown chan struct{}
closeOnce sync.Once
nameOnce sync.Once // guards calling initNameCache, which sets following fields nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface nameCache string // name of interface
nameErr error nameErr error
@ -372,7 +374,8 @@ func (tun *NativeTun) Events() chan Event {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err1 error var err1, err2 error
tun.closeOnce.Do(func() {
if tun.statusListenersShutdown != nil { if tun.statusListenersShutdown != nil {
close(tun.statusListenersShutdown) close(tun.statusListenersShutdown)
if tun.netlinkCancel != nil { if tun.netlinkCancel != nil {
@ -381,8 +384,8 @@ func (tun *NativeTun) Close() error {
} else if tun.events != nil { } else if tun.events != nil {
close(tun.events) close(tun.events)
} }
err2 := tun.tunFile.Close() err2 = tun.tunFile.Close()
})
if err1 != nil { if err1 != nil {
return err1 return err1
} }

View file

@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"sync"
"syscall" "syscall"
"unsafe" "unsafe"
@ -32,6 +33,7 @@ type NativeTun struct {
events chan Event events chan Event
errors chan error errors chan error
routeSocket int routeSocket int
closeOnce sync.Once
} }
func (tun *NativeTun) routineRouteListener(tunIfindex int) { func (tun *NativeTun) routineRouteListener(tunIfindex int) {
@ -245,8 +247,9 @@ func (tun *NativeTun) Flush() error {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err2 error var err1, err2 error
err1 := tun.tunFile.Close() tun.closeOnce.Do(func() {
err1 = tun.tunFile.Close()
if tun.routeSocket != -1 { if tun.routeSocket != -1 {
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
err2 = unix.Close(tun.routeSocket) err2 = unix.Close(tun.routeSocket)
@ -254,6 +257,7 @@ func (tun *NativeTun) Close() error {
} else if tun.events != nil { } else if tun.events != nil {
close(tun.events) close(tun.events)
} }
})
if err1 != nil { if err1 != nil {
return err1 return err1
} }

View file

@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"log" "log"
"os" "os"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
_ "unsafe" _ "unsafe"
@ -42,6 +43,7 @@ type NativeTun struct {
rate rateJuggler rate rateJuggler
session wintun.Session session wintun.Session
readWait windows.Handle readWait windows.Handle
closeOnce sync.Once
} }
var WintunPool, _ = wintun.MakePool("WireGuard") var WintunPool, _ = wintun.MakePool("WireGuard")
@ -122,13 +124,15 @@ func (tun *NativeTun) Events() chan Event {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err error
tun.closeOnce.Do(func() {
tun.close = true tun.close = true
tun.session.End() tun.session.End()
var err error
if tun.wt != nil { if tun.wt != nil {
_, err = tun.wt.Delete(false) _, err = tun.wt.Delete(false)
} }
close(tun.events) close(tun.events)
})
return err return err
} }