wintun: registry: fix nits

This commit is contained in:
Jason A. Donenfeld 2019-05-11 17:25:48 +02:00
parent 6c1b66802f
commit 3147f00089

View file

@ -10,7 +10,6 @@ import (
"fmt" "fmt"
"runtime" "runtime"
"strings" "strings"
"syscall"
"time" "time"
"unsafe" "unsafe"
@ -104,26 +103,26 @@ func WaitForKey(k registry.Key, path string, timeout time.Duration) error {
} }
// //
// getValue is the same as windows/registry's getValue, which is unfortunately // getValue is more or less the same as windows/registry's getValue.
// private.
// //
func getValue(k registry.Key, name string, buf []byte) ([]byte, uint32, error) { func getValue(k registry.Key, name string, buf []byte) (value []byte, valueType uint32, err error) {
p, err := syscall.UTF16PtrFromString(name) var name16 *uint16
name16, err = windows.UTF16PtrFromString(name)
if err != nil { if err != nil {
return nil, 0, err return
} }
var t uint32
n := uint32(len(buf)) n := uint32(len(buf))
for { for {
err = syscall.RegQueryValueEx(syscall.Handle(k), p, nil, &t, (*byte)(unsafe.Pointer(&buf[0])), &n) err = windows.RegQueryValueEx(windows.Handle(k), name16, nil, &valueType, (*byte)(unsafe.Pointer(&buf[0])), &n)
if err == nil { if err == nil {
return buf[:n], t, nil value = buf[:n]
return
} }
if err != syscall.ERROR_MORE_DATA { if err != windows.ERROR_MORE_DATA {
return nil, 0, err return
} }
if n <= uint32(len(buf)) { if n <= uint32(len(buf)) {
return nil, 0, err return
} }
buf = make([]byte, n) buf = make([]byte, n)
} }
@ -184,7 +183,7 @@ func toString(buf []byte, valueType uint32, err error) (string, error) {
if len(buf) == 0 { if len(buf) == 0 {
return "", nil return "", nil
} }
value = syscall.UTF16ToString((*[1 << 29]uint16)(unsafe.Pointer(&buf[0]))[:len(buf)/2]) value = windows.UTF16ToString((*[(1 << 30) - 1]uint16)(unsafe.Pointer(&buf[0]))[:len(buf)/2])
default: default:
return "", registry.ErrUnexpectedType return "", registry.ErrUnexpectedType
@ -215,13 +214,17 @@ func toInteger(buf []byte, valueType uint32, err error) (uint64, error) {
if len(buf) != 4 { if len(buf) != 4 {
return 0, errors.New("DWORD value is not 4 bytes long") return 0, errors.New("DWORD value is not 4 bytes long")
} }
return uint64(*(*uint32)(unsafe.Pointer(&buf[0]))), nil var val uint32
copy((*[4]byte)(unsafe.Pointer(&val))[:], buf)
return uint64(val), nil
case registry.QWORD: case registry.QWORD:
if len(buf) != 8 { if len(buf) != 8 {
return 0, errors.New("QWORD value is not 8 bytes long") return 0, errors.New("QWORD value is not 8 bytes long")
} }
return uint64(*(*uint64)(unsafe.Pointer(&buf[0]))), nil var val uint64
copy((*[8]byte)(unsafe.Pointer(&val))[:], buf)
return val, nil
default: default:
return 0, registry.ErrUnexpectedType return 0, registry.ErrUnexpectedType
@ -240,7 +243,7 @@ func toInteger(buf []byte, valueType uint32, err error) (uint64, error) {
// If the value type is REG_MULTI_SZ only the first string is returned. // If the value type is REG_MULTI_SZ only the first string is returned.
// //
func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) { func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) {
return toString(getValueRetry(key, name, make([]byte, 64), timeout)) return toString(getValueRetry(key, name, make([]byte, 256), timeout))
} }
// //
@ -254,7 +257,7 @@ func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (s
// If the value type is REG_MULTI_SZ only the first string is returned. // If the value type is REG_MULTI_SZ only the first string is returned.
// //
func GetStringValue(key registry.Key, name string) (string, error) { func GetStringValue(key registry.Key, name string) (string, error) {
return toString(getValue(key, name, make([]byte, 64))) return toString(getValue(key, name, make([]byte, 256)))
} }
// //