From 0f4809f366daa77c6e2f5b09d3f05771fe9bf188 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 18 Feb 2021 14:53:22 -0800 Subject: [PATCH] tun: make NativeTun.Close well behaved, not crash on double close Signed-off-by: Brad Fitzpatrick --- tun/tun_darwin.go | 20 ++++++++++++-------- tun/tun_freebsd.go | 24 ++++++++++++++---------- tun/tun_linux.go | 23 +++++++++++++---------- tun/tun_openbsd.go | 22 +++++++++++++--------- tun/tun_windows.go | 16 ++++++++++------ 5 files changed, 62 insertions(+), 43 deletions(-) diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go index 542f666..a703c8c 100644 --- a/tun/tun_darwin.go +++ b/tun/tun_darwin.go @@ -10,6 +10,7 @@ import ( "fmt" "net" "os" + "sync" "syscall" "time" "unsafe" @@ -26,6 +27,7 @@ type NativeTun struct { events chan Event errors chan error routeSocket int + closeOnce sync.Once } func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { @@ -256,14 +258,16 @@ func (tun *NativeTun) Flush() error { } func (tun *NativeTun) Close() error { - var err2 error - err1 := tun.tunFile.Close() - if tun.routeSocket != -1 { - unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) - err2 = unix.Close(tun.routeSocket) - } else if tun.events != nil { - close(tun.events) - } + var err1, err2 error + tun.closeOnce.Do(func() { + err1 = tun.tunFile.Close() + if tun.routeSocket != -1 { + unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) + err2 = unix.Close(tun.routeSocket) + } else if tun.events != nil { + close(tun.events) + } + }) if err1 != nil { return err1 } diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go index e0dc2e1..12b44da 100644 --- a/tun/tun_freebsd.go +++ b/tun/tun_freebsd.go @@ -11,6 +11,7 @@ import ( "fmt" "net" "os" + "sync" "syscall" "unsafe" @@ -82,6 +83,7 @@ type NativeTun struct { events chan Event errors chan error routeSocket int + closeOnce sync.Once } func (tun *NativeTun) routineRouteListener(tunIfindex int) { @@ -472,16 +474,18 @@ func (tun *NativeTun) Flush() error { } func (tun *NativeTun) Close() error { - var err3 error - err1 := tun.tunFile.Close() - err2 := tunDestroy(tun.name) - if tun.routeSocket != -1 { - unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) - err3 = unix.Close(tun.routeSocket) - tun.routeSocket = -1 - } else if tun.events != nil { - close(tun.events) - } + var err1, err2, err3 error + tun.closeOnce.Do(func() { + err1 = tun.tunFile.Close() + err2 = tunDestroy(tun.name) + if tun.routeSocket != -1 { + unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) + err3 = unix.Close(tun.routeSocket) + tun.routeSocket = -1 + } else if tun.events != nil { + close(tun.events) + } + }) if err1 != nil { return err1 } diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 501f3a3..e0c9878 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -39,6 +39,8 @@ type NativeTun struct { hackListenerClosed sync.Mutex statusListenersShutdown chan struct{} + closeOnce sync.Once + nameOnce sync.Once // guards calling initNameCache, which sets following fields nameCache string // name of interface nameErr error @@ -372,17 +374,18 @@ func (tun *NativeTun) Events() chan Event { } func (tun *NativeTun) Close() error { - var err1 error - if tun.statusListenersShutdown != nil { - close(tun.statusListenersShutdown) - if tun.netlinkCancel != nil { - err1 = tun.netlinkCancel.Cancel() + var err1, err2 error + tun.closeOnce.Do(func() { + if tun.statusListenersShutdown != nil { + close(tun.statusListenersShutdown) + if tun.netlinkCancel != nil { + err1 = tun.netlinkCancel.Cancel() + } + } else if tun.events != nil { + close(tun.events) } - } else if tun.events != nil { - close(tun.events) - } - err2 := tun.tunFile.Close() - + err2 = tun.tunFile.Close() + }) if err1 != nil { return err1 } diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go index 8fca1e3..7ef62f4 100644 --- a/tun/tun_openbsd.go +++ b/tun/tun_openbsd.go @@ -10,6 +10,7 @@ import ( "fmt" "net" "os" + "sync" "syscall" "unsafe" @@ -32,6 +33,7 @@ type NativeTun struct { events chan Event errors chan error routeSocket int + closeOnce sync.Once } func (tun *NativeTun) routineRouteListener(tunIfindex int) { @@ -245,15 +247,17 @@ func (tun *NativeTun) Flush() error { } func (tun *NativeTun) Close() error { - var err2 error - err1 := tun.tunFile.Close() - if tun.routeSocket != -1 { - unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) - err2 = unix.Close(tun.routeSocket) - tun.routeSocket = -1 - } else if tun.events != nil { - close(tun.events) - } + var err1, err2 error + tun.closeOnce.Do(func() { + err1 = tun.tunFile.Close() + if tun.routeSocket != -1 { + unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) + err2 = unix.Close(tun.routeSocket) + tun.routeSocket = -1 + } else if tun.events != nil { + close(tun.events) + } + }) if err1 != nil { return err1 } diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 081b5e2..9d83db7 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -10,6 +10,7 @@ import ( "fmt" "log" "os" + "sync" "sync/atomic" "time" _ "unsafe" @@ -42,6 +43,7 @@ type NativeTun struct { rate rateJuggler session wintun.Session readWait windows.Handle + closeOnce sync.Once } var WintunPool, _ = wintun.MakePool("WireGuard") @@ -122,13 +124,15 @@ func (tun *NativeTun) Events() chan Event { } func (tun *NativeTun) Close() error { - tun.close = true - tun.session.End() var err error - if tun.wt != nil { - _, err = tun.wt.Delete(false) - } - close(tun.events) + tun.closeOnce.Do(func() { + tun.close = true + tun.session.End() + if tun.wt != nil { + _, err = tun.wt.Delete(false) + } + close(tun.events) + }) return err }