winpipe: enforce ownership of client connection
This commit is contained in:
parent
950ca2ba8c
commit
e4b957183c
|
@ -211,7 +211,7 @@ func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
|
||||||
// DialPipe connects to a named pipe by path, timing out if the connection
|
// DialPipe connects to a named pipe by path, timing out if the connection
|
||||||
// takes longer than the specified duration. If timeout is nil, then we use
|
// takes longer than the specified duration. If timeout is nil, then we use
|
||||||
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
|
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
|
||||||
func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
|
func DialPipe(path string, timeout *time.Duration, expectedOwner *syscall.SID) (net.Conn, error) {
|
||||||
var absTimeout time.Time
|
var absTimeout time.Time
|
||||||
if timeout != nil {
|
if timeout != nil {
|
||||||
absTimeout = time.Now().Add(*timeout)
|
absTimeout = time.Now().Add(*timeout)
|
||||||
|
@ -219,7 +219,7 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
|
||||||
absTimeout = time.Now().Add(time.Second * 2)
|
absTimeout = time.Now().Add(time.Second * 2)
|
||||||
}
|
}
|
||||||
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
|
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
|
||||||
conn, err := DialPipeContext(ctx, path)
|
conn, err := DialPipeContext(ctx, path, expectedOwner)
|
||||||
if err == context.DeadlineExceeded {
|
if err == context.DeadlineExceeded {
|
||||||
return nil, ErrTimeout
|
return nil, ErrTimeout
|
||||||
}
|
}
|
||||||
|
@ -228,7 +228,7 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
|
||||||
|
|
||||||
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
|
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
|
||||||
// cancellation or timeout.
|
// cancellation or timeout.
|
||||||
func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
|
func DialPipeContext(ctx context.Context, path string, expectedOwner *syscall.SID) (net.Conn, error) {
|
||||||
var err error
|
var err error
|
||||||
var h syscall.Handle
|
var h syscall.Handle
|
||||||
h, err = tryDialPipe(ctx, &path)
|
h, err = tryDialPipe(ctx, &path)
|
||||||
|
@ -236,9 +236,25 @@ func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if expectedOwner != nil {
|
||||||
|
var realOwner *syscall.SID
|
||||||
|
var realSd uintptr
|
||||||
|
err = getSecurityInfo(h, SE_FILE_OBJECT, OWNER_SECURITY_INFORMATION, &realOwner, nil, nil, nil, &realSd)
|
||||||
|
if err != nil {
|
||||||
|
syscall.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer localFree(realSd)
|
||||||
|
if !equalSid(realOwner, expectedOwner) {
|
||||||
|
syscall.Close(h)
|
||||||
|
return nil, syscall.ERROR_ACCESS_DENIED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var flags uint32
|
var flags uint32
|
||||||
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
|
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
syscall.Close(h)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,9 +12,16 @@ import (
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
|
//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
|
||||||
//sys localFree(mem uintptr) = LocalFree
|
//sys localFree(mem uintptr) = LocalFree
|
||||||
//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
|
//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
|
||||||
|
//sys getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) = advapi32.GetSecurityInfo
|
||||||
|
//sys equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) = advapi32.EqualSid
|
||||||
|
|
||||||
|
const (
|
||||||
|
SE_FILE_OBJECT = 1
|
||||||
|
OWNER_SECURITY_INFORMATION = 1
|
||||||
|
)
|
||||||
|
|
||||||
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
|
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
|
||||||
var sdBuffer uintptr
|
var sdBuffer uintptr
|
||||||
|
@ -26,4 +33,4 @@ func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
|
||||||
sd := make([]byte, getSecurityDescriptorLength(sdBuffer))
|
sd := make([]byte, getSecurityDescriptorLength(sdBuffer))
|
||||||
copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
|
copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
|
||||||
return sd, nil
|
return sd, nil
|
||||||
}
|
}
|
|
@ -55,6 +55,8 @@ var (
|
||||||
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
|
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
|
||||||
procLocalFree = modkernel32.NewProc("LocalFree")
|
procLocalFree = modkernel32.NewProc("LocalFree")
|
||||||
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength")
|
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength")
|
||||||
|
procGetSecurityInfo = modadvapi32.NewProc("GetSecurityInfo")
|
||||||
|
procEqualSid = modadvapi32.NewProc("EqualSid")
|
||||||
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
|
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
|
||||||
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
|
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
|
||||||
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
|
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
|
||||||
|
@ -206,6 +208,20 @@ func getSecurityDescriptorLength(sd uintptr) (len uint32) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) {
|
||||||
|
r0, _, _ := syscall.Syscall9(procGetSecurityInfo.Addr(), 8, uintptr(handle), uintptr(objectType), uintptr(securityInformation), uintptr(unsafe.Pointer(owner)), uintptr(unsafe.Pointer(group)), uintptr(unsafe.Pointer(dacl)), uintptr(unsafe.Pointer(sacl)), uintptr(unsafe.Pointer(sd)), 0)
|
||||||
|
if r0 != 0 {
|
||||||
|
ret = syscall.Errno(r0)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) {
|
||||||
|
r0, _, _ := syscall.Syscall(procEqualSid.Addr(), 2, uintptr(unsafe.Pointer(sid1)), uintptr(unsafe.Pointer(sid2)), 0)
|
||||||
|
isEqual = r0 != 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
|
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
|
||||||
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
|
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
|
||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
|
|
Loading…
Reference in a new issue