diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 2305b72..ff16e2f 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -40,10 +40,10 @@ type NativeTun struct { session wintun.Session readWait windows.Handle events chan Event - closing sync.RWMutex + running sync.WaitGroup closeOnce sync.Once + close int32 forcedMTU int - close bool } var WintunPool, _ = wintun.MakePool("WireGuard") @@ -111,9 +111,9 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu } func (tun *NativeTun) Name() (string, error) { - tun.closing.RLock() - defer tun.closing.RUnlock() - if tun.close { + tun.running.Add(1) + defer tun.running.Done() + if atomic.LoadInt32(&tun.close) == 1 { return "", os.ErrClosed } return tun.wt.Name() @@ -130,9 +130,9 @@ func (tun *NativeTun) Events() chan Event { func (tun *NativeTun) Close() error { var err error tun.closeOnce.Do(func() { - tun.closing.Lock() - defer tun.closing.Unlock() - tun.close = true + atomic.StoreInt32(&tun.close, 1) + windows.SetEvent(tun.readWait) + tun.running.Wait() tun.session.End() if tun.wt != nil { _, err = tun.wt.Delete(false) @@ -158,16 +158,16 @@ func (tun *NativeTun) ForceMTU(mtu int) { // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { - tun.closing.RLock() - defer tun.closing.RUnlock() + tun.running.Add(1) + defer tun.running.Done() retry: - if tun.close { + if atomic.LoadInt32(&tun.close) == 1 { return 0, os.ErrClosed } start := nanotime() shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 for { - if tun.close { + if atomic.LoadInt32(&tun.close) == 1 { return 0, os.ErrClosed } packet, err := tun.session.ReceivePacket() @@ -199,9 +199,9 @@ func (tun *NativeTun) Flush() error { } func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - tun.closing.RLock() - defer tun.closing.RUnlock() - if tun.close { + tun.running.Add(1) + defer tun.running.Done() + if atomic.LoadInt32(&tun.close) == 1 { return 0, os.ErrClosed } @@ -225,9 +225,9 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { // LUID returns Windows interface instance ID. func (tun *NativeTun) LUID() uint64 { - tun.closing.RLock() - defer tun.closing.RUnlock() - if tun.close { + tun.running.Add(1) + defer tun.running.Done() + if atomic.LoadInt32(&tun.close) == 1 { return 0 } return tun.wt.LUID()