diff --git a/ipc/winpipe/file.go b/ipc/namedpipe/file.go similarity index 96% rename from ipc/winpipe/file.go rename to ipc/namedpipe/file.go index 319565f..9c2481d 100644 --- a/ipc/winpipe/file.go +++ b/ipc/namedpipe/file.go @@ -1,12 +1,12 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Copyright 2015 Microsoft +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + //go:build windows +// +build windows -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package winpipe +package namedpipe import ( "io" diff --git a/ipc/winpipe/winpipe.go b/ipc/namedpipe/namedpipe.go similarity index 83% rename from ipc/winpipe/winpipe.go rename to ipc/namedpipe/namedpipe.go index e3719d6..6db5ea3 100644 --- a/ipc/winpipe/winpipe.go +++ b/ipc/namedpipe/namedpipe.go @@ -1,13 +1,13 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Copyright 2015 Microsoft +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + //go:build windows +// +build windows -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -// Package winpipe implements a net.Conn and net.Listener around Windows named pipes. -package winpipe +// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes. +package namedpipe import ( "context" @@ -15,6 +15,7 @@ import ( "net" "os" "runtime" + "sync/atomic" "time" "unsafe" @@ -28,7 +29,7 @@ type pipe struct { type messageBytePipe struct { pipe - writeClosed bool + writeClosed int32 readEOF bool } @@ -50,25 +51,26 @@ func (f *pipe) SetDeadline(t time.Time) error { // CloseWrite closes the write side of a message pipe in byte mode. func (f *messageBytePipe) CloseWrite() error { - if f.writeClosed { + if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) { return io.ErrClosedPipe } err := f.file.Flush() if err != nil { + atomic.StoreInt32(&f.writeClosed, 0) return err } _, err = f.file.Write(nil) if err != nil { + atomic.StoreInt32(&f.writeClosed, 0) return err } - f.writeClosed = true return nil } // Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since // they are used to implement CloseWrite. func (f *messageBytePipe) Write(b []byte) (int, error) { - if f.writeClosed { + if atomic.LoadInt32(&f.writeClosed) != 0 { return 0, io.ErrClosedPipe } if len(b) == 0 { @@ -142,30 +144,24 @@ type DialConfig struct { ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID. } -// Dial connects to the specified named pipe by path, timing out if the connection -// takes longer than the specified duration. If timeout is nil, then we use -// a default timeout of 2 seconds. -func Dial(path string, timeout *time.Duration, config *DialConfig) (net.Conn, error) { - var absTimeout time.Time - if timeout != nil { - absTimeout = time.Now().Add(*timeout) - } else { - absTimeout = time.Now().Add(2 * time.Second) +// DialTimeout connects to the specified named pipe by path, timing out if the +// connection takes longer than the specified duration. If timeout is zero, then +// we use a default timeout of 2 seconds. +func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) { + if timeout == 0 { + timeout = time.Second * 2 } + absTimeout := time.Now().Add(timeout) ctx, _ := context.WithDeadline(context.Background(), absTimeout) - conn, err := DialContext(ctx, path, config) + conn, err := config.DialContext(ctx, path) if err == context.DeadlineExceeded { return nil, os.ErrDeadlineExceeded } return conn, err } -// DialContext attempts to connect to the specified named pipe by path -// cancellation or timeout. -func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn, error) { - if config == nil { - config = &DialConfig{} - } +// DialContext attempts to connect to the specified named pipe by path. +func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) { var err error var h windows.Handle h, err = tryDialPipe(ctx, &path) @@ -213,6 +209,18 @@ func DialContext(ctx context.Context, path string, config *DialConfig) (net.Conn return &pipe{file: f, path: path}, nil } +var defaultDialer DialConfig + +// DialTimeout calls DialConfig.DialTimeout using an empty configuration. +func DialTimeout(path string, timeout time.Duration) (net.Conn, error) { + return defaultDialer.DialTimeout(path, timeout) +} + +// DialContext calls DialConfig.DialContext using an empty configuration. +func DialContext(ctx context.Context, path string) (net.Conn, error) { + return defaultDialer.DialContext(ctx, path) +} + type acceptResponse struct { f *file err error @@ -222,12 +230,12 @@ type pipeListener struct { firstHandle windows.Handle path string config ListenConfig - acceptCh chan (chan acceptResponse) + acceptCh chan chan acceptResponse closeCh chan int doneCh chan int } -func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, first bool) (windows.Handle, error) { +func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) { path16, err := windows.UTF16PtrFromString(path) if err != nil { return 0, &os.PathError{Op: "open", Path: path, Err: err} @@ -247,7 +255,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste oa.ObjectName = &ntPath // The security descriptor is only needed for the first pipe. - if first { + if isFirstPipe { if sd != nil { oa.SecurityDescriptor = sd } else { @@ -257,7 +265,7 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste return 0, err } defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl))) - sd, err := windows.NewSecurityDescriptor() + sd, err = windows.NewSecurityDescriptor() if err != nil { return 0, err } @@ -275,11 +283,11 @@ func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *Liste disposition := uint32(windows.FILE_OPEN) access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) - if first { + if isFirstPipe { disposition = windows.FILE_CREATE // By not asking for read or write access, the named pipe file system // will put this pipe into an initially disconnected state, blocking - // client connections until the next call with first == false. + // client connections until the next call with isFirstPipe == false. access = windows.SYNCHRONIZE } @@ -395,10 +403,7 @@ type ListenConfig struct { // Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe. // The pipe must not already exist. -func Listen(path string, c *ListenConfig) (net.Listener, error) { - if c == nil { - c = &ListenConfig{} - } +func (c *ListenConfig) Listen(path string) (net.Listener, error) { h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) if err != nil { return nil, err @@ -407,12 +412,12 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) { firstHandle: h, path: path, config: *c, - acceptCh: make(chan (chan acceptResponse)), + acceptCh: make(chan chan acceptResponse), closeCh: make(chan int), doneCh: make(chan int), } // The first connection is swallowed on Windows 7 & 8, so synthesize it. - if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { + if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) { path16, err := windows.UTF16PtrFromString(path) if err == nil { h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) @@ -425,6 +430,13 @@ func Listen(path string, c *ListenConfig) (net.Listener, error) { return l, nil } +var defaultListener ListenConfig + +// Listen calls ListenConfig.Listen using an empty configuration. +func Listen(path string) (net.Listener, error) { + return defaultListener.Listen(path) +} + func connectPipe(p *file) error { c, err := p.prepareIo() if err != nil { diff --git a/ipc/winpipe/winpipe_test.go b/ipc/namedpipe/namedpipe_test.go similarity index 81% rename from ipc/winpipe/winpipe_test.go rename to ipc/namedpipe/namedpipe_test.go index ea515e3..0573d0f 100644 --- a/ipc/winpipe/winpipe_test.go +++ b/ipc/namedpipe/namedpipe_test.go @@ -1,12 +1,12 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Copyright 2015 Microsoft +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + //go:build windows +// +build windows -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package winpipe_test +package namedpipe_test import ( "bufio" @@ -22,7 +22,7 @@ import ( "time" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/ipc/winpipe" + "golang.zx2c4.com/wireguard/ipc/namedpipe" ) func randomPipePath() string { @@ -30,7 +30,7 @@ func randomPipePath() string { if err != nil { panic(err) } - return `\\.\PIPE\go-winpipe-test-` + guid.String() + return `\\.\PIPE\go-namedpipe-test-` + guid.String() } func TestPingPong(t *testing.T) { @@ -39,7 +39,7 @@ func TestPingPong(t *testing.T) { pong = 24 ) pipePath := randomPipePath() - listener, err := winpipe.Listen(pipePath, nil) + listener, err := namedpipe.Listen(pipePath) if err != nil { t.Fatalf("unable to listen on pipe: %v", err) } @@ -64,11 +64,12 @@ func TestPingPong(t *testing.T) { t.Fatalf("unable to write pong to pipe: %v", err) } }() - client, err := winpipe.Dial(pipePath, nil, nil) + client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatalf("unable to dial pipe: %v", err) } defer client.Close() + client.SetDeadline(time.Now().Add(time.Second * 5)) var data [1]byte data[0] = ping _, err = client.Write(data[:]) @@ -85,7 +86,7 @@ func TestPingPong(t *testing.T) { } func TestDialUnknownFailsImmediately(t *testing.T) { - _, err := winpipe.Dial(randomPipePath(), nil, nil) + _, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0)) if !errors.Is(err, syscall.ENOENT) { t.Fatalf("expected ENOENT got %v", err) } @@ -93,13 +94,15 @@ func TestDialUnknownFailsImmediately(t *testing.T) { func TestDialListenerTimesOut(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() - d := 10 * time.Millisecond - _, err = winpipe.Dial(pipePath, &d, nil) + pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond) + if err == nil { + pipe.Close() + } if err != os.ErrDeadlineExceeded { t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) } @@ -107,14 +110,17 @@ func TestDialListenerTimesOut(t *testing.T) { func TestDialContextListenerTimesOut(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() d := 10 * time.Millisecond ctx, _ := context.WithTimeout(context.Background(), d) - _, err = winpipe.DialContext(ctx, pipePath, nil) + pipe, err := namedpipe.DialContext(ctx, pipePath) + if err == nil { + pipe.Close() + } if err != context.DeadlineExceeded { t.Fatalf("expected context.DeadlineExceeded, got %v", err) } @@ -123,14 +129,14 @@ func TestDialContextListenerTimesOut(t *testing.T) { func TestDialListenerGetsCancelled(t *testing.T) { pipePath := randomPipePath() ctx, cancel := context.WithCancel(context.Background()) - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } - ch := make(chan error) defer l.Close() + ch := make(chan error) go func(ctx context.Context, ch chan error) { - _, err := winpipe.DialContext(ctx, pipePath, nil) + _, err := namedpipe.DialContext(ctx, pipePath) ch <- err }(ctx, ch) time.Sleep(time.Millisecond * 30) @@ -147,23 +153,28 @@ func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { } pipePath := randomPipePath() sd, _ := windows.SecurityDescriptorFromString("D:") - c := winpipe.ListenConfig{ + l, err := (&namedpipe.ListenConfig{ SecurityDescriptor: sd, - } - l, err := winpipe.Listen(pipePath, &c) + }).Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() - _, err = winpipe.Dial(pipePath, nil, nil) + pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err == nil { + pipe.Close() + } if !errors.Is(err, windows.ERROR_ACCESS_DENIED) { t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) } } -func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, err error) { +func getConnection(cfg *namedpipe.ListenConfig) (client net.Conn, server net.Conn, err error) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, cfg) + if cfg == nil { + cfg = &namedpipe.ListenConfig{} + } + l, err := cfg.Listen(pipePath) if err != nil { return } @@ -179,7 +190,7 @@ func getConnection(cfg *winpipe.ListenConfig) (client net.Conn, server net.Conn, ch <- response{c, err} }() - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { return } @@ -236,7 +247,7 @@ func server(l net.Listener, ch chan int) { func TestFullListenDialReadWrite(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } @@ -245,7 +256,7 @@ func TestFullListenDialReadWrite(t *testing.T) { ch := make(chan int) go server(l, ch) - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -275,7 +286,7 @@ func TestFullListenDialReadWrite(t *testing.T) { func TestCloseAbortsListen(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } @@ -328,7 +339,7 @@ func TestCloseServerEOFClient(t *testing.T) { } func TestCloseWriteEOF(t *testing.T) { - cfg := &winpipe.ListenConfig{ + cfg := &namedpipe.ListenConfig{ MessageMode: true, } c, s, err := getConnection(cfg) @@ -356,7 +367,7 @@ func TestCloseWriteEOF(t *testing.T) { func TestAcceptAfterCloseFails(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } @@ -369,12 +380,15 @@ func TestAcceptAfterCloseFails(t *testing.T) { func TestDialTimesOutByDefault(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } defer l.Close() - _, err = winpipe.Dial(pipePath, nil, nil) + pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds. + if err == nil { + pipe.Close() + } if err != os.ErrDeadlineExceeded { t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) } @@ -382,7 +396,7 @@ func TestDialTimesOutByDefault(t *testing.T) { func TestTimeoutPendingRead(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } @@ -400,7 +414,7 @@ func TestTimeoutPendingRead(t *testing.T) { close(serverDone) }() - client, err := winpipe.Dial(pipePath, nil, nil) + client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -430,7 +444,7 @@ func TestTimeoutPendingRead(t *testing.T) { func TestTimeoutPendingWrite(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } @@ -448,7 +462,7 @@ func TestTimeoutPendingWrite(t *testing.T) { close(serverDone) }() - client, err := winpipe.Dial(pipePath, nil, nil) + client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -480,13 +494,12 @@ type CloseWriter interface { } func TestEchoWithMessaging(t *testing.T) { - c := winpipe.ListenConfig{ + pipePath := randomPipePath() + l, err := (&namedpipe.ListenConfig{ MessageMode: true, // Use message mode so that CloseWrite() is supported InputBufferSize: 65536, // Use 64KB buffers to improve performance OutputBufferSize: 65536, - } - pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, &c) + }).Listen(pipePath) if err != nil { t.Fatal(err) } @@ -496,19 +509,21 @@ func TestEchoWithMessaging(t *testing.T) { clientDone := make(chan bool) go func() { // server echo - conn, e := l.Accept() - if e != nil { - t.Fatal(e) + conn, err := l.Accept() + if err != nil { + t.Fatal(err) } defer conn.Close() time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent - io.Copy(conn, conn) + _, err = io.Copy(conn, conn) + if err != nil { + t.Fatal(err) + } conn.(CloseWriter).CloseWrite() close(listenerDone) }() - timeout := 1 * time.Second - client, err := winpipe.Dial(pipePath, &timeout, nil) + client, err := namedpipe.DialTimeout(pipePath, time.Second) if err != nil { t.Fatal(err) } @@ -521,7 +536,7 @@ func TestEchoWithMessaging(t *testing.T) { if e != nil { t.Fatal(e) } - if n != 2 { + if n != 2 || bytes[0] != 0 || bytes[1] != 1 { t.Fatalf("expected 2 bytes, got %v", n) } close(clientDone) @@ -545,7 +560,7 @@ func TestEchoWithMessaging(t *testing.T) { func TestConnectRace(t *testing.T) { pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, nil) + l, err := namedpipe.Listen(pipePath) if err != nil { t.Fatal(err) } @@ -565,7 +580,7 @@ func TestConnectRace(t *testing.T) { }() for i := 0; i < 1000; i++ { - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -580,7 +595,7 @@ func TestMessageReadMode(t *testing.T) { var wg sync.WaitGroup defer wg.Wait() pipePath := randomPipePath() - l, err := winpipe.Listen(pipePath, &winpipe.ListenConfig{MessageMode: true}) + l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath) if err != nil { t.Fatal(err) } @@ -602,7 +617,7 @@ func TestMessageReadMode(t *testing.T) { s.Close() }() - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err != nil { t.Fatal(err) } @@ -643,13 +658,13 @@ func TestListenConnectRace(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - c, err := winpipe.Dial(pipePath, nil, nil) + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) if err == nil { c.Close() } wg.Done() }() - s, err := winpipe.Listen(pipePath, nil) + s, err := namedpipe.Listen(pipePath) if err != nil { t.Error(i, err) } else { diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go index a4d68da..a1bfbd1 100644 --- a/ipc/uapi_windows.go +++ b/ipc/uapi_windows.go @@ -9,8 +9,7 @@ import ( "net" "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/ipc/winpipe" + "golang.zx2c4.com/wireguard/ipc/namedpipe" ) // TODO: replace these with actual standard windows error numbers from the win package @@ -61,10 +60,9 @@ func init() { } func UAPIListen(name string) (net.Listener, error) { - config := winpipe.ListenConfig{ + listener, err := (&namedpipe.ListenConfig{ SecurityDescriptor: UAPISecurityDescriptor, - } - listener, err := winpipe.Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config) + }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name) if err != nil { return nil, err } diff --git a/tun/wintun/dll_windows.go b/tun/wintun/dll_windows.go deleted file mode 100644 index 3832c1e..0000000 --- a/tun/wintun/dll_windows.go +++ /dev/null @@ -1,128 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package wintun - -import ( - "fmt" - "sync" - "sync/atomic" - "unsafe" - - "golang.org/x/sys/windows" -) - -func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL { - return &lazyDLL{Name: name, onLoad: onLoad} -} - -func (d *lazyDLL) NewProc(name string) *lazyProc { - return &lazyProc{dll: d, Name: name} -} - -type lazyProc struct { - Name string - mu sync.Mutex - dll *lazyDLL - addr uintptr -} - -func (p *lazyProc) Find() error { - if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil { - return nil - } - p.mu.Lock() - defer p.mu.Unlock() - if p.addr != 0 { - return nil - } - - err := p.dll.Load() - if err != nil { - return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err) - } - addr, err := p.nameToAddr() - if err != nil { - return fmt.Errorf("Error getting %v address: %w", p.Name, err) - } - - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr)) - return nil -} - -func (p *lazyProc) Addr() uintptr { - err := p.Find() - if err != nil { - panic(err) - } - return p.addr -} - -type lazyDLL struct { - Name string - mu sync.Mutex - module windows.Handle - onLoad func(d *lazyDLL) -} - -func (d *lazyDLL) Load() error { - if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil { - return nil - } - d.mu.Lock() - defer d.mu.Unlock() - if d.module != 0 { - return nil - } - - const ( - LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200 - LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800 - ) - module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32) - if err != nil { - return fmt.Errorf("Unable to load library: %w", err) - } - - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module)) - if d.onLoad != nil { - d.onLoad(d) - } - return nil -} - -func (p *lazyProc) nameToAddr() (uintptr, error) { - return windows.GetProcAddress(p.dll.module, p.Name) -} - -// Version returns the version of the Wintun DLL. -func Version() string { - if modwintun.Load() != nil { - return "unknown" - } - resInfo, err := windows.FindResource(modwintun.module, windows.ResourceID(1), windows.RT_VERSION) - if err != nil { - return "unknown" - } - data, err := windows.LoadResourceData(modwintun.module, resInfo) - if err != nil { - return "unknown" - } - - var fixedInfo *windows.VS_FIXEDFILEINFO - fixedInfoLen := uint32(unsafe.Sizeof(*fixedInfo)) - err = windows.VerQueryValue(unsafe.Pointer(&data[0]), `\`, unsafe.Pointer(&fixedInfo), &fixedInfoLen) - if err != nil { - return "unknown" - } - version := fmt.Sprintf("%d.%d", (fixedInfo.FileVersionMS>>16)&0xff, (fixedInfo.FileVersionMS>>0)&0xff) - if nextNibble := (fixedInfo.FileVersionLS >> 16) & 0xff; nextNibble != 0 { - version += fmt.Sprintf(".%d", nextNibble) - } - if nextNibble := (fixedInfo.FileVersionLS >> 0) & 0xff; nextNibble != 0 { - version += fmt.Sprintf(".%d", nextNibble) - } - return version -} diff --git a/tun/wintun/session_windows.go b/tun/wintun/session_windows.go deleted file mode 100644 index f023baf..0000000 --- a/tun/wintun/session_windows.go +++ /dev/null @@ -1,90 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package wintun - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -type Session struct { - handle uintptr -} - -const ( - PacketSizeMax = 0xffff // Maximum packet size - RingCapacityMin = 0x20000 // Minimum ring capacity (128 kiB) - RingCapacityMax = 0x4000000 // Maximum ring capacity (64 MiB) -) - -// Packet with data -type Packet struct { - Next *Packet // Pointer to next packet in queue - Size uint32 // Size of packet (max WINTUN_MAX_IP_PACKET_SIZE) - Data *[PacketSizeMax]byte // Pointer to layer 3 IPv4 or IPv6 packet -} - -var ( - procWintunAllocateSendPacket = modwintun.NewProc("WintunAllocateSendPacket") - procWintunEndSession = modwintun.NewProc("WintunEndSession") - procWintunGetReadWaitEvent = modwintun.NewProc("WintunGetReadWaitEvent") - procWintunReceivePacket = modwintun.NewProc("WintunReceivePacket") - procWintunReleaseReceivePacket = modwintun.NewProc("WintunReleaseReceivePacket") - procWintunSendPacket = modwintun.NewProc("WintunSendPacket") - procWintunStartSession = modwintun.NewProc("WintunStartSession") -) - -func (wintun *Adapter) StartSession(capacity uint32) (session Session, err error) { - r0, _, e1 := syscall.Syscall(procWintunStartSession.Addr(), 2, uintptr(wintun.handle), uintptr(capacity), 0) - if r0 == 0 { - err = e1 - } else { - session = Session{r0} - } - return -} - -func (session Session) End() { - syscall.Syscall(procWintunEndSession.Addr(), 1, session.handle, 0, 0) - session.handle = 0 -} - -func (session Session) ReadWaitEvent() (handle windows.Handle) { - r0, _, _ := syscall.Syscall(procWintunGetReadWaitEvent.Addr(), 1, session.handle, 0, 0) - handle = windows.Handle(r0) - return -} - -func (session Session) ReceivePacket() (packet []byte, err error) { - var packetSize uint32 - r0, _, e1 := syscall.Syscall(procWintunReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packetSize)), 0) - if r0 == 0 { - err = e1 - return - } - packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize) - return -} - -func (session Session) ReleaseReceivePacket(packet []byte) { - syscall.Syscall(procWintunReleaseReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) -} - -func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err error) { - r0, _, e1 := syscall.Syscall(procWintunAllocateSendPacket.Addr(), 2, session.handle, uintptr(packetSize), 0) - if r0 == 0 { - err = e1 - return - } - packet = unsafe.Slice((*byte)(unsafe.Pointer(r0)), packetSize) - return -} - -func (session Session) SendPacket(packet []byte) { - syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) -} diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go deleted file mode 100644 index 2fe26a7..0000000 --- a/tun/wintun/wintun_windows.go +++ /dev/null @@ -1,150 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package wintun - -import ( - "log" - "runtime" - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -type loggerLevel int - -const ( - logInfo loggerLevel = iota - logWarn - logErr -) - -const AdapterNameMax = 128 - -type Adapter struct { - handle uintptr -} - -var ( - modwintun = newLazyDLL("wintun.dll", setupLogger) - procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter") - procWintunOpenAdapter = modwintun.NewProc("WintunOpenAdapter") - procWintunCloseAdapter = modwintun.NewProc("WintunCloseAdapter") - procWintunDeleteDriver = modwintun.NewProc("WintunDeleteDriver") - procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID") - procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion") -) - -type TimestampedWriter interface { - WriteWithTimestamp(p []byte, ts int64) (n int, err error) -} - -func logMessage(level loggerLevel, timestamp uint64, msg *uint16) int { - if tw, ok := log.Default().Writer().(TimestampedWriter); ok { - tw.WriteWithTimestamp([]byte(log.Default().Prefix()+windows.UTF16PtrToString(msg)), (int64(timestamp)-116444736000000000)*100) - } else { - log.Println(windows.UTF16PtrToString(msg)) - } - return 0 -} - -func setupLogger(dll *lazyDLL) { - var callback uintptr - if runtime.GOARCH == "386" { - callback = windows.NewCallback(func(level loggerLevel, timestampLow, timestampHigh uint32, msg *uint16) int { - return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg) - }) - } else if runtime.GOARCH == "arm" { - callback = windows.NewCallback(func(level loggerLevel, _, timestampLow, timestampHigh uint32, msg *uint16) int { - return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg) - }) - } else if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" { - callback = windows.NewCallback(logMessage) - } - syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, callback, 0, 0) -} - -func closeAdapter(wintun *Adapter) { - syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0) -} - -// CreateAdapter creates a Wintun adapter. name is the cosmetic name of the adapter. -// tunnelType represents the type of adapter and should be "Wintun". requestedGUID is -// the GUID of the created network adapter, which then influences NLA generation -// deterministically. If it is set to nil, the GUID is chosen by the system at random, -// and hence a new NLA entry is created for each new adapter. -func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) { - var name16 *uint16 - name16, err = windows.UTF16PtrFromString(name) - if err != nil { - return - } - var tunnelType16 *uint16 - tunnelType16, err = windows.UTF16PtrFromString(tunnelType) - if err != nil { - return - } - r0, _, e1 := syscall.Syscall(procWintunCreateAdapter.Addr(), 3, uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(requestedGUID))) - if r0 == 0 { - err = e1 - return - } - wintun = &Adapter{handle: r0} - runtime.SetFinalizer(wintun, closeAdapter) - return -} - -// OpenAdapter opens an existing Wintun adapter by name. -func OpenAdapter(name string) (wintun *Adapter, err error) { - var name16 *uint16 - name16, err = windows.UTF16PtrFromString(name) - if err != nil { - return - } - r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 1, uintptr(unsafe.Pointer(name16)), 0, 0) - if r0 == 0 { - err = e1 - return - } - wintun = &Adapter{handle: r0} - runtime.SetFinalizer(wintun, closeAdapter) - return -} - -// Close closes a Wintun adapter. -func (wintun *Adapter) Close() (err error) { - runtime.SetFinalizer(wintun, nil) - r1, _, e1 := syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0) - if r1 == 0 { - err = e1 - } - return -} - -// Uninstall removes the driver from the system if no drivers are currently in use. -func Uninstall() (err error) { - r1, _, e1 := syscall.Syscall(procWintunDeleteDriver.Addr(), 0, 0, 0, 0) - if r1 == 0 { - err = e1 - } - return -} - -// RunningVersion returns the version of the loaded driver. -func RunningVersion() (version uint32, err error) { - r0, _, e1 := syscall.Syscall(procWintunGetRunningDriverVersion.Addr(), 0, 0, 0, 0) - version = uint32(r0) - if version == 0 { - err = e1 - } - return -} - -// LUID returns the LUID of the adapter. -func (wintun *Adapter) LUID() (luid uint64) { - syscall.Syscall(procWintunGetAdapterLUID.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&luid)), 0) - return -}