wintun: add more retry loops

This commit is contained in:
Jason A. Donenfeld 2019-03-31 10:17:11 +02:00
parent 2e0ed4614a
commit 92f8474832
3 changed files with 56 additions and 30 deletions

View file

@ -75,18 +75,11 @@ func CreateTUN(ifname string) (TUNDevice, error) {
return nil, err return nil, err
} }
go func() { err = wt.SetInterfaceName(ifname)
retries := retryTimeout * retryRate if err != nil {
for { wt.DeleteInterface(0)
err := wt.SetInterfaceName(ifname) return nil, err
if err != nil && retries > 0 {
time.Sleep(time.Second / retryRate)
retries--
continue
} }
return
}
}()
err = wt.FlushInterface() err = wt.FlushInterface()
if err != nil { if err != nil {

View file

@ -0,0 +1,42 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"golang.org/x/sys/windows/registry"
"time"
)
const (
numRetries = 25
retryTimeout = 100 * time.Millisecond
)
func registryOpenKeyRetry(k registry.Key, path string, access uint32) (key registry.Key, err error) {
for i := 0; i < numRetries; i++ {
key, err = registry.OpenKey(k, path, access)
if err == nil {
break
}
if i != numRetries - 1 {
time.Sleep(retryTimeout)
}
}
return
}
func keyGetStringValueRetry(k registry.Key, name string) (val string, valtype uint32, err error) {
for i := 0; i < numRetries; i++ {
val, valtype, err = k.GetStringValue(name)
if err == nil {
break
}
if i != numRetries - 1 {
time.Sleep(retryTimeout)
}
}
return
}

View file

@ -48,22 +48,14 @@ func MakeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfo
var valueStr string var valueStr string
var valueType uint32 var valueType uint32
//TODO: Figure out a way to not need to loop like this.
for i := 0; i < 30; i++ {
// Read the NetCfgInstanceId value. // Read the NetCfgInstanceId value.
valueStr, valueType, err = key.GetStringValue("NetCfgInstanceId") valueStr, valueType, err = keyGetStringValueRetry(key, "NetCfgInstanceId")
if err != nil { if err != nil {
time.Sleep(time.Millisecond * 100) return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error())
continue
} }
if valueType != registry.SZ { if valueType != registry.SZ {
return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType) return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType)
} }
break
}
if err != nil {
return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error())
}
// Convert to windows.GUID. // Convert to windows.GUID.
ifid, err := guid.FromString(valueStr) ifid, err := guid.FromString(valueStr)
@ -117,7 +109,6 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
// "foobar" would cause conflict with "FooBar". // "foobar" would cause conflict with "FooBar".
ifname = strings.ToLower(ifname) ifname = strings.ToLower(ifname)
// Iterate.
for index := 0; ; index++ { for index := 0; ; index++ {
// Get the device from the list. Should anything be wrong with this device, continue with next. // Get the device from the list. Should anything be wrong with this device, continue with next.
deviceData, err := devInfoList.EnumDeviceInfo(index) deviceData, err := devInfoList.EnumDeviceInfo(index)
@ -174,7 +165,7 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
} }
// This interface is not using Wintun driver. // This interface is not using Wintun driver.
return wintun, errors.New("Foreign network interface with the same name exists") return nil, errors.New("Foreign network interface with the same name exists")
} }
} }
@ -444,7 +435,7 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
// GetInterfaceName returns network interface name. // GetInterfaceName returns network interface name.
// //
func (wintun *Wintun) GetInterfaceName() (string, error) { func (wintun *Wintun) GetInterfaceName() (string, error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE) key, err := registryOpenKeyRetry(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE)
if err != nil { if err != nil {
return "", errors.New("Network-specific registry key open failed: " + err.Error()) return "", errors.New("Network-specific registry key open failed: " + err.Error())
} }
@ -458,7 +449,7 @@ func (wintun *Wintun) GetInterfaceName() (string, error) {
// SetInterfaceName sets network interface name. // SetInterfaceName sets network interface name.
// //
func (wintun *Wintun) SetInterfaceName(ifname string) error { func (wintun *Wintun) SetInterfaceName(ifname string) error {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE) key, err := registryOpenKeyRetry(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE)
if err != nil { if err != nil {
return errors.New("Network-specific registry key open failed: " + err.Error()) return errors.New("Network-specific registry key open failed: " + err.Error())
} }
@ -483,7 +474,7 @@ func (wintun *Wintun) GetNetRegKeyName() string {
// //
func getRegStringValue(key registry.Key, name string) (string, error) { func getRegStringValue(key registry.Key, name string) (string, error) {
// Read string value. // Read string value.
value, valueType, err := key.GetStringValue(name) value, valueType, err := keyGetStringValueRetry(key, name)
if err != nil { if err != nil {
return "", err return "", err
} }