From 642a56e165e74a518fe986c2cf93dea62d6029b5 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 11 Oct 2021 14:53:36 -0600 Subject: [PATCH] memmod: import from wireguard-windows We'll eventually be getting rid of it here, but keep it sync'd up for now. Signed-off-by: Jason A. Donenfeld --- go.mod | 2 +- tun/wintun/memmod/memmod_windows.go | 124 +++++++++++++++++------- tun/wintun/memmod/memmod_windows_32.go | 1 + tun/wintun/memmod/memmod_windows_64.go | 1 + tun/wintun/memmod/syscall_windows_32.go | 1 + tun/wintun/memmod/syscall_windows_64.go | 1 + 6 files changed, 96 insertions(+), 34 deletions(-) diff --git a/go.mod b/go.mod index 5d8388b..e543167 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module golang.zx2c4.com/wireguard -go 1.16 +go 1.17 require ( golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 diff --git a/tun/wintun/memmod/memmod_windows.go b/tun/wintun/memmod/memmod_windows.go index 075c03a..da6ff9a 100644 --- a/tun/wintun/memmod/memmod_windows.go +++ b/tun/wintun/memmod/memmod_windows.go @@ -8,6 +8,8 @@ package memmod import ( "errors" "fmt" + "strings" + "sync" "syscall" "unsafe" @@ -62,8 +64,7 @@ func (module *Module) copySections(address uintptr, size uintptr, oldHeaders *IM dest = module.codeBase + uintptr(sections[i].VirtualAddress) // NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used. sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff)) - var dst []byte - unsafeSlice(unsafe.Pointer(&dst), a2p(dest), int(sectionSize)) + dst := unsafe.Slice((*byte)(a2p(dest)), sectionSize) for j := range dst { dst[j] = 0 } @@ -245,11 +246,9 @@ func (module *Module) performBaseRelocation(delta uintptr) (relocated bool, err for relocationHdr.VirtualAddress > 0 { dest := module.codeBase + uintptr(relocationHdr.VirtualAddress) - var relInfos []uint16 - unsafeSlice( - unsafe.Pointer(&relInfos), - a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr)), - int((uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(relInfos[0]))) + relInfos := unsafe.Slice( + (*uint16)(a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr))), + (uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(uint16(0))) for _, relInfo := range relInfos { // The upper 4 bits define the type of relocation. relType := relInfo >> 12 @@ -370,10 +369,8 @@ func (module *Module) buildNameExports() error { if exports.NumberOfNames == 0 { return errors.New("No functions exported by name") } - var nameRefs []uint32 - unsafeSlice(unsafe.Pointer(&nameRefs), a2p(module.codeBase+uintptr(exports.AddressOfNames)), int(exports.NumberOfNames)) - var ordinals []uint16 - unsafeSlice(unsafe.Pointer(&ordinals), a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals)), int(exports.NumberOfNames)) + nameRefs := unsafe.Slice((*uint32)(a2p(module.codeBase+uintptr(exports.AddressOfNames))), exports.NumberOfNames) + ordinals := unsafe.Slice((*uint16)(a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals))), exports.NumberOfNames) module.nameExports = make(map[string]uint16) for i := range nameRefs { nameArray := windows.BytePtrToString((*byte)(a2p(module.codeBase + uintptr(nameRefs[i])))) @@ -382,6 +379,76 @@ func (module *Module) buildNameExports() error { return nil } +type addressRange struct { + start uintptr + end uintptr +} + +var loadedAddressRanges []addressRange +var loadedAddressRangesMu sync.RWMutex +var haveHookedRtlPcToFileHeader sync.Once +var hookRtlPcToFileHeaderResult error + +func hookRtlPcToFileHeader() error { + var kernelBase windows.Handle + err := windows.GetModuleHandleEx(windows.GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, windows.StringToUTF16Ptr("kernelbase.dll"), &kernelBase) + if err != nil { + return err + } + imageBase := unsafe.Pointer(kernelBase) + dosHeader := (*IMAGE_DOS_HEADER)(imageBase) + ntHeaders := (*IMAGE_NT_HEADERS)(unsafe.Add(imageBase, dosHeader.E_lfanew)) + importsDirectory := ntHeaders.OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT] + importDescriptor := (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(imageBase, importsDirectory.VirtualAddress)) + for ; importDescriptor.Name != 0; importDescriptor = (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(unsafe.Pointer(importDescriptor), unsafe.Sizeof(*importDescriptor))) { + libraryName := windows.BytePtrToString((*byte)(unsafe.Add(imageBase, importDescriptor.Name))) + if strings.EqualFold(libraryName, "ntdll.dll") { + break + } + } + if importDescriptor.Name == 0 { + return errors.New("ntdll.dll not found") + } + originalThunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.OriginalFirstThunk())) + thunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.FirstThunk)) + for ; *originalThunk != 0; originalThunk = (*uintptr)(unsafe.Add(unsafe.Pointer(originalThunk), unsafe.Sizeof(*originalThunk))) { + if *originalThunk&IMAGE_ORDINAL_FLAG == 0 { + function := (*IMAGE_IMPORT_BY_NAME)(unsafe.Add(imageBase, *originalThunk)) + name := windows.BytePtrToString(&function.Name[0]) + if name == "RtlPcToFileHeader" { + break + } + } + thunk = (*uintptr)(unsafe.Add(unsafe.Pointer(thunk), unsafe.Sizeof(*thunk))) + } + if *originalThunk == 0 { + return errors.New("RtlPcToFileHeader not found") + } + var oldProtect uint32 + err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), windows.PAGE_READWRITE, &oldProtect) + if err != nil { + return err + } + originalRtlPcToFileHeader := *thunk + *thunk = windows.NewCallback(func(pcValue uintptr, baseOfImage *uintptr) uintptr { + loadedAddressRangesMu.RLock() + for i := range loadedAddressRanges { + if pcValue >= loadedAddressRanges[i].start && pcValue < loadedAddressRanges[i].end { + pcValue = *thunk + break + } + } + loadedAddressRangesMu.RUnlock() + ret, _, _ := syscall.Syscall(originalRtlPcToFileHeader, 2, pcValue, uintptr(unsafe.Pointer(baseOfImage)), 0) + return ret + }) + err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), oldProtect, &oldProtect) + if err != nil { + return err + } + return nil +} + // LoadLibrary loads module image to memory. func LoadLibrary(data []byte) (module *Module, err error) { addr := uintptr(unsafe.Pointer(&data[0])) @@ -513,6 +580,18 @@ func LoadLibrary(data []byte) (module *Module, err error) { // Register exception tables, if they exist. module.registerExceptionHandlers() + // Register function PCs. + loadedAddressRangesMu.Lock() + loadedAddressRanges = append(loadedAddressRanges, addressRange{module.codeBase, module.codeBase + alignedImageSize}) + loadedAddressRangesMu.Unlock() + haveHookedRtlPcToFileHeader.Do(func() { + hookRtlPcToFileHeaderResult = hookRtlPcToFileHeader() + }) + err = hookRtlPcToFileHeaderResult + if err != nil { + return + } + // TLS callbacks are executed BEFORE the main loading. module.executeTLS() @@ -610,26 +689,5 @@ func a2p(addr uintptr) unsafe.Pointer { } func memcpy(dst, src, size uintptr) { - var d, s []byte - unsafeSlice(unsafe.Pointer(&d), a2p(dst), int(size)) - unsafeSlice(unsafe.Pointer(&s), a2p(src), int(size)) - copy(d, s) -} - -// unsafeSlice updates the slice slicePtr to be a slice -// referencing the provided data with its length & capacity set to -// lenCap. -// -// TODO: when Go 1.16 or Go 1.17 is the minimum supported version, -// update callers to use unsafe.Slice instead of this. -func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) { - type sliceHeader struct { - Data unsafe.Pointer - Len int - Cap int - } - h := (*sliceHeader)(slicePtr) - h.Data = data - h.Len = lenCap - h.Cap = lenCap + copy(unsafe.Slice((*byte)(a2p(dst)), size), unsafe.Slice((*byte)(a2p(src)), size)) } diff --git a/tun/wintun/memmod/memmod_windows_32.go b/tun/wintun/memmod/memmod_windows_32.go index ac76bdc..75d7ca1 100644 --- a/tun/wintun/memmod/memmod_windows_32.go +++ b/tun/wintun/memmod/memmod_windows_32.go @@ -1,3 +1,4 @@ +//go:build (windows && 386) || (windows && arm) // +build windows,386 windows,arm /* SPDX-License-Identifier: MIT diff --git a/tun/wintun/memmod/memmod_windows_64.go b/tun/wintun/memmod/memmod_windows_64.go index a620368..09e6e73 100644 --- a/tun/wintun/memmod/memmod_windows_64.go +++ b/tun/wintun/memmod/memmod_windows_64.go @@ -1,3 +1,4 @@ +//go:build (windows && amd64) || (windows && arm64) // +build windows,amd64 windows,arm64 /* SPDX-License-Identifier: MIT diff --git a/tun/wintun/memmod/syscall_windows_32.go b/tun/wintun/memmod/syscall_windows_32.go index 7abbac9..0072710 100644 --- a/tun/wintun/memmod/syscall_windows_32.go +++ b/tun/wintun/memmod/syscall_windows_32.go @@ -1,3 +1,4 @@ +//go:build (windows && 386) || (windows && arm) // +build windows,386 windows,arm /* SPDX-License-Identifier: MIT diff --git a/tun/wintun/memmod/syscall_windows_64.go b/tun/wintun/memmod/syscall_windows_64.go index 10c6533..b475202 100644 --- a/tun/wintun/memmod/syscall_windows_64.go +++ b/tun/wintun/memmod/syscall_windows_64.go @@ -1,3 +1,4 @@ +//go:build (windows && amd64) || (windows && arm64) // +build windows,amd64 windows,arm64 /* SPDX-License-Identifier: MIT