tun: windows: protect reads from closing

The code previously used the old errors channel for checking, rather
than the simpler boolean, which caused issues on shutdown, since the
errors channel was meaningless. However, looking at this exposed a more
basic problem: Close() and all the other functions that check the closed
boolean can race. So protect with a basic RW lock, to ensure that
Close() waits for all pending operations to complete.

Reported-by: Joshua Sjoding <joshua.sjoding@scjalliance.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-04-26 22:22:45 -04:00
parent 8246d251ea
commit 097af6e135
1 changed files with 19 additions and 6 deletions

View File

@ -37,8 +37,8 @@ type NativeTun struct {
wt *wintun.Adapter wt *wintun.Adapter
handle windows.Handle handle windows.Handle
close bool close bool
closing sync.RWMutex
events chan Event events chan Event
errors chan error
forcedMTU int forcedMTU int
rate rateJuggler rate rateJuggler
session wintun.Session session wintun.Session
@ -97,7 +97,6 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
wt: wt, wt: wt,
handle: windows.InvalidHandle, handle: windows.InvalidHandle,
events: make(chan Event, 10), events: make(chan Event, 10),
errors: make(chan error, 1),
forcedMTU: forcedMTU, forcedMTU: forcedMTU,
} }
@ -112,6 +111,11 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
} }
func (tun *NativeTun) Name() (string, error) { func (tun *NativeTun) Name() (string, error) {
tun.closing.RLock()
defer tun.closing.RUnlock()
if tun.close {
return "", os.ErrClosed
}
return tun.wt.Name() return tun.wt.Name()
} }
@ -126,6 +130,8 @@ func (tun *NativeTun) Events() chan Event {
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err error var err error
tun.closeOnce.Do(func() { tun.closeOnce.Do(func() {
tun.closing.Lock()
defer tun.closing.Unlock()
tun.close = true tun.close = true
tun.session.End() tun.session.End()
if tun.wt != nil { if tun.wt != nil {
@ -148,11 +154,11 @@ func (tun *NativeTun) ForceMTU(mtu int) {
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
tun.closing.RLock()
defer tun.closing.RUnlock()
retry: retry:
select { if tun.close {
case err := <-tun.errors: return 0, os.ErrClosed
return 0, err
default:
} }
start := nanotime() start := nanotime()
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
@ -189,6 +195,8 @@ func (tun *NativeTun) Flush() error {
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
tun.closing.RLock()
defer tun.closing.RUnlock()
if tun.close { if tun.close {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
@ -213,6 +221,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
// LUID returns Windows interface instance ID. // LUID returns Windows interface instance ID.
func (tun *NativeTun) LUID() uint64 { func (tun *NativeTun) LUID() uint64 {
tun.closing.RLock()
defer tun.closing.RUnlock()
if tun.close {
return 0
}
return tun.wt.LUID() return tun.wt.LUID()
} }