From 11f57802506045a137d0e2022bfb16da4fc624f1 Mon Sep 17 00:00:00 2001
From: Simon Rozman <simon@rozman.si>
Date: Thu, 7 Mar 2019 15:19:27 +0100
Subject: [PATCH] wintun: Revise interface creation wait

DIF_INSTALLDEVICE returns almost immediately, while the device
installation continues in the background. It might take a while, before
all registry keys and values are populated.

Previously, wireguard-go waited for HKLM\SYSTEM\CurrentControlSet\
Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}\<id> registry key
only.

Followed by a SetInterfaceName() method of Wintun struct which tried to
access HKLM\SYSTEM\CurrentControlSet\Control\Network\
{4D36E972-E325-11CE-BFC1-08002BE10318}\<id>\Connection registry key
might not be available yet.

This commit loops until both registry keys are available before
returning from CreateInterface() function.

Signed-off-by: Simon Rozman <simon@rozman.si>
---
 tun/wintun/setupapi/setupapi_windows.go      |  29 ++++
 tun/wintun/setupapi/setupapi_windows_test.go |   5 +
 tun/wintun/wintun_windows.go                 | 131 ++++++++++---------
 3 files changed, 106 insertions(+), 59 deletions(-)

diff --git a/tun/wintun/setupapi/setupapi_windows.go b/tun/wintun/setupapi/setupapi_windows.go
index 71732a4..5f9e05c 100644
--- a/tun/wintun/setupapi/setupapi_windows.go
+++ b/tun/wintun/setupapi/setupapi_windows.go
@@ -7,12 +7,14 @@ package setupapi
 
 import (
 	"encoding/binary"
+	"errors"
 	"fmt"
 	"syscall"
 	"unsafe"
 
 	"golang.org/x/sys/windows"
 	"golang.org/x/sys/windows/registry"
+	"golang.zx2c4.com/wireguard/tun/wintun/guid"
 )
 
 //sys	setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) [failretval==DevInfo(windows.InvalidHandle)] = setupapi.SetupDiCreateDeviceInfoListExW
@@ -234,6 +236,33 @@ func (deviceInfoSet DevInfo) OpenDevRegKey(DeviceInfoData *DevInfoData, Scope DI
 	return SetupDiOpenDevRegKey(deviceInfoSet, DeviceInfoData, Scope, HwProfile, KeyType, samDesired)
 }
 
+// GetInterfaceID method returns network interface ID.
+func (deviceInfoSet DevInfo) GetInterfaceID(deviceInfoData *DevInfoData) (*windows.GUID, error) {
+	// Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
+	key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, DICS_FLAG_GLOBAL, 0, DIREG_DRV, registry.READ)
+	if err != nil {
+		return nil, errors.New("Device-specific registry key open failed: " + err.Error())
+	}
+	defer key.Close()
+
+	// Read the NetCfgInstanceId value.
+	value, valueType, err := key.GetStringValue("NetCfgInstanceId")
+	if err != nil {
+		return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error())
+	}
+	if valueType != registry.SZ {
+		return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType)
+	}
+
+	// Convert to windows.GUID.
+	ifid, err := guid.FromString(value)
+	if err != nil {
+		return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: \"%v\")", value)
+	}
+
+	return ifid, nil
+}
+
 //sys	setupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyRegDataType *uint32, propertyBuffer *byte, propertyBufferSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetDeviceRegistryPropertyW
 
 // SetupDiGetDeviceRegistryProperty function retrieves a specified Plug and Play device property.
diff --git a/tun/wintun/setupapi/setupapi_windows_test.go b/tun/wintun/setupapi/setupapi_windows_test.go
index 30f3692..c6f4a15 100644
--- a/tun/wintun/setupapi/setupapi_windows_test.go
+++ b/tun/wintun/setupapi/setupapi_windows_test.go
@@ -291,6 +291,11 @@ func TestSetupDiOpenDevRegKey(t *testing.T) {
 			t.Errorf("Error calling SetupDiOpenDevRegKey: %s", err.Error())
 		}
 		defer key.Close()
+
+		_, err = devInfoList.GetInterfaceID(data)
+		if err != nil {
+			t.Errorf("Error calling GetInterfaceID: %s", err.Error())
+		}
 	}
 }
 
diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go
index 85d29f4..69fd30c 100644
--- a/tun/wintun/wintun_windows.go
+++ b/tun/wintun/wintun_windows.go
@@ -58,27 +58,24 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
 
 	// Iterate.
 	for index := 0; ; index++ {
-		// Get the device from the list.
+		// Get the device from the list. Should anything be wrong with this device, continue with next.
 		deviceData, err := devInfoList.EnumDeviceInfo(index)
 		if err != nil {
 			if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ {
 				break
 			}
-			// Something is wrong with this device. Skip it.
 			continue
 		}
 
 		// Get interface ID.
-		ifid, err := getInterfaceID(devInfoList, deviceData, 1)
+		ifid, err := devInfoList.GetInterfaceID(deviceData)
 		if err != nil {
-			// Something is wrong with this device. Skip it.
 			continue
 		}
 
 		// Get interface name.
 		ifname2, err := ((*Wintun)(ifid)).GetInterfaceName()
 		if err != nil {
-			// Something is wrong with this device. Skip it.
 			continue
 		}
 
@@ -243,8 +240,74 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
 			rebootRequired = true
 		}
 
-		// Get network interface ID from registry. Retry for max 30sec.
-		ifid, err = getInterfaceID(devInfoList, deviceData, 30)
+		// Get network interface ID from registry. DIF_INSTALLDEVICE returns almost immediately,
+		// while the device installation continues in the background. It might take a while, before
+		// all registry keys and values are populated.
+		getInterfaceID := func() (*windows.GUID, error) {
+			// Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
+			keyDev, err := devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.READ)
+			if err != nil {
+				return nil, errors.New("Device-specific registry key open failed: " + err.Error())
+			}
+			defer keyDev.Close()
+
+			// Read the NetCfgInstanceId value.
+			value, err := getRegStringValue(keyDev, "NetCfgInstanceId")
+			if err != nil {
+				if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND {
+					return nil, err
+				}
+
+				return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error())
+			}
+
+			// Convert to windows.GUID.
+			ifid, err := guid.FromString(value)
+			if err != nil {
+				return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: \"%v\")", value)
+			}
+
+			keyNetName := fmt.Sprintf("SYSTEM\\CurrentControlSet\\Control\\Network\\%v\\%v\\Connection", guid.ToString(&deviceClassNetGUID), value)
+			keyNet, err := registry.OpenKey(registry.LOCAL_MACHINE, keyNetName, registry.QUERY_VALUE)
+			if err != nil {
+				if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND {
+					return nil, err
+				}
+
+				return nil, errors.New(fmt.Sprintf("RegOpenKeyEx(\"%v\") failed: ", keyNetName) + err.Error())
+			}
+			defer keyNet.Close()
+
+			// Query the interface name.
+			_, valueType, err := keyNet.GetValue("Name", nil)
+			if err != nil {
+				if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND {
+					return nil, err
+				}
+
+				return nil, errors.New("RegQueryValueEx(\"Name\") failed: " + err.Error())
+			}
+			switch valueType {
+			case registry.SZ, registry.EXPAND_SZ:
+			default:
+				return nil, fmt.Errorf("Interface name registry value is not REG_SZ or REG_EXPAND_SZ (expected: %v or %v, provided: %v)", registry.SZ, registry.EXPAND_SZ, valueType)
+			}
+
+			// TUN interface is ready. (As far as we need it.)
+			return ifid, nil
+		}
+		for numAttempts := 0; numAttempts < 30; numAttempts++ {
+			ifid, err = getInterfaceID()
+			if err != nil {
+				if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND {
+					// Wait and retry. TODO: Wait for a cancellable event instead.
+					time.Sleep(1000 * time.Millisecond)
+					continue
+				}
+			}
+
+			break
+		}
 	}
 
 	if err == nil {
@@ -294,20 +357,18 @@ func (wintun *Wintun) DeleteInterface(hwndParent uintptr) (bool, bool, error) {
 
 	// Iterate.
 	for index := 0; ; index++ {
-		// Get the device from the list.
+		// Get the device from the list. Should anything be wrong with this device, continue with next.
 		deviceData, err := devInfoList.EnumDeviceInfo(index)
 		if err != nil {
 			if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ {
 				break
 			}
-			// Something is wrong with this device. Skip it.
 			continue
 		}
 
 		// Get interface ID.
-		ifid2, err := getInterfaceID(devInfoList, deviceData, 1)
+		ifid2, err := devInfoList.GetInterfaceID(deviceData)
 		if err != nil {
-			// Something is wrong with this device. Skip it.
 			continue
 		}
 
@@ -367,54 +428,6 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
 	return false, nil
 }
 
-// getInterfaceID returns network interface ID.
-//
-// After the device is created, it might take some time before the registry
-// key is populated. numAttempts parameter specifies the number of attempts
-// to read NetCfgInstanceId value from registry. A 1sec sleep is inserted
-// between retry attempts.
-//
-// Function returns the network interface ID.
-//
-func getInterfaceID(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData, numAttempts int) (*windows.GUID, error) {
-	if numAttempts < 1 {
-		return nil, fmt.Errorf("Invalid numAttempts (expected: >=1, provided: %v)", numAttempts)
-	}
-
-	// Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
-	key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.READ)
-	if err != nil {
-		return nil, errors.New("Device-specific registry key open failed: " + err.Error())
-	}
-	defer key.Close()
-
-	for {
-		// Read the NetCfgInstanceId value.
-		value, err := getRegStringValue(key, "NetCfgInstanceId")
-		if err != nil {
-			if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND {
-				numAttempts--
-				if numAttempts > 0 {
-					// Wait and retry.
-					// TODO: Wait for a cancellable event instead.
-					time.Sleep(1000 * time.Millisecond)
-					continue
-				}
-			}
-
-			return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error())
-		}
-
-		// Convert to windows.GUID.
-		ifid, err := guid.FromString(value)
-		if err != nil {
-			return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: \"%v\")", value)
-		}
-
-		return ifid, err
-	}
-}
-
 //
 // GetInterfaceName returns network interface name.
 //