wintun: add more retry loops

This commit is contained in:
Jason A. Donenfeld 2019-03-31 10:17:11 +02:00
parent 2e0ed4614a
commit 92f8474832
3 changed files with 56 additions and 30 deletions

View file

@ -75,18 +75,11 @@ func CreateTUN(ifname string) (TUNDevice, error) {
return nil, err
}
go func() {
retries := retryTimeout * retryRate
for {
err := wt.SetInterfaceName(ifname)
if err != nil && retries > 0 {
time.Sleep(time.Second / retryRate)
retries--
continue
}
return
}
}()
err = wt.SetInterfaceName(ifname)
if err != nil {
wt.DeleteInterface(0)
return nil, err
}
err = wt.FlushInterface()
if err != nil {

View file

@ -0,0 +1,42 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"golang.org/x/sys/windows/registry"
"time"
)
const (
numRetries = 25
retryTimeout = 100 * time.Millisecond
)
func registryOpenKeyRetry(k registry.Key, path string, access uint32) (key registry.Key, err error) {
for i := 0; i < numRetries; i++ {
key, err = registry.OpenKey(k, path, access)
if err == nil {
break
}
if i != numRetries - 1 {
time.Sleep(retryTimeout)
}
}
return
}
func keyGetStringValueRetry(k registry.Key, name string) (val string, valtype uint32, err error) {
for i := 0; i < numRetries; i++ {
val, valtype, err = k.GetStringValue(name)
if err == nil {
break
}
if i != numRetries - 1 {
time.Sleep(retryTimeout)
}
}
return
}

View file

@ -48,22 +48,14 @@ func MakeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfo
var valueStr string
var valueType uint32
//TODO: Figure out a way to not need to loop like this.
for i := 0; i < 30; i++ {
// Read the NetCfgInstanceId value.
valueStr, valueType, err = key.GetStringValue("NetCfgInstanceId")
if err != nil {
time.Sleep(time.Millisecond * 100)
continue
}
if valueType != registry.SZ {
return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType)
}
break
}
// Read the NetCfgInstanceId value.
valueStr, valueType, err = keyGetStringValueRetry(key, "NetCfgInstanceId")
if err != nil {
return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error())
}
if valueType != registry.SZ {
return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType)
}
// Convert to windows.GUID.
ifid, err := guid.FromString(valueStr)
@ -117,7 +109,6 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
// "foobar" would cause conflict with "FooBar".
ifname = strings.ToLower(ifname)
// Iterate.
for index := 0; ; index++ {
// Get the device from the list. Should anything be wrong with this device, continue with next.
deviceData, err := devInfoList.EnumDeviceInfo(index)
@ -174,7 +165,7 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
}
// This interface is not using Wintun driver.
return wintun, errors.New("Foreign network interface with the same name exists")
return nil, errors.New("Foreign network interface with the same name exists")
}
}
@ -444,7 +435,7 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
// GetInterfaceName returns network interface name.
//
func (wintun *Wintun) GetInterfaceName() (string, error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE)
key, err := registryOpenKeyRetry(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE)
if err != nil {
return "", errors.New("Network-specific registry key open failed: " + err.Error())
}
@ -458,7 +449,7 @@ func (wintun *Wintun) GetInterfaceName() (string, error) {
// SetInterfaceName sets network interface name.
//
func (wintun *Wintun) SetInterfaceName(ifname string) error {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE)
key, err := registryOpenKeyRetry(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE)
if err != nil {
return errors.New("Network-specific registry key open failed: " + err.Error())
}
@ -483,7 +474,7 @@ func (wintun *Wintun) GetNetRegKeyName() string {
//
func getRegStringValue(key registry.Key, name string) (string, error) {
// Read string value.
value, valueType, err := key.GetStringValue(name)
value, valueType, err := keyGetStringValueRetry(key, name)
if err != nil {
return "", err
}