memmod: import from wireguard-windows

We'll eventually be getting rid of it here, but keep it sync'd up for
now.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-10-11 14:53:36 -06:00
parent bb745b2ea3
commit 642a56e165
6 changed files with 96 additions and 34 deletions

2
go.mod
View file

@ -1,6 +1,6 @@
module golang.zx2c4.com/wireguard module golang.zx2c4.com/wireguard
go 1.16 go 1.17
require ( require (
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 golang.org/x/crypto v0.0.0-20210921155107-089bfa567519

View file

@ -8,6 +8,8 @@ package memmod
import ( import (
"errors" "errors"
"fmt" "fmt"
"strings"
"sync"
"syscall" "syscall"
"unsafe" "unsafe"
@ -62,8 +64,7 @@ func (module *Module) copySections(address uintptr, size uintptr, oldHeaders *IM
dest = module.codeBase + uintptr(sections[i].VirtualAddress) dest = module.codeBase + uintptr(sections[i].VirtualAddress)
// NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used. // NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff)) sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
var dst []byte dst := unsafe.Slice((*byte)(a2p(dest)), sectionSize)
unsafeSlice(unsafe.Pointer(&dst), a2p(dest), int(sectionSize))
for j := range dst { for j := range dst {
dst[j] = 0 dst[j] = 0
} }
@ -245,11 +246,9 @@ func (module *Module) performBaseRelocation(delta uintptr) (relocated bool, err
for relocationHdr.VirtualAddress > 0 { for relocationHdr.VirtualAddress > 0 {
dest := module.codeBase + uintptr(relocationHdr.VirtualAddress) dest := module.codeBase + uintptr(relocationHdr.VirtualAddress)
var relInfos []uint16 relInfos := unsafe.Slice(
unsafeSlice( (*uint16)(a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr))),
unsafe.Pointer(&relInfos), (uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(uint16(0)))
a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr)),
int((uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(relInfos[0])))
for _, relInfo := range relInfos { for _, relInfo := range relInfos {
// The upper 4 bits define the type of relocation. // The upper 4 bits define the type of relocation.
relType := relInfo >> 12 relType := relInfo >> 12
@ -370,10 +369,8 @@ func (module *Module) buildNameExports() error {
if exports.NumberOfNames == 0 { if exports.NumberOfNames == 0 {
return errors.New("No functions exported by name") return errors.New("No functions exported by name")
} }
var nameRefs []uint32 nameRefs := unsafe.Slice((*uint32)(a2p(module.codeBase+uintptr(exports.AddressOfNames))), exports.NumberOfNames)
unsafeSlice(unsafe.Pointer(&nameRefs), a2p(module.codeBase+uintptr(exports.AddressOfNames)), int(exports.NumberOfNames)) ordinals := unsafe.Slice((*uint16)(a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals))), exports.NumberOfNames)
var ordinals []uint16
unsafeSlice(unsafe.Pointer(&ordinals), a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals)), int(exports.NumberOfNames))
module.nameExports = make(map[string]uint16) module.nameExports = make(map[string]uint16)
for i := range nameRefs { for i := range nameRefs {
nameArray := windows.BytePtrToString((*byte)(a2p(module.codeBase + uintptr(nameRefs[i])))) nameArray := windows.BytePtrToString((*byte)(a2p(module.codeBase + uintptr(nameRefs[i]))))
@ -382,6 +379,76 @@ func (module *Module) buildNameExports() error {
return nil return nil
} }
type addressRange struct {
start uintptr
end uintptr
}
var loadedAddressRanges []addressRange
var loadedAddressRangesMu sync.RWMutex
var haveHookedRtlPcToFileHeader sync.Once
var hookRtlPcToFileHeaderResult error
func hookRtlPcToFileHeader() error {
var kernelBase windows.Handle
err := windows.GetModuleHandleEx(windows.GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, windows.StringToUTF16Ptr("kernelbase.dll"), &kernelBase)
if err != nil {
return err
}
imageBase := unsafe.Pointer(kernelBase)
dosHeader := (*IMAGE_DOS_HEADER)(imageBase)
ntHeaders := (*IMAGE_NT_HEADERS)(unsafe.Add(imageBase, dosHeader.E_lfanew))
importsDirectory := ntHeaders.OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT]
importDescriptor := (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(imageBase, importsDirectory.VirtualAddress))
for ; importDescriptor.Name != 0; importDescriptor = (*IMAGE_IMPORT_DESCRIPTOR)(unsafe.Add(unsafe.Pointer(importDescriptor), unsafe.Sizeof(*importDescriptor))) {
libraryName := windows.BytePtrToString((*byte)(unsafe.Add(imageBase, importDescriptor.Name)))
if strings.EqualFold(libraryName, "ntdll.dll") {
break
}
}
if importDescriptor.Name == 0 {
return errors.New("ntdll.dll not found")
}
originalThunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.OriginalFirstThunk()))
thunk := (*uintptr)(unsafe.Add(imageBase, importDescriptor.FirstThunk))
for ; *originalThunk != 0; originalThunk = (*uintptr)(unsafe.Add(unsafe.Pointer(originalThunk), unsafe.Sizeof(*originalThunk))) {
if *originalThunk&IMAGE_ORDINAL_FLAG == 0 {
function := (*IMAGE_IMPORT_BY_NAME)(unsafe.Add(imageBase, *originalThunk))
name := windows.BytePtrToString(&function.Name[0])
if name == "RtlPcToFileHeader" {
break
}
}
thunk = (*uintptr)(unsafe.Add(unsafe.Pointer(thunk), unsafe.Sizeof(*thunk)))
}
if *originalThunk == 0 {
return errors.New("RtlPcToFileHeader not found")
}
var oldProtect uint32
err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), windows.PAGE_READWRITE, &oldProtect)
if err != nil {
return err
}
originalRtlPcToFileHeader := *thunk
*thunk = windows.NewCallback(func(pcValue uintptr, baseOfImage *uintptr) uintptr {
loadedAddressRangesMu.RLock()
for i := range loadedAddressRanges {
if pcValue >= loadedAddressRanges[i].start && pcValue < loadedAddressRanges[i].end {
pcValue = *thunk
break
}
}
loadedAddressRangesMu.RUnlock()
ret, _, _ := syscall.Syscall(originalRtlPcToFileHeader, 2, pcValue, uintptr(unsafe.Pointer(baseOfImage)), 0)
return ret
})
err = windows.VirtualProtect(uintptr(unsafe.Pointer(thunk)), unsafe.Sizeof(*thunk), oldProtect, &oldProtect)
if err != nil {
return err
}
return nil
}
// LoadLibrary loads module image to memory. // LoadLibrary loads module image to memory.
func LoadLibrary(data []byte) (module *Module, err error) { func LoadLibrary(data []byte) (module *Module, err error) {
addr := uintptr(unsafe.Pointer(&data[0])) addr := uintptr(unsafe.Pointer(&data[0]))
@ -513,6 +580,18 @@ func LoadLibrary(data []byte) (module *Module, err error) {
// Register exception tables, if they exist. // Register exception tables, if they exist.
module.registerExceptionHandlers() module.registerExceptionHandlers()
// Register function PCs.
loadedAddressRangesMu.Lock()
loadedAddressRanges = append(loadedAddressRanges, addressRange{module.codeBase, module.codeBase + alignedImageSize})
loadedAddressRangesMu.Unlock()
haveHookedRtlPcToFileHeader.Do(func() {
hookRtlPcToFileHeaderResult = hookRtlPcToFileHeader()
})
err = hookRtlPcToFileHeaderResult
if err != nil {
return
}
// TLS callbacks are executed BEFORE the main loading. // TLS callbacks are executed BEFORE the main loading.
module.executeTLS() module.executeTLS()
@ -610,26 +689,5 @@ func a2p(addr uintptr) unsafe.Pointer {
} }
func memcpy(dst, src, size uintptr) { func memcpy(dst, src, size uintptr) {
var d, s []byte copy(unsafe.Slice((*byte)(a2p(dst)), size), unsafe.Slice((*byte)(a2p(src)), size))
unsafeSlice(unsafe.Pointer(&d), a2p(dst), int(size))
unsafeSlice(unsafe.Pointer(&s), a2p(src), int(size))
copy(d, s)
}
// unsafeSlice updates the slice slicePtr to be a slice
// referencing the provided data with its length & capacity set to
// lenCap.
//
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
// update callers to use unsafe.Slice instead of this.
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
h := (*sliceHeader)(slicePtr)
h.Data = data
h.Len = lenCap
h.Cap = lenCap
} }

View file

@ -1,3 +1,4 @@
//go:build (windows && 386) || (windows && arm)
// +build windows,386 windows,arm // +build windows,386 windows,arm
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT

View file

@ -1,3 +1,4 @@
//go:build (windows && amd64) || (windows && arm64)
// +build windows,amd64 windows,arm64 // +build windows,amd64 windows,arm64
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT

View file

@ -1,3 +1,4 @@
//go:build (windows && 386) || (windows && arm)
// +build windows,386 windows,arm // +build windows,386 windows,arm
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT

View file

@ -1,3 +1,4 @@
//go:build (windows && amd64) || (windows && arm64)
// +build windows,amd64 windows,arm64 // +build windows,amd64 windows,arm64
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT