wintun: do not load dll in init()

This prevents linking to wintun.dll until it's actually needed, which
should improve startup time.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2020-12-09 01:46:55 +01:00
parent 347ce76bbc
commit ca9edf1c63
5 changed files with 21 additions and 14 deletions

View file

@ -20,6 +20,7 @@ type lazyDLL struct {
Name string Name string
mu sync.Mutex mu sync.Mutex
module windows.Handle module windows.Handle
onLoad func(d *lazyDLL)
} }
func (d *lazyDLL) Load() error { func (d *lazyDLL) Load() error {
@ -42,6 +43,9 @@ func (d *lazyDLL) Load() error {
} }
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module)) atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
if d.onLoad != nil {
d.onLoad(d)
}
return nil return nil
} }

View file

@ -23,6 +23,7 @@ type lazyDLL struct {
Name string Name string
mu sync.Mutex mu sync.Mutex
module *memmod.Module module *memmod.Module
onLoad func(d *lazyDLL)
} }
func (d *lazyDLL) Load() error { func (d *lazyDLL) Load() error {
@ -50,6 +51,9 @@ func (d *lazyDLL) Load() error {
} }
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module)) atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
if d.onLoad != nil {
d.onLoad(d)
}
return nil return nil
} }

View file

@ -12,8 +12,8 @@ import (
"unsafe" "unsafe"
) )
func newLazyDLL(name string) *lazyDLL { func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL {
return &lazyDLL{Name: name} return &lazyDLL{Name: name, onLoad: onLoad}
} }
func (d *lazyDLL) NewProc(name string) *lazyProc { func (d *lazyDLL) NewProc(name string) *lazyProc {

View file

@ -30,12 +30,12 @@ type Packet struct {
} }
var ( var (
procWintunAllocateSendPacket = modwintun.NewProc("WintunAllocateSendPacket").Addr() procWintunAllocateSendPacket = modwintun.NewProc("WintunAllocateSendPacket")
procWintunEndSession = modwintun.NewProc("WintunEndSession") procWintunEndSession = modwintun.NewProc("WintunEndSession")
procWintunGetReadWaitEvent = modwintun.NewProc("WintunGetReadWaitEvent") procWintunGetReadWaitEvent = modwintun.NewProc("WintunGetReadWaitEvent")
procWintunReceivePacket = modwintun.NewProc("WintunReceivePacket").Addr() procWintunReceivePacket = modwintun.NewProc("WintunReceivePacket")
procWintunReleaseReceivePacket = modwintun.NewProc("WintunReleaseReceivePacket").Addr() procWintunReleaseReceivePacket = modwintun.NewProc("WintunReleaseReceivePacket")
procWintunSendPacket = modwintun.NewProc("WintunSendPacket").Addr() procWintunSendPacket = modwintun.NewProc("WintunSendPacket")
procWintunStartSession = modwintun.NewProc("WintunStartSession") procWintunStartSession = modwintun.NewProc("WintunStartSession")
) )
@ -62,7 +62,7 @@ func (session Session) ReadWaitEvent() (handle windows.Handle) {
func (session Session) ReceivePacket() (packet []byte, err error) { func (session Session) ReceivePacket() (packet []byte, err error) {
var packetSize uint32 var packetSize uint32
r0, _, e1 := syscall.Syscall(procWintunReceivePacket, 2, session.handle, uintptr(unsafe.Pointer(&packetSize)), 0) r0, _, e1 := syscall.Syscall(procWintunReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packetSize)), 0)
if r0 == 0 { if r0 == 0 {
err = e1 err = e1
return return
@ -72,11 +72,11 @@ func (session Session) ReceivePacket() (packet []byte, err error) {
} }
func (session Session) ReleaseReceivePacket(packet []byte) { func (session Session) ReleaseReceivePacket(packet []byte) {
syscall.Syscall(procWintunReleaseReceivePacket, 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) syscall.Syscall(procWintunReleaseReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
} }
func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err error) { func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err error) {
r0, _, e1 := syscall.Syscall(procWintunAllocateSendPacket, 2, session.handle, uintptr(packetSize), 0) r0, _, e1 := syscall.Syscall(procWintunAllocateSendPacket.Addr(), 2, session.handle, uintptr(packetSize), 0)
if r0 == 0 { if r0 == 0 {
err = e1 err = e1
return return
@ -86,7 +86,7 @@ func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err er
} }
func (session Session) SendPacket(packet []byte) { func (session Session) SendPacket(packet []byte) {
syscall.Syscall(procWintunSendPacket, 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0)
} }
// unsafeSlice updates the slice slicePtr to be a slice // unsafeSlice updates the slice slicePtr to be a slice

View file

@ -34,7 +34,7 @@ type Adapter struct {
} }
var ( var (
modwintun = newLazyDLL("wintun.dll") modwintun = newLazyDLL("wintun.dll", setupLogger)
procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter") procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter")
procWintunDeleteAdapter = modwintun.NewProc("WintunDeleteAdapter") procWintunDeleteAdapter = modwintun.NewProc("WintunDeleteAdapter")
@ -46,11 +46,10 @@ var (
procWintunGetAdapterName = modwintun.NewProc("WintunGetAdapterName") procWintunGetAdapterName = modwintun.NewProc("WintunGetAdapterName")
procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion") procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion")
procWintunSetAdapterName = modwintun.NewProc("WintunSetAdapterName") procWintunSetAdapterName = modwintun.NewProc("WintunSetAdapterName")
procWintunSetLogger = modwintun.NewProc("WintunSetLogger")
) )
func init() { func setupLogger(dll *lazyDLL) {
syscall.Syscall(procWintunSetLogger.Addr(), 1, windows.NewCallback(func(level loggerLevel, msg *uint16) int { syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, windows.NewCallback(func(level loggerLevel, msg *uint16) int {
log.Println("[Wintun]", windows.UTF16PtrToString(msg)) log.Println("[Wintun]", windows.UTF16PtrToString(msg))
return 0 return 0
}), 0, 0) }), 0, 0)