wintun: use new swdevice-based API for upcoming Wintun 0.14

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-10-12 00:26:46 -06:00
parent 982d5d2e84
commit 82d2aa87aa
2 changed files with 65 additions and 153 deletions

View file

@ -8,7 +8,6 @@ package tun
import ( import (
"errors" "errors"
"fmt" "fmt"
"log"
"os" "os"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -35,6 +34,7 @@ type rateJuggler struct {
type NativeTun struct { type NativeTun struct {
wt *wintun.Adapter wt *wintun.Adapter
name string
handle windows.Handle handle windows.Handle
rate rateJuggler rate rateJuggler
session wintun.Session session wintun.Session
@ -46,7 +46,7 @@ type NativeTun struct {
forcedMTU int forcedMTU int
} }
var WintunPool, _ = wintun.MakePool("WireGuard") var WintunTunnelType = "WireGuard"
var WintunStaticRequestedGUID *windows.GUID var WintunStaticRequestedGUID *windows.GUID
//go:linkname procyield runtime.procyield //go:linkname procyield runtime.procyield
@ -68,25 +68,10 @@ func CreateTUN(ifname string, mtu int) (Device, error) {
// a requested GUID. Should a Wintun interface with the same name exist, it is reused. // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
// //
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
var err error wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
var wt *wintun.Adapter
// Does an interface with this name already exist?
wt, err = WintunPool.OpenAdapter(ifname)
if err == nil {
// If so, we delete it, in case it has weird residual configuration.
_, err = wt.Delete(true)
if err != nil {
return nil, fmt.Errorf("Error deleting already existing interface: %w", err)
}
}
wt, rebootRequired, err := WintunPool.CreateAdapter(ifname, requestedGUID)
if err != nil { if err != nil {
return nil, fmt.Errorf("Error creating interface: %w", err) return nil, fmt.Errorf("Error creating interface: %w", err)
} }
if rebootRequired {
log.Println("Windows indicated a reboot is required.")
}
forcedMTU := 1420 forcedMTU := 1420
if mtu > 0 { if mtu > 0 {
@ -95,6 +80,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
tun := &NativeTun{ tun := &NativeTun{
wt: wt, wt: wt,
name: ifname,
handle: windows.InvalidHandle, handle: windows.InvalidHandle,
events: make(chan Event, 10), events: make(chan Event, 10),
forcedMTU: forcedMTU, forcedMTU: forcedMTU,
@ -102,7 +88,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
if err != nil { if err != nil {
tun.wt.Delete(false) tun.wt.Close()
close(tun.events) close(tun.events)
return nil, fmt.Errorf("Error starting session: %w", err) return nil, fmt.Errorf("Error starting session: %w", err)
} }
@ -111,12 +97,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
} }
func (tun *NativeTun) Name() (string, error) { func (tun *NativeTun) Name() (string, error) {
tun.running.Add(1) return tun.name, nil
defer tun.running.Done()
if atomic.LoadInt32(&tun.close) == 1 {
return "", os.ErrClosed
}
return tun.wt.Name()
} }
func (tun *NativeTun) File() *os.File { func (tun *NativeTun) File() *os.File {
@ -135,7 +116,7 @@ func (tun *NativeTun) Close() error {
tun.running.Wait() tun.running.Wait()
tun.session.End() tun.session.End()
if tun.wt != nil { if tun.wt != nil {
_, err = tun.wt.Delete(false) tun.wt.Close()
} }
close(tun.events) close(tun.events)
}) })

View file

@ -6,7 +6,6 @@
package wintun package wintun
import ( import (
"errors"
"log" "log"
"runtime" "runtime"
"syscall" "syscall"
@ -23,175 +22,107 @@ const (
logErr logErr
) )
const ( const AdapterNameMax = 128
PoolNameMax = 256
AdapterNameMax = 128
)
type Pool [PoolNameMax]uint16
type Adapter struct { type Adapter struct {
handle uintptr handle uintptr
} }
var ( var (
modwintun = newLazyDLL("wintun.dll", setupLogger) modwintun = newLazyDLL("wintun.dll", setupLogger)
procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter") procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter")
procWintunDeleteAdapter = modwintun.NewProc("WintunDeleteAdapter")
procWintunDeletePoolDriver = modwintun.NewProc("WintunDeletePoolDriver")
procWintunEnumAdapters = modwintun.NewProc("WintunEnumAdapters")
procWintunFreeAdapter = modwintun.NewProc("WintunFreeAdapter")
procWintunOpenAdapter = modwintun.NewProc("WintunOpenAdapter") procWintunOpenAdapter = modwintun.NewProc("WintunOpenAdapter")
procWintunCloseAdapter = modwintun.NewProc("WintunCloseAdapter")
procWintunDeleteDriver = modwintun.NewProc("WintunDeleteDriver")
procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID") procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID")
procWintunGetAdapterName = modwintun.NewProc("WintunGetAdapterName")
procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion") procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion")
procWintunSetAdapterName = modwintun.NewProc("WintunSetAdapterName")
) )
func setupLogger(dll *lazyDLL) { type TimestampedWriter interface {
syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, windows.NewCallback(func(level loggerLevel, msg *uint16) int { WriteWithTimestamp(p []byte, ts int64) (n int, err error)
log.Println("[Wintun]", windows.UTF16PtrToString(msg)) }
func logMessage(level loggerLevel, timestamp uint64, msg *uint16) int {
if tw, ok := log.Default().Writer().(TimestampedWriter); ok {
tw.WriteWithTimestamp([]byte(log.Default().Prefix()+windows.UTF16PtrToString(msg)), (int64(timestamp)-116444736000000000)*100)
} else {
log.Println(windows.UTF16PtrToString(msg))
}
return 0 return 0
}), 0, 0)
} }
func MakePool(poolName string) (pool *Pool, err error) { func setupLogger(dll *lazyDLL) {
poolName16, err := windows.UTF16FromString(poolName) var callback uintptr
if runtime.GOARCH == "386" || runtime.GOARCH == "arm" {
callback = windows.NewCallback(func(level loggerLevel, timestampLow, timestampHigh uint32, msg *uint16) int {
return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
})
} else if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" {
callback = windows.NewCallback(logMessage)
}
syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, callback, 0, 0)
}
func closeAdapter(wintun *Adapter) {
syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
}
// CreateAdapter creates a Wintun adapter. name is the cosmetic name of the adapter.
// tunnelType represents the type of adapter and should be "Wintun". requestedGUID is
// the GUID of the created network adapter, which then influences NLA generation
// deterministically. If it is set to nil, the GUID is chosen by the system at random,
// and hence a new NLA entry is created for each new adapter.
func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) {
var name16 *uint16
name16, err = windows.UTF16PtrFromString(name)
if err != nil { if err != nil {
return return
} }
if len(poolName16) > PoolNameMax { var tunnelType16 *uint16
err = errors.New("Pool name too long") tunnelType16, err = windows.UTF16PtrFromString(tunnelType)
return
}
pool = &Pool{}
copy(pool[:], poolName16)
return
}
func (pool *Pool) String() string {
return windows.UTF16ToString(pool[:])
}
func freeAdapter(wintun *Adapter) {
syscall.Syscall(procWintunFreeAdapter.Addr(), 1, uintptr(wintun.handle), 0, 0)
}
// OpenAdapter finds a Wintun adapter by its name. This function returns the adapter if found, or
// windows.ERROR_FILE_NOT_FOUND otherwise. If the adapter is found but not a Wintun-class or a
// member of the pool, this function returns windows.ERROR_ALREADY_EXISTS. The adapter must be
// released after use.
func (pool *Pool) OpenAdapter(ifname string) (wintun *Adapter, err error) {
ifname16, err := windows.UTF16PtrFromString(ifname)
if err != nil { if err != nil {
return nil, err return
} }
r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), 0) r0, _, e1 := syscall.Syscall(procWintunCreateAdapter.Addr(), 3, uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(requestedGUID)))
if r0 == 0 { if r0 == 0 {
err = e1 err = e1
return return
} }
wintun = &Adapter{r0} wintun = &Adapter{handle: r0}
runtime.SetFinalizer(wintun, freeAdapter) runtime.SetFinalizer(wintun, closeAdapter)
return return
} }
// CreateAdapter creates a Wintun adapter. ifname is the requested name of the adapter, while // OpenAdapter opens an existing Wintun adapter by name.
// requestedGUID is the GUID of the created network adapter, which then influences NLA generation func OpenAdapter(name string) (wintun *Adapter, err error) {
// deterministically. If it is set to nil, the GUID is chosen by the system at random, and hence a var name16 *uint16
// new NLA entry is created for each new adapter. It is called "requested" GUID because the API it name16, err = windows.UTF16PtrFromString(name)
// uses is completely undocumented, and so there could be minor interesting complications with its
// usage. This function returns the network adapter ID and a flag if reboot is required.
func (pool *Pool) CreateAdapter(ifname string, requestedGUID *windows.GUID) (wintun *Adapter, rebootRequired bool, err error) {
var ifname16 *uint16
ifname16, err = windows.UTF16PtrFromString(ifname)
if err != nil { if err != nil {
return return
} }
var _p0 uint32 r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 1, uintptr(unsafe.Pointer(name16)), 0, 0)
r0, _, e1 := syscall.Syscall6(procWintunCreateAdapter.Addr(), 4, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), uintptr(unsafe.Pointer(requestedGUID)), uintptr(unsafe.Pointer(&_p0)), 0, 0)
rebootRequired = _p0 != 0
if r0 == 0 { if r0 == 0 {
err = e1 err = e1
return return
} }
wintun = &Adapter{r0} wintun = &Adapter{handle: r0}
runtime.SetFinalizer(wintun, freeAdapter) runtime.SetFinalizer(wintun, closeAdapter)
return return
} }
// Delete deletes a Wintun adapter. This function succeeds if the adapter was not found. It returns // Close closes a Wintun adapter.
// a bool indicating whether a reboot is required. func (wintun *Adapter) Close() (err error) {
func (wintun *Adapter) Delete(forceCloseSessions bool) (rebootRequired bool, err error) { runtime.SetFinalizer(wintun, nil)
var _p0 uint32 r1, _, e1 := syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
if forceCloseSessions {
_p0 = 1
}
var _p1 uint32
r1, _, e1 := syscall.Syscall(procWintunDeleteAdapter.Addr(), 3, uintptr(wintun.handle), uintptr(_p0), uintptr(unsafe.Pointer(&_p1)))
rebootRequired = _p1 != 0
if r1 == 0 { if r1 == 0 {
err = e1 err = e1
} }
return return
} }
// DeleteMatchingAdapters deletes all Wintun adapters, which match // Uninstall removes the driver from the system if no drivers are currently in use.
// given criteria, and returns which ones it deleted, whether a reboot func Uninstall() (err error) {
// is required after, and which errors occurred during the process. r1, _, e1 := syscall.Syscall(procWintunDeleteDriver.Addr(), 0, 0, 0, 0)
func (pool *Pool) DeleteMatchingAdapters(matches func(adapter *Adapter) bool, forceCloseSessions bool) (rebootRequired bool, errors []error) {
cb := func(handle uintptr, _ uintptr) int {
adapter := &Adapter{handle}
if !matches(adapter) {
return 1
}
rebootRequired2, err := adapter.Delete(forceCloseSessions)
if err != nil {
errors = append(errors, err)
return 1
}
rebootRequired = rebootRequired || rebootRequired2
return 1
}
r1, _, e1 := syscall.Syscall(procWintunEnumAdapters.Addr(), 3, uintptr(unsafe.Pointer(pool)), uintptr(windows.NewCallback(cb)), 0)
if r1 == 0 {
errors = append(errors, e1)
}
return
}
// Name returns the name of the Wintun adapter.
func (wintun *Adapter) Name() (ifname string, err error) {
var ifname16 [AdapterNameMax]uint16
r1, _, e1 := syscall.Syscall(procWintunGetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0)
if r1 == 0 {
err = e1
return
}
ifname = windows.UTF16ToString(ifname16[:])
return
}
// DeleteDriver deletes all Wintun adapters in a pool and if there are no more adapters in any other
// pools, also removes Wintun from the driver store, usually called by uninstallers.
func (pool *Pool) DeleteDriver() (rebootRequired bool, err error) {
var _p0 uint32
r1, _, e1 := syscall.Syscall(procWintunDeletePoolDriver.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(&_p0)), 0)
rebootRequired = _p0 != 0
if r1 == 0 {
err = e1
}
return
}
// SetName sets name of the Wintun adapter.
func (wintun *Adapter) SetName(ifname string) (err error) {
ifname16, err := windows.UTF16FromString(ifname)
if err != nil {
return err
}
r1, _, e1 := syscall.Syscall(procWintunSetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0)
if r1 == 0 { if r1 == 0 {
err = e1 err = e1
} }