wintun: load from filesystem by default

We let people loading this from resources opt in via:

    go build -tags load_wintun_from_rsrc

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2020-11-11 18:51:44 +01:00
parent 82128c47d9
commit 60b3766b89
3 changed files with 109 additions and 39 deletions

View file

@ -0,0 +1,50 @@
// +build !load_wintun_from_rsrc
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
)
type lazyDLL struct {
Name string
mu sync.Mutex
module windows.Handle
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != 0 {
return nil
}
const (
LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
)
module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return windows.GetProcAddress(p.dll.module, p.Name)
}

View file

@ -0,0 +1,58 @@
// +build load_wintun_from_rsrc
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/tun/wintun/memmod"
"golang.zx2c4.com/wireguard/tun/wintun/resource"
)
type lazyDLL struct {
Name string
mu sync.Mutex
module *memmod.Module
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != nil {
return nil
}
const ourModule windows.Handle = 0
resInfo, err := resource.FindByName(ourModule, d.Name, resource.RT_RCDATA)
if err != nil {
return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err)
}
data, err := resource.Load(ourModule, resInfo)
if err != nil {
return fmt.Errorf("Unable to load resource: %w", err)
}
module, err := memmod.LoadLibrary(data)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return p.dll.module.ProcAddressByName(p.Name)
}

View file

@ -10,50 +10,12 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"unsafe" "unsafe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/tun/wintun/memmod"
"golang.zx2c4.com/wireguard/tun/wintun/resource"
) )
type lazyDLL struct {
Name string
mu sync.Mutex
module *memmod.Module
}
func newLazyDLL(name string) *lazyDLL { func newLazyDLL(name string) *lazyDLL {
return &lazyDLL{Name: name} return &lazyDLL{Name: name}
} }
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != nil {
return nil
}
const ourModule windows.Handle = 0
resInfo, err := resource.FindByName(ourModule, d.Name, resource.RT_RCDATA)
if err != nil {
return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err)
}
data, err := resource.Load(ourModule, resInfo)
if err != nil {
return fmt.Errorf("Unable to load resource: %w", err)
}
module, err := memmod.LoadLibrary(data)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
return nil
}
func (d *lazyDLL) NewProc(name string) *lazyProc { func (d *lazyDLL) NewProc(name string) *lazyProc {
return &lazyProc{dll: d, Name: name} return &lazyProc{dll: d, Name: name}
} }
@ -79,7 +41,7 @@ func (p *lazyProc) Find() error {
if err != nil { if err != nil {
return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err) return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err)
} }
addr, err := p.dll.module.ProcAddressByName(p.Name) addr, err := p.nameToAddr()
if err != nil { if err != nil {
return fmt.Errorf("Error getting %v address: %w", p.Name, err) return fmt.Errorf("Error getting %v address: %w", p.Name, err)
} }