diff --git a/tun/tun_windows.go b/tun/tun_windows.go index aa54ad5..c34204c 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -71,6 +71,8 @@ type NativeTun struct { rate rateJuggler } +const WintunPool = wintun.Pool("WireGuard") + //go:linkname procyield runtime.procyield func procyield(cycles uint32) @@ -98,22 +100,20 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev var wt *wintun.Wintun // Does an interface with this name already exist? - wt, err = wintun.GetInterface(ifname) + wt, err = WintunPool.GetInterface(ifname) if err == nil { // If so, we delete it, in case it has weird residual configuration. _, err = wt.DeleteInterface() if err != nil { return nil, fmt.Errorf("Unable to delete already existing Wintun interface: %v", err) } - } else if err == windows.ERROR_ALREADY_EXISTS { - return nil, fmt.Errorf("Foreign network interface with the same name exists") } - wt, _, err = wintun.CreateInterface(requestedGUID) + wt, _, err = WintunPool.CreateInterface(requestedGUID) if err != nil { return nil, fmt.Errorf("Unable to create Wintun interface: %v", err) } - err = wt.SetInterfaceName(ifname) + err = wt.SetInterfaceName(ifname, WintunPool) if err != nil { wt.DeleteInterface() return nil, fmt.Errorf("Unable to set name of Wintun interface: %v", err) diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go index 5624525..857aceb 100644 --- a/tun/wintun/wintun_windows.go +++ b/tun/wintun/wintun_windows.go @@ -21,6 +21,8 @@ import ( "golang.zx2c4.com/wireguard/tun/wintun/setupapi" ) +type Pool string + // Wintun is a handle of a Wintun adapter. type Wintun struct { cfgInstanceID windows.GUID @@ -33,7 +35,6 @@ var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0 var deviceInterfaceNetGUID = windows.GUID{Data1: 0xcac88484, Data2: 0x7515, Data3: 0x4c03, Data4: [8]byte{0x82, 0xe6, 0x71, 0xa8, 0x7a, 0xba, 0xc3, 0x61}} const ( - deviceTypeName = "WireGuard Tunnel" hardwareID = "Wintun" waitForRegistryTimeout = time.Second * 10 ) @@ -94,9 +95,9 @@ func removeNumberedSuffix(ifname string) string { // GetInterface finds a Wintun interface by its name. This function returns // the interface if found, or windows.ERROR_OBJECT_NOT_FOUND otherwise. If -// the interface is found but not a Wintun-class, this function returns -// windows.ERROR_ALREADY_EXISTS. -func GetInterface(ifname string) (*Wintun, error) { +// the interface is found but not a Wintun-class or a member of the pool, +// this function returns windows.ERROR_ALREADY_EXISTS. +func (pool Pool) GetInterface(ifname string) (*Wintun, error) { // Create a list of network devices. devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") if err != nil { @@ -156,6 +157,14 @@ func GetInterface(ifname string) (*Wintun, error) { } if driverDetailData.IsCompatible(hardwareID) { + isMember, err := pool.isMember(devInfoList, deviceData) + if err != nil { + return nil, err + } + if !isMember { + return nil, windows.ERROR_ALREADY_EXISTS + } + return wintun, nil } } @@ -175,7 +184,7 @@ func GetInterface(ifname string) (*Wintun, error) { // called "requested" GUID because the API it uses is completely undocumented, // and so there could be minor interesting complications with its usage. This // function returns the network interface ID and a flag if reboot is required. -func CreateInterface(requestedGUID *windows.GUID) (wintun *Wintun, rebootRequired bool, err error) { +func (pool Pool) CreateInterface(requestedGUID *windows.GUID) (wintun *Wintun, rebootRequired bool, err error) { // Create an empty device info set for network adapter device class. devInfoList, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "") if err != nil { @@ -192,6 +201,7 @@ func CreateInterface(requestedGUID *windows.GUID) (wintun *Wintun, rebootRequire } // Create a new device info element and add it to the device info set. + deviceTypeName := pool.DeviceTypeName() deviceData, err := devInfoList.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID) if err != nil { err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err) @@ -432,7 +442,7 @@ func (wintun *Wintun) DeleteInterface() (rebootRequired bool, err error) { // DeleteMatchingInterfaces deletes all Wintun interfaces, which match // given criteria, and returns which ones it deleted, whether a reboot // is required after, and which errors occurred during the process. -func DeleteMatchingInterfaces(matches func(wintun *Wintun) bool) (deviceInstancesDeleted []uint32, rebootRequired bool, errors []error) { +func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Wintun) bool) (deviceInstancesDeleted []uint32, rebootRequired bool, errors []error) { devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") if err != nil { return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())} @@ -476,20 +486,12 @@ func DeleteMatchingInterfaces(matches func(wintun *Wintun) bool) (deviceInstance continue } - deviceDescVal, err := devInfoList.DeviceRegistryProperty(deviceData, setupapi.SPDRP_DEVICEDESC) + isMember, err := pool.isMember(devInfoList, deviceData) if err != nil { - errors = append(errors, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)) + errors = append(errors, err) continue } - deviceDesc, _ := deviceDescVal.(string) - friendlyNameVal, err := devInfoList.DeviceRegistryProperty(deviceData, setupapi.SPDRP_FRIENDLYNAME) - if err != nil { - errors = append(errors, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err)) - continue - } - friendlyName, _ := friendlyNameVal.(string) - if friendlyName != deviceTypeName && deviceDesc != deviceTypeName && - removeNumberedSuffix(friendlyName) != deviceTypeName && removeNumberedSuffix(deviceDesc) != deviceTypeName { + if !isMember { continue } @@ -532,12 +534,29 @@ func DeleteMatchingInterfaces(matches func(wintun *Wintun) bool) (deviceInstance // DeleteAllInterfaces deletes all Wintun interfaces, and returns which // ones it deleted, whether a reboot is required after, and which errors // occurred during the process. -func DeleteAllInterfaces() (deviceInstancesDeleted []uint32, rebootRequired bool, errors []error) { - return DeleteMatchingInterfaces(func(wintun *Wintun) bool { +func (pool Pool) DeleteAllInterfaces() (deviceInstancesDeleted []uint32, rebootRequired bool, errors []error) { + return pool.DeleteMatchingInterfaces(func(wintun *Wintun) bool { return true }) } +// isMember checks if SPDRP_DEVICEDESC or SPDRP_FRIENDLYNAME match device type name. +func (pool Pool) isMember(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) (bool, error) { + deviceDescVal, err := deviceInfoSet.DeviceRegistryProperty(deviceInfoData, setupapi.SPDRP_DEVICEDESC) + if err != nil { + return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err) + } + deviceDesc, _ := deviceDescVal.(string) + friendlyNameVal, err := deviceInfoSet.DeviceRegistryProperty(deviceInfoData, setupapi.SPDRP_FRIENDLYNAME) + if err != nil { + return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err) + } + friendlyName, _ := friendlyNameVal.(string) + deviceTypeName := pool.DeviceTypeName() + return friendlyName == deviceTypeName || deviceDesc == deviceTypeName || + removeNumberedSuffix(friendlyName) == deviceTypeName || removeNumberedSuffix(deviceDesc) == deviceTypeName, nil +} + // checkReboot checks device install parameters if a system reboot is required. func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) bool { devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData) @@ -559,13 +578,18 @@ func setQuietInstall(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.De return deviceInfoSet.SetDeviceInstallParams(deviceInfoData, devInstallParams) } +// DeviceTypeName returns pool-specific device type name. +func (pool Pool) DeviceTypeName() string { + return fmt.Sprintf("%s Tunnel", pool) +} + // InterfaceName returns the name of the Wintun interface. func (wintun *Wintun) InterfaceName() (string, error) { return nci.ConnectionName(&wintun.cfgInstanceID) } // SetInterfaceName sets name of the Wintun interface. -func (wintun *Wintun) SetInterfaceName(ifname string) error { +func (wintun *Wintun) SetInterfaceName(ifname string, pool Pool) error { const maxSuffix = 1000 availableIfname := ifname for i := 0; ; i++ { @@ -608,7 +632,7 @@ func (wintun *Wintun) SetInterfaceName(ifname string) error { return fmt.Errorf("Device-level registry key open failed: %v", err) } defer deviceRegKey.Close() - err = deviceRegKey.SetStringValue("FriendlyName", deviceTypeName) + err = deviceRegKey.SetStringValue("FriendlyName", pool.DeviceTypeName()) if err != nil { return fmt.Errorf("SetStringValue(FriendlyName) failed: %v", err) }