From 82d2aa87aa623cb5143a41c3345da4fb875ad85d Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Tue, 12 Oct 2021 00:26:46 -0600 Subject: [PATCH] wintun: use new swdevice-based API for upcoming Wintun 0.14 Signed-off-by: Jason A. Donenfeld --- tun/tun_windows.go | 33 ++----- tun/wintun/wintun_windows.go | 185 +++++++++++------------------------ 2 files changed, 65 insertions(+), 153 deletions(-) diff --git a/tun/tun_windows.go b/tun/tun_windows.go index ff16e2f..381a842 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -8,7 +8,6 @@ package tun import ( "errors" "fmt" - "log" "os" "sync" "sync/atomic" @@ -35,6 +34,7 @@ type rateJuggler struct { type NativeTun struct { wt *wintun.Adapter + name string handle windows.Handle rate rateJuggler session wintun.Session @@ -46,7 +46,7 @@ type NativeTun struct { forcedMTU int } -var WintunPool, _ = wintun.MakePool("WireGuard") +var WintunTunnelType = "WireGuard" var WintunStaticRequestedGUID *windows.GUID //go:linkname procyield runtime.procyield @@ -68,25 +68,10 @@ func CreateTUN(ifname string, mtu int) (Device, error) { // a requested GUID. Should a Wintun interface with the same name exist, it is reused. // func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { - var err error - var wt *wintun.Adapter - - // Does an interface with this name already exist? - wt, err = WintunPool.OpenAdapter(ifname) - if err == nil { - // If so, we delete it, in case it has weird residual configuration. - _, err = wt.Delete(true) - if err != nil { - return nil, fmt.Errorf("Error deleting already existing interface: %w", err) - } - } - wt, rebootRequired, err := WintunPool.CreateAdapter(ifname, requestedGUID) + wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) if err != nil { return nil, fmt.Errorf("Error creating interface: %w", err) } - if rebootRequired { - log.Println("Windows indicated a reboot is required.") - } forcedMTU := 1420 if mtu > 0 { @@ -95,6 +80,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu tun := &NativeTun{ wt: wt, + name: ifname, handle: windows.InvalidHandle, events: make(chan Event, 10), forcedMTU: forcedMTU, @@ -102,7 +88,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB if err != nil { - tun.wt.Delete(false) + tun.wt.Close() close(tun.events) return nil, fmt.Errorf("Error starting session: %w", err) } @@ -111,12 +97,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu } func (tun *NativeTun) Name() (string, error) { - tun.running.Add(1) - defer tun.running.Done() - if atomic.LoadInt32(&tun.close) == 1 { - return "", os.ErrClosed - } - return tun.wt.Name() + return tun.name, nil } func (tun *NativeTun) File() *os.File { @@ -135,7 +116,7 @@ func (tun *NativeTun) Close() error { tun.running.Wait() tun.session.End() if tun.wt != nil { - _, err = tun.wt.Delete(false) + tun.wt.Close() } close(tun.events) }) diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go index 6c5a00d..4edad91 100644 --- a/tun/wintun/wintun_windows.go +++ b/tun/wintun/wintun_windows.go @@ -6,7 +6,6 @@ package wintun import ( - "errors" "log" "runtime" "syscall" @@ -23,175 +22,107 @@ const ( logErr ) -const ( - PoolNameMax = 256 - AdapterNameMax = 128 -) +const AdapterNameMax = 128 -type Pool [PoolNameMax]uint16 type Adapter struct { handle uintptr } var ( - modwintun = newLazyDLL("wintun.dll", setupLogger) - + modwintun = newLazyDLL("wintun.dll", setupLogger) procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter") - procWintunDeleteAdapter = modwintun.NewProc("WintunDeleteAdapter") - procWintunDeletePoolDriver = modwintun.NewProc("WintunDeletePoolDriver") - procWintunEnumAdapters = modwintun.NewProc("WintunEnumAdapters") - procWintunFreeAdapter = modwintun.NewProc("WintunFreeAdapter") procWintunOpenAdapter = modwintun.NewProc("WintunOpenAdapter") + procWintunCloseAdapter = modwintun.NewProc("WintunCloseAdapter") + procWintunDeleteDriver = modwintun.NewProc("WintunDeleteDriver") procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID") - procWintunGetAdapterName = modwintun.NewProc("WintunGetAdapterName") procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion") - procWintunSetAdapterName = modwintun.NewProc("WintunSetAdapterName") ) +type TimestampedWriter interface { + WriteWithTimestamp(p []byte, ts int64) (n int, err error) +} + +func logMessage(level loggerLevel, timestamp uint64, msg *uint16) int { + if tw, ok := log.Default().Writer().(TimestampedWriter); ok { + tw.WriteWithTimestamp([]byte(log.Default().Prefix()+windows.UTF16PtrToString(msg)), (int64(timestamp)-116444736000000000)*100) + } else { + log.Println(windows.UTF16PtrToString(msg)) + } + return 0 +} + func setupLogger(dll *lazyDLL) { - syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, windows.NewCallback(func(level loggerLevel, msg *uint16) int { - log.Println("[Wintun]", windows.UTF16PtrToString(msg)) - return 0 - }), 0, 0) + var callback uintptr + if runtime.GOARCH == "386" || runtime.GOARCH == "arm" { + callback = windows.NewCallback(func(level loggerLevel, timestampLow, timestampHigh uint32, msg *uint16) int { + return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg) + }) + } else if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" { + callback = windows.NewCallback(logMessage) + } + syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, callback, 0, 0) } -func MakePool(poolName string) (pool *Pool, err error) { - poolName16, err := windows.UTF16FromString(poolName) +func closeAdapter(wintun *Adapter) { + syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0) +} + +// CreateAdapter creates a Wintun adapter. name is the cosmetic name of the adapter. +// tunnelType represents the type of adapter and should be "Wintun". requestedGUID is +// the GUID of the created network adapter, which then influences NLA generation +// deterministically. If it is set to nil, the GUID is chosen by the system at random, +// and hence a new NLA entry is created for each new adapter. +func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) { + var name16 *uint16 + name16, err = windows.UTF16PtrFromString(name) if err != nil { return } - if len(poolName16) > PoolNameMax { - err = errors.New("Pool name too long") + var tunnelType16 *uint16 + tunnelType16, err = windows.UTF16PtrFromString(tunnelType) + if err != nil { return } - pool = &Pool{} - copy(pool[:], poolName16) - return -} - -func (pool *Pool) String() string { - return windows.UTF16ToString(pool[:]) -} - -func freeAdapter(wintun *Adapter) { - syscall.Syscall(procWintunFreeAdapter.Addr(), 1, uintptr(wintun.handle), 0, 0) -} - -// OpenAdapter finds a Wintun adapter by its name. This function returns the adapter if found, or -// windows.ERROR_FILE_NOT_FOUND otherwise. If the adapter is found but not a Wintun-class or a -// member of the pool, this function returns windows.ERROR_ALREADY_EXISTS. The adapter must be -// released after use. -func (pool *Pool) OpenAdapter(ifname string) (wintun *Adapter, err error) { - ifname16, err := windows.UTF16PtrFromString(ifname) - if err != nil { - return nil, err - } - r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), 0) + r0, _, e1 := syscall.Syscall(procWintunCreateAdapter.Addr(), 3, uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(requestedGUID))) if r0 == 0 { err = e1 return } - wintun = &Adapter{r0} - runtime.SetFinalizer(wintun, freeAdapter) + wintun = &Adapter{handle: r0} + runtime.SetFinalizer(wintun, closeAdapter) return } -// CreateAdapter creates a Wintun adapter. ifname is the requested name of the adapter, while -// requestedGUID is the GUID of the created network adapter, which then influences NLA generation -// deterministically. If it is set to nil, the GUID is chosen by the system at random, and hence a -// new NLA entry is created for each new adapter. It is called "requested" GUID because the API it -// uses is completely undocumented, and so there could be minor interesting complications with its -// usage. This function returns the network adapter ID and a flag if reboot is required. -func (pool *Pool) CreateAdapter(ifname string, requestedGUID *windows.GUID) (wintun *Adapter, rebootRequired bool, err error) { - var ifname16 *uint16 - ifname16, err = windows.UTF16PtrFromString(ifname) +// OpenAdapter opens an existing Wintun adapter by name. +func OpenAdapter(name string) (wintun *Adapter, err error) { + var name16 *uint16 + name16, err = windows.UTF16PtrFromString(name) if err != nil { return } - var _p0 uint32 - r0, _, e1 := syscall.Syscall6(procWintunCreateAdapter.Addr(), 4, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), uintptr(unsafe.Pointer(requestedGUID)), uintptr(unsafe.Pointer(&_p0)), 0, 0) - rebootRequired = _p0 != 0 + r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 1, uintptr(unsafe.Pointer(name16)), 0, 0) if r0 == 0 { err = e1 return } - wintun = &Adapter{r0} - runtime.SetFinalizer(wintun, freeAdapter) + wintun = &Adapter{handle: r0} + runtime.SetFinalizer(wintun, closeAdapter) return } -// Delete deletes a Wintun adapter. This function succeeds if the adapter was not found. It returns -// a bool indicating whether a reboot is required. -func (wintun *Adapter) Delete(forceCloseSessions bool) (rebootRequired bool, err error) { - var _p0 uint32 - if forceCloseSessions { - _p0 = 1 - } - var _p1 uint32 - r1, _, e1 := syscall.Syscall(procWintunDeleteAdapter.Addr(), 3, uintptr(wintun.handle), uintptr(_p0), uintptr(unsafe.Pointer(&_p1))) - rebootRequired = _p1 != 0 +// Close closes a Wintun adapter. +func (wintun *Adapter) Close() (err error) { + runtime.SetFinalizer(wintun, nil) + r1, _, e1 := syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0) if r1 == 0 { err = e1 } return } -// DeleteMatchingAdapters deletes all Wintun adapters, which match -// given criteria, and returns which ones it deleted, whether a reboot -// is required after, and which errors occurred during the process. -func (pool *Pool) DeleteMatchingAdapters(matches func(adapter *Adapter) bool, forceCloseSessions bool) (rebootRequired bool, errors []error) { - cb := func(handle uintptr, _ uintptr) int { - adapter := &Adapter{handle} - if !matches(adapter) { - return 1 - } - rebootRequired2, err := adapter.Delete(forceCloseSessions) - if err != nil { - errors = append(errors, err) - return 1 - } - rebootRequired = rebootRequired || rebootRequired2 - return 1 - } - r1, _, e1 := syscall.Syscall(procWintunEnumAdapters.Addr(), 3, uintptr(unsafe.Pointer(pool)), uintptr(windows.NewCallback(cb)), 0) - if r1 == 0 { - errors = append(errors, e1) - } - return -} - -// Name returns the name of the Wintun adapter. -func (wintun *Adapter) Name() (ifname string, err error) { - var ifname16 [AdapterNameMax]uint16 - r1, _, e1 := syscall.Syscall(procWintunGetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0) - if r1 == 0 { - err = e1 - return - } - ifname = windows.UTF16ToString(ifname16[:]) - return -} - -// DeleteDriver deletes all Wintun adapters in a pool and if there are no more adapters in any other -// pools, also removes Wintun from the driver store, usually called by uninstallers. -func (pool *Pool) DeleteDriver() (rebootRequired bool, err error) { - var _p0 uint32 - r1, _, e1 := syscall.Syscall(procWintunDeletePoolDriver.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(&_p0)), 0) - rebootRequired = _p0 != 0 - if r1 == 0 { - err = e1 - } - return - -} - -// SetName sets name of the Wintun adapter. -func (wintun *Adapter) SetName(ifname string) (err error) { - ifname16, err := windows.UTF16FromString(ifname) - if err != nil { - return err - } - r1, _, e1 := syscall.Syscall(procWintunSetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0) +// Uninstall removes the driver from the system if no drivers are currently in use. +func Uninstall() (err error) { + r1, _, e1 := syscall.Syscall(procWintunDeleteDriver.Addr(), 0, 0, 0, 0) if r1 == 0 { err = e1 }