diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 9d83db7..c8b8d39 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -37,8 +37,8 @@ type NativeTun struct { wt *wintun.Adapter handle windows.Handle close bool + closing sync.RWMutex events chan Event - errors chan error forcedMTU int rate rateJuggler session wintun.Session @@ -97,7 +97,6 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu wt: wt, handle: windows.InvalidHandle, events: make(chan Event, 10), - errors: make(chan error, 1), forcedMTU: forcedMTU, } @@ -112,6 +111,11 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu } func (tun *NativeTun) Name() (string, error) { + tun.closing.RLock() + defer tun.closing.RUnlock() + if tun.close { + return "", os.ErrClosed + } return tun.wt.Name() } @@ -126,6 +130,8 @@ func (tun *NativeTun) Events() chan Event { func (tun *NativeTun) Close() error { var err error tun.closeOnce.Do(func() { + tun.closing.Lock() + defer tun.closing.Unlock() tun.close = true tun.session.End() if tun.wt != nil { @@ -148,11 +154,11 @@ func (tun *NativeTun) ForceMTU(mtu int) { // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { + tun.closing.RLock() + defer tun.closing.RUnlock() retry: - select { - case err := <-tun.errors: - return 0, err - default: + if tun.close { + return 0, os.ErrClosed } start := nanotime() shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 @@ -189,6 +195,8 @@ func (tun *NativeTun) Flush() error { } func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { + tun.closing.RLock() + defer tun.closing.RUnlock() if tun.close { return 0, os.ErrClosed } @@ -213,6 +221,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { // LUID returns Windows interface instance ID. func (tun *NativeTun) LUID() uint64 { + tun.closing.RLock() + defer tun.closing.RUnlock() + if tun.close { + return 0 + } return tun.wt.LUID() }