From e4b957183c4a330f020f5188f3b30b59355efb80 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Fri, 30 Aug 2019 13:21:47 -0600 Subject: [PATCH] winpipe: enforce ownership of client connection --- ipc/winpipe/pipe.go | 22 +++++++++++++++++++--- ipc/winpipe/sd.go | 15 +++++++++++---- ipc/winpipe/zsyscall_windows.go | 16 ++++++++++++++++ 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/ipc/winpipe/pipe.go b/ipc/winpipe/pipe.go index 1e99a93..39ccfa4 100644 --- a/ipc/winpipe/pipe.go +++ b/ipc/winpipe/pipe.go @@ -211,7 +211,7 @@ func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) { // DialPipe connects to a named pipe by path, timing out if the connection // takes longer than the specified duration. If timeout is nil, then we use // a default timeout of 2 seconds. (We do not use WaitNamedPipe.) -func DialPipe(path string, timeout *time.Duration) (net.Conn, error) { +func DialPipe(path string, timeout *time.Duration, expectedOwner *syscall.SID) (net.Conn, error) { var absTimeout time.Time if timeout != nil { absTimeout = time.Now().Add(*timeout) @@ -219,7 +219,7 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) { absTimeout = time.Now().Add(time.Second * 2) } ctx, _ := context.WithDeadline(context.Background(), absTimeout) - conn, err := DialPipeContext(ctx, path) + conn, err := DialPipeContext(ctx, path, expectedOwner) if err == context.DeadlineExceeded { return nil, ErrTimeout } @@ -228,7 +228,7 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) { // DialPipeContext attempts to connect to a named pipe by `path` until `ctx` // cancellation or timeout. -func DialPipeContext(ctx context.Context, path string) (net.Conn, error) { +func DialPipeContext(ctx context.Context, path string, expectedOwner *syscall.SID) (net.Conn, error) { var err error var h syscall.Handle h, err = tryDialPipe(ctx, &path) @@ -236,9 +236,25 @@ func DialPipeContext(ctx context.Context, path string) (net.Conn, error) { return nil, err } + if expectedOwner != nil { + var realOwner *syscall.SID + var realSd uintptr + err = getSecurityInfo(h, SE_FILE_OBJECT, OWNER_SECURITY_INFORMATION, &realOwner, nil, nil, nil, &realSd) + if err != nil { + syscall.Close(h) + return nil, err + } + defer localFree(realSd) + if !equalSid(realOwner, expectedOwner) { + syscall.Close(h) + return nil, syscall.ERROR_ACCESS_DENIED + } + } + var flags uint32 err = getNamedPipeInfo(h, &flags, nil, nil, nil) if err != nil { + syscall.Close(h) return nil, err } diff --git a/ipc/winpipe/sd.go b/ipc/winpipe/sd.go index 75686b2..4456917 100644 --- a/ipc/winpipe/sd.go +++ b/ipc/winpipe/sd.go @@ -12,9 +12,16 @@ import ( "unsafe" ) -//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW -//sys localFree(mem uintptr) = LocalFree -//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength +//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW +//sys localFree(mem uintptr) = LocalFree +//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength +//sys getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) = advapi32.GetSecurityInfo +//sys equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) = advapi32.EqualSid + +const ( + SE_FILE_OBJECT = 1 + OWNER_SECURITY_INFORMATION = 1 +) func SddlToSecurityDescriptor(sddl string) ([]byte, error) { var sdBuffer uintptr @@ -26,4 +33,4 @@ func SddlToSecurityDescriptor(sddl string) ([]byte, error) { sd := make([]byte, getSecurityDescriptorLength(sdBuffer)) copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)]) return sd, nil -} +} \ No newline at end of file diff --git a/ipc/winpipe/zsyscall_windows.go b/ipc/winpipe/zsyscall_windows.go index b8eedb4..ecf3e84 100644 --- a/ipc/winpipe/zsyscall_windows.go +++ b/ipc/winpipe/zsyscall_windows.go @@ -55,6 +55,8 @@ var ( procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW") procLocalFree = modkernel32.NewProc("LocalFree") procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength") + procGetSecurityInfo = modadvapi32.NewProc("GetSecurityInfo") + procEqualSid = modadvapi32.NewProc("EqualSid") procCancelIoEx = modkernel32.NewProc("CancelIoEx") procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort") procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus") @@ -206,6 +208,20 @@ func getSecurityDescriptorLength(sd uintptr) (len uint32) { return } +func getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) { + r0, _, _ := syscall.Syscall9(procGetSecurityInfo.Addr(), 8, uintptr(handle), uintptr(objectType), uintptr(securityInformation), uintptr(unsafe.Pointer(owner)), uintptr(unsafe.Pointer(group)), uintptr(unsafe.Pointer(dacl)), uintptr(unsafe.Pointer(sacl)), uintptr(unsafe.Pointer(sd)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) { + r0, _, _ := syscall.Syscall(procEqualSid.Addr(), 2, uintptr(unsafe.Pointer(sid1)), uintptr(unsafe.Pointer(sid2)), 0) + isEqual = r0 != 0 + return +} + func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) { r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0) if r1 == 0 {