diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go index 7550263..55d4ed0 100644 --- a/tun/wintun/wintun_windows.go +++ b/tun/wintun/wintun_windows.go @@ -262,14 +262,23 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W devInfoList.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, deviceData) var key registry.Key - if requestedGUID != nil { - key, err = devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE) - if err != nil { - err = fmt.Errorf("OpenDevRegKey failed: %v", err) - return + const pollTimeout = time.Millisecond * 50 + for i := 0; i < int(waitForRegistryTimeout/pollTimeout); i++ { + if i != 0 { + time.Sleep(pollTimeout) } + key, err = devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY) + if err == nil { + break + } + } + if err != nil { + err = fmt.Errorf("SetupDiOpenDevRegKey failed: %v", err) + return + } + defer key.Close() + if requestedGUID != nil { err = key.SetStringValue("NetSetupAnticipatedInstanceId", requestedGUID.String()) - key.Close() if err != nil { err = fmt.Errorf("SetStringValue(NetSetupAnticipatedInstanceId) failed: %v", err) return @@ -309,21 +318,6 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W // DIF_INSTALLDEVICE returns almost immediately, while the device installation // continues in the background. It might take a while, before all registry // keys and values are populated. - const pollTimeout = time.Millisecond * 50 - for i := 0; i < int(waitForRegistryTimeout/pollTimeout); i++ { - if i != 0 { - time.Sleep(pollTimeout) - } - key, err = devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE|registry.NOTIFY) - if err == nil { - break - } - } - if err != nil { - err = fmt.Errorf("SetupDiOpenDevRegKey failed: %v", err) - return - } - defer key.Close() _, err = registryEx.GetStringValueWait(key, "NetCfgInstanceId", waitForRegistryTimeout) if err != nil { err = fmt.Errorf("GetStringValueWait(NetCfgInstanceId) failed: %v", err)