From 6f5ef153c3b578da99501cfcbe4f2e4a84d44708 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sun, 27 Aug 2017 15:41:00 +0200 Subject: [PATCH] Added code from windows branch --- src/build.cmd | 6 + src/conn_default.go | 2 +- src/daemon_windows.go | 34 +++ src/timers.go | 672 +++++++++++++++++++++--------------------- src/tun_windows.go | 475 +++++++++++++++++++++++++++++ src/uapi_windows.go | 44 +++ 6 files changed, 896 insertions(+), 337 deletions(-) create mode 100755 src/build.cmd create mode 100644 src/daemon_windows.go create mode 100644 src/tun_windows.go create mode 100644 src/uapi_windows.go diff --git a/src/build.cmd b/src/build.cmd new file mode 100755 index 0000000..52cb883 --- /dev/null +++ b/src/build.cmd @@ -0,0 +1,6 @@ +@echo off + +REM builds wireguard for windows + +go get +go build -o wireguard-go.exe diff --git a/src/conn_default.go b/src/conn_default.go index a6dc97d..5ef2659 100644 --- a/src/conn_default.go +++ b/src/conn_default.go @@ -6,6 +6,6 @@ import ( "net" ) -func setFwmark(conn *net.UDPConn, value int) error { +func setMark(conn *net.UDPConn, value int) error { return nil } diff --git a/src/daemon_windows.go b/src/daemon_windows.go new file mode 100644 index 0000000..d5ec1e8 --- /dev/null +++ b/src/daemon_windows.go @@ -0,0 +1,34 @@ +package main + +import ( + "os" +) + +/* Daemonizes the process on windows + * + * This is done by spawning and releasing a copy with the --foreground flag + */ + +func Daemonize() error { + argv := []string{os.Args[0], "--foreground"} + argv = append(argv, os.Args[1:]...) + attr := &os.ProcAttr{ + Dir: ".", + Env: os.Environ(), + Files: []*os.File{ + os.Stdin, + nil, + nil, + }, + } + process, err := os.StartProcess( + argv[0], + argv, + attr, + ) + if err != nil { + return err + } + process.Release() + return nil +} diff --git a/src/timers.go b/src/timers.go index ab2e7ad..de54a96 100644 --- a/src/timers.go +++ b/src/timers.go @@ -1,336 +1,336 @@ -package main - -import ( - "bytes" - "encoding/binary" - "golang.org/x/crypto/blake2s" - "math/rand" - "sync/atomic" - "time" -) - -/* Called when a new authenticated message has been send - * - */ -func (peer *Peer) KeepKeyFreshSending() { - kp := peer.keyPairs.Current() - if kp == nil { - return - } - nonce := atomic.LoadUint64(&kp.sendNonce) - if nonce > RekeyAfterMessages { - signalSend(peer.signal.handshakeBegin) - } - if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime { - signalSend(peer.signal.handshakeBegin) - } -} - -/* Called when a new authenticated message has been recevied - * - */ -func (peer *Peer) KeepKeyFreshReceiving() { - // TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete) - kp := peer.keyPairs.Current() - if kp == nil { - return - } - if !kp.isInitiator { - return - } - nonce := atomic.LoadUint64(&kp.sendNonce) - send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving - if send { - signalSend(peer.signal.handshakeBegin) - } -} - -/* Queues a keep-alive if no packets are queued for peer - */ -func (peer *Peer) SendKeepAlive() bool { - elem := peer.device.NewOutboundElement() - elem.packet = nil - if len(peer.queue.nonce) == 0 { - select { - case peer.queue.nonce <- elem: - return true - default: - return false - } - } - return true -} - -/* Event: - * Sent non-empty (authenticated) transport message - */ -func (peer *Peer) TimerDataSent() { - timerStop(peer.timer.keepalivePassive) - if !peer.timer.pendingNewHandshake { - peer.timer.pendingNewHandshake = true - peer.timer.newHandshake.Reset(NewHandshakeTime) - } -} - -/* Event: - * Received non-empty (authenticated) transport message - */ -func (peer *Peer) TimerDataReceived() { - if peer.timer.pendingKeepalivePassive { - peer.timer.needAnotherKeepalive = true - return - } - peer.timer.pendingKeepalivePassive = false - peer.timer.keepalivePassive.Reset(KeepaliveTimeout) -} - -/* Event: - * Any (authenticated) packet received - */ -func (peer *Peer) TimerAnyAuthenticatedPacketReceived() { - timerStop(peer.timer.newHandshake) -} - -/* Event: - * Any authenticated packet send / received. - */ -func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { - interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) - if interval > 0 { - duration := time.Duration(interval) * time.Second - peer.timer.keepalivePersistent.Reset(duration) - } -} - -/* Called after succesfully completing a handshake. - * i.e. after: - * - * - Valid handshake response - * - First transport message under the "next" key - */ -func (peer *Peer) TimerHandshakeComplete() { - atomic.StoreInt64( - &peer.stats.lastHandshakeNano, - time.Now().UnixNano(), - ) - signalSend(peer.signal.handshakeCompleted) - peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) -} - -/* Event: - * An ephemeral key is generated - * - * i.e after: - * - * CreateMessageInitiation - * CreateMessageResponse - * - * Schedules the deletion of all key material - * upon failure to complete a handshake - */ -func (peer *Peer) TimerEphemeralKeyCreated() { - peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) -} - -func (peer *Peer) RoutineTimerHandler() { - device := peer.device - indices := &device.indices - - logDebug := device.log.Debug - logDebug.Println("Routine, timer handler, started for peer", peer.String()) - - for { - select { - - case <-peer.signal.stop: - return - - // keep-alives - - case <-peer.timer.keepalivePersistent.C: - - interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) - if interval > 0 { - logDebug.Println("Sending keep-alive to", peer.String()) - peer.SendKeepAlive() - } - - case <-peer.timer.keepalivePassive.C: - - logDebug.Println("Sending keep-alive to", peer.String()) - - peer.SendKeepAlive() - - if peer.timer.needAnotherKeepalive { - peer.timer.keepalivePassive.Reset(KeepaliveTimeout) - peer.timer.needAnotherKeepalive = false - } - - // unresponsive session - - case <-peer.timer.newHandshake.C: - - logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply") - - signalSend(peer.signal.handshakeBegin) - - // clear key material - - case <-peer.timer.zeroAllKeys.C: - - logDebug.Println("Clearing all key material for", peer.String()) - - hs := &peer.handshake - hs.mutex.Lock() - - kp := &peer.keyPairs - kp.mutex.Lock() - - // unmap indecies - - indices.mutex.Lock() - if kp.previous != nil { - delete(indices.table, kp.previous.localIndex) - } - if kp.current != nil { - delete(indices.table, kp.current.localIndex) - } - if kp.next != nil { - delete(indices.table, kp.next.localIndex) - } - delete(indices.table, hs.localIndex) - indices.mutex.Unlock() - - // zero out key pairs (TODO: better than wait for GC) - - kp.current = nil - kp.previous = nil - kp.next = nil - kp.mutex.Unlock() - - // zero out handshake - - hs.localIndex = 0 - hs.localEphemeral = NoisePrivateKey{} - hs.remoteEphemeral = NoisePublicKey{} - hs.chainKey = [blake2s.Size]byte{} - hs.hash = [blake2s.Size]byte{} - hs.mutex.Unlock() - } - } -} - -/* This is the state machine for handshake initiation - * - * Associated with this routine is the signal "handshakeBegin" - * The routine will read from the "handshakeBegin" channel - * at most every RekeyTimeout seconds - */ -func (peer *Peer) RoutineHandshakeInitiator() { - device := peer.device - - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - logDebug.Println("Routine, handshake initator, started for", peer.String()) - - var temp [256]byte - - for { - - // wait for signal - - select { - case <-peer.signal.handshakeBegin: - case <-peer.signal.stop: - return - } - - // set deadline - - BeginHandshakes: - - signalClear(peer.signal.handshakeReset) - deadline := time.NewTimer(RekeyAttemptTime) - - AttemptHandshakes: - - for attempts := uint(1); ; attempts++ { - - // check if deadline reached - - select { - case <-deadline.C: - logInfo.Println("Handshake negotiation timed out for:", peer.String()) - signalSend(peer.signal.flushNonceQueue) - timerStop(peer.timer.keepalivePersistent) - break - case <-peer.signal.stop: - return - default: - } - - signalClear(peer.signal.handshakeCompleted) - - // create initiation message - - msg, err := peer.device.CreateMessageInitiation(peer) - if err != nil { - logError.Println("Failed to create handshake initiation message:", err) - break AttemptHandshakes - } - - jitter := time.Millisecond * time.Duration(rand.Uint32()%334) - - // marshal and send - - writer := bytes.NewBuffer(temp[:0]) - binary.Write(writer, binary.LittleEndian, msg) - packet := writer.Bytes() - peer.mac.AddMacs(packet) - - _, err = peer.SendBuffer(packet) - if err != nil { - logError.Println( - "Failed to send handshake initiation message to", - peer.String(), ":", err, - ) - break - } - - peer.TimerAnyAuthenticatedPacketTraversal() - - // set handshake timeout - - timeout := time.NewTimer(RekeyTimeout + jitter) - logDebug.Println( - "Handshake initiation attempt", - attempts, "sent to", peer.String(), - ) - - // wait for handshake or timeout - - select { - - case <-peer.signal.stop: - return - - case <-peer.signal.handshakeCompleted: - <-timeout.C - break AttemptHandshakes - - case <-peer.signal.handshakeReset: - <-timeout.C - goto BeginHandshakes - - case <-timeout.C: - // TODO: Clear source address for peer - continue - } - } - - // clear signal set in the meantime - - signalClear(peer.signal.handshakeBegin) - } -} +package main + +import ( + "bytes" + "encoding/binary" + "golang.org/x/crypto/blake2s" + "math/rand" + "sync/atomic" + "time" +) + +/* Called when a new authenticated message has been send + * + */ +func (peer *Peer) KeepKeyFreshSending() { + kp := peer.keyPairs.Current() + if kp == nil { + return + } + nonce := atomic.LoadUint64(&kp.sendNonce) + if nonce > RekeyAfterMessages { + signalSend(peer.signal.handshakeBegin) + } + if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime { + signalSend(peer.signal.handshakeBegin) + } +} + +/* Called when a new authenticated message has been recevied + * + */ +func (peer *Peer) KeepKeyFreshReceiving() { + // TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete) + kp := peer.keyPairs.Current() + if kp == nil { + return + } + if !kp.isInitiator { + return + } + nonce := atomic.LoadUint64(&kp.sendNonce) + send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving + if send { + signalSend(peer.signal.handshakeBegin) + } +} + +/* Queues a keep-alive if no packets are queued for peer + */ +func (peer *Peer) SendKeepAlive() bool { + elem := peer.device.NewOutboundElement() + elem.packet = nil + if len(peer.queue.nonce) == 0 { + select { + case peer.queue.nonce <- elem: + return true + default: + return false + } + } + return true +} + +/* Event: + * Sent non-empty (authenticated) transport message + */ +func (peer *Peer) TimerDataSent() { + timerStop(peer.timer.keepalivePassive) + if !peer.timer.pendingNewHandshake { + peer.timer.pendingNewHandshake = true + peer.timer.newHandshake.Reset(NewHandshakeTime) + } +} + +/* Event: + * Received non-empty (authenticated) transport message + */ +func (peer *Peer) TimerDataReceived() { + if peer.timer.pendingKeepalivePassive { + peer.timer.needAnotherKeepalive = true + return + } + peer.timer.pendingKeepalivePassive = false + peer.timer.keepalivePassive.Reset(KeepaliveTimeout) +} + +/* Event: + * Any (authenticated) packet received + */ +func (peer *Peer) TimerAnyAuthenticatedPacketReceived() { + timerStop(peer.timer.newHandshake) +} + +/* Event: + * Any authenticated packet send / received. + */ +func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { + interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) + if interval > 0 { + duration := time.Duration(interval) * time.Second + peer.timer.keepalivePersistent.Reset(duration) + } +} + +/* Called after succesfully completing a handshake. + * i.e. after: + * + * - Valid handshake response + * - First transport message under the "next" key + */ +func (peer *Peer) TimerHandshakeComplete() { + atomic.StoreInt64( + &peer.stats.lastHandshakeNano, + time.Now().UnixNano(), + ) + signalSend(peer.signal.handshakeCompleted) + peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) +} + +/* Event: + * An ephemeral key is generated + * + * i.e after: + * + * CreateMessageInitiation + * CreateMessageResponse + * + * Schedules the deletion of all key material + * upon failure to complete a handshake + */ +func (peer *Peer) TimerEphemeralKeyCreated() { + peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) +} + +func (peer *Peer) RoutineTimerHandler() { + device := peer.device + indices := &device.indices + + logDebug := device.log.Debug + logDebug.Println("Routine, timer handler, started for peer", peer.String()) + + for { + select { + + case <-peer.signal.stop: + return + + // keep-alives + + case <-peer.timer.keepalivePersistent.C: + + interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) + if interval > 0 { + logDebug.Println("Sending keep-alive to", peer.String()) + peer.SendKeepAlive() + } + + case <-peer.timer.keepalivePassive.C: + + logDebug.Println("Sending keep-alive to", peer.String()) + + peer.SendKeepAlive() + + if peer.timer.needAnotherKeepalive { + peer.timer.keepalivePassive.Reset(KeepaliveTimeout) + peer.timer.needAnotherKeepalive = false + } + + // unresponsive session + + case <-peer.timer.newHandshake.C: + + logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply") + + signalSend(peer.signal.handshakeBegin) + + // clear key material + + case <-peer.timer.zeroAllKeys.C: + + logDebug.Println("Clearing all key material for", peer.String()) + + hs := &peer.handshake + hs.mutex.Lock() + + kp := &peer.keyPairs + kp.mutex.Lock() + + // unmap indecies + + indices.mutex.Lock() + if kp.previous != nil { + delete(indices.table, kp.previous.localIndex) + } + if kp.current != nil { + delete(indices.table, kp.current.localIndex) + } + if kp.next != nil { + delete(indices.table, kp.next.localIndex) + } + delete(indices.table, hs.localIndex) + indices.mutex.Unlock() + + // zero out key pairs (TODO: better than wait for GC) + + kp.current = nil + kp.previous = nil + kp.next = nil + kp.mutex.Unlock() + + // zero out handshake + + hs.localIndex = 0 + hs.localEphemeral = NoisePrivateKey{} + hs.remoteEphemeral = NoisePublicKey{} + hs.chainKey = [blake2s.Size]byte{} + hs.hash = [blake2s.Size]byte{} + hs.mutex.Unlock() + } + } +} + +/* This is the state machine for handshake initiation + * + * Associated with this routine is the signal "handshakeBegin" + * The routine will read from the "handshakeBegin" channel + * at most every RekeyTimeout seconds + */ +func (peer *Peer) RoutineHandshakeInitiator() { + device := peer.device + + logInfo := device.log.Info + logError := device.log.Error + logDebug := device.log.Debug + logDebug.Println("Routine, handshake initator, started for", peer.String()) + + var temp [256]byte + + for { + + // wait for signal + + select { + case <-peer.signal.handshakeBegin: + case <-peer.signal.stop: + return + } + + // set deadline + + BeginHandshakes: + + signalClear(peer.signal.handshakeReset) + deadline := time.NewTimer(RekeyAttemptTime) + + AttemptHandshakes: + + for attempts := uint(1); ; attempts++ { + + // check if deadline reached + + select { + case <-deadline.C: + logInfo.Println("Handshake negotiation timed out for:", peer.String()) + signalSend(peer.signal.flushNonceQueue) + timerStop(peer.timer.keepalivePersistent) + break + case <-peer.signal.stop: + return + default: + } + + signalClear(peer.signal.handshakeCompleted) + + // create initiation message + + msg, err := peer.device.CreateMessageInitiation(peer) + if err != nil { + logError.Println("Failed to create handshake initiation message:", err) + break AttemptHandshakes + } + + jitter := time.Millisecond * time.Duration(rand.Uint32()%334) + + // marshal and send + + writer := bytes.NewBuffer(temp[:0]) + binary.Write(writer, binary.LittleEndian, msg) + packet := writer.Bytes() + peer.mac.AddMacs(packet) + + _, err = peer.SendBuffer(packet) + if err != nil { + logError.Println( + "Failed to send handshake initiation message to", + peer.String(), ":", err, + ) + continue + } + + peer.TimerAnyAuthenticatedPacketTraversal() + + // set handshake timeout + + timeout := time.NewTimer(RekeyTimeout + jitter) + logDebug.Println( + "Handshake initiation attempt", + attempts, "sent to", peer.String(), + ) + + // wait for handshake or timeout + + select { + + case <-peer.signal.stop: + return + + case <-peer.signal.handshakeCompleted: + <-timeout.C + break AttemptHandshakes + + case <-peer.signal.handshakeReset: + <-timeout.C + goto BeginHandshakes + + case <-timeout.C: + // TODO: Clear source address for peer + continue + } + } + + // clear signal set in the meantime + + signalClear(peer.signal.handshakeBegin) + } +} diff --git a/src/tun_windows.go b/src/tun_windows.go new file mode 100644 index 0000000..0711032 --- /dev/null +++ b/src/tun_windows.go @@ -0,0 +1,475 @@ +package main + +import ( + "encoding/binary" + "errors" + "fmt" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + "net" + "sync" + "syscall" + "time" + "unsafe" +) + +/* Relies on the OpenVPN TAP-Windows driver (NDIS 6 version) + * + * https://github.com/OpenVPN/tap-windows + */ + +type NativeTUN struct { + fd windows.Handle + rl sync.Mutex + wl sync.Mutex + ro *windows.Overlapped + wo *windows.Overlapped + events chan TUNEvent + name string +} + +const ( + METHOD_BUFFERED = 0 + ComponentID = "tap0901" // tap0801 +) + +func ctl_code(device_type, function, method, access uint32) uint32 { + return (device_type << 16) | (access << 14) | (function << 2) | method +} + +func TAP_CONTROL_CODE(request, method uint32) uint32 { + return ctl_code(file_device_unknown, request, method, 0) +} + +var ( + errIfceNameNotFound = errors.New("Failed to find the name of interface") + + TAP_IOCTL_GET_MAC = TAP_CONTROL_CODE(1, METHOD_BUFFERED) + TAP_IOCTL_GET_VERSION = TAP_CONTROL_CODE(2, METHOD_BUFFERED) + TAP_IOCTL_GET_MTU = TAP_CONTROL_CODE(3, METHOD_BUFFERED) + TAP_IOCTL_GET_INFO = TAP_CONTROL_CODE(4, METHOD_BUFFERED) + TAP_IOCTL_CONFIG_POINT_TO_POINT = TAP_CONTROL_CODE(5, METHOD_BUFFERED) + TAP_IOCTL_SET_MEDIA_STATUS = TAP_CONTROL_CODE(6, METHOD_BUFFERED) + TAP_IOCTL_CONFIG_DHCP_MASQ = TAP_CONTROL_CODE(7, METHOD_BUFFERED) + TAP_IOCTL_GET_LOG_LINE = TAP_CONTROL_CODE(8, METHOD_BUFFERED) + TAP_IOCTL_CONFIG_DHCP_SET_OPT = TAP_CONTROL_CODE(9, METHOD_BUFFERED) + TAP_IOCTL_CONFIG_TUN = TAP_CONTROL_CODE(10, METHOD_BUFFERED) + + file_device_unknown = uint32(0x00000022) + nCreateEvent, + nResetEvent, + nGetOverlappedResult uintptr +) + +func init() { + k32, err := windows.LoadLibrary("kernel32.dll") + if err != nil { + panic("LoadLibrary " + err.Error()) + } + defer windows.FreeLibrary(k32) + nCreateEvent = getProcAddr(k32, "CreateEventW") + nResetEvent = getProcAddr(k32, "ResetEvent") + nGetOverlappedResult = getProcAddr(k32, "GetOverlappedResult") +} + +/* implementation of the read/write/closer interface */ + +func getProcAddr(lib windows.Handle, name string) uintptr { + addr, err := windows.GetProcAddress(lib, name) + if err != nil { + panic(name + " " + err.Error()) + } + return addr +} + +func resetEvent(h windows.Handle) error { + r, _, err := syscall.Syscall(nResetEvent, 1, uintptr(h), 0, 0) + if r == 0 { + return err + } + return nil +} + +func getOverlappedResult(h windows.Handle, overlapped *windows.Overlapped) (int, error) { + var n int + r, _, err := syscall.Syscall6( + nGetOverlappedResult, + 4, + uintptr(h), + uintptr(unsafe.Pointer(overlapped)), + uintptr(unsafe.Pointer(&n)), 1, 0, 0) + + if r == 0 { + return n, err + } + return n, nil +} + +func newOverlapped() (*windows.Overlapped, error) { + var overlapped windows.Overlapped + r, _, err := syscall.Syscall6(nCreateEvent, 4, 0, 1, 0, 0, 0, 0) + if r == 0 { + return nil, err + } + overlapped.HEvent = windows.Handle(r) + return &overlapped, nil +} + +func (f *NativeTUN) Events() chan TUNEvent { + return f.events +} + +func (f *NativeTUN) Close() error { + return windows.Close(f.fd) +} + +func (f *NativeTUN) Write(b []byte) (int, error) { + f.wl.Lock() + defer f.wl.Unlock() + + if err := resetEvent(f.wo.HEvent); err != nil { + return 0, err + } + var n uint32 + err := windows.WriteFile(f.fd, b, &n, f.wo) + if err != nil && err != windows.ERROR_IO_PENDING { + return int(n), err + } + return getOverlappedResult(f.fd, f.wo) +} + +func (f *NativeTUN) Read(b []byte) (int, error) { + f.rl.Lock() + defer f.rl.Unlock() + + if err := resetEvent(f.ro.HEvent); err != nil { + return 0, err + } + var done uint32 + err := windows.ReadFile(f.fd, b, &done, f.ro) + if err != nil && err != windows.ERROR_IO_PENDING { + return int(done), err + } + return getOverlappedResult(f.fd, f.ro) +} + +func getdeviceid( + targetComponentId string, + targetDeviceName string, +) (deviceid string, err error) { + + getName := func(instanceId string) (string, error) { + path := fmt.Sprintf( + `SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s\Connection`, + instanceId, + ) + + key, err := registry.OpenKey( + registry.LOCAL_MACHINE, + path, + registry.READ, + ) + + if err != nil { + return "", err + } + defer key.Close() + + val, _, err := key.GetStringValue("Name") + key.Close() + return val, err + } + + getInstanceId := func(keyName string) (string, string, error) { + path := fmt.Sprintf( + `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s`, + keyName, + ) + + key, err := registry.OpenKey( + registry.LOCAL_MACHINE, + path, + registry.READ, + ) + + if err != nil { + return "", "", err + } + defer key.Close() + + componentId, _, err := key.GetStringValue("ComponentId") + if err != nil { + return "", "", err + } + + instanceId, _, err := key.GetStringValue("NetCfgInstanceId") + + return componentId, instanceId, err + } + + // find list of all network devices + + k, err := registry.OpenKey( + registry.LOCAL_MACHINE, + `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}`, + registry.READ, + ) + + if err != nil { + return "", fmt.Errorf("Failed to open the adapter registry, TAP driver may be not installed, %v", err) + } + + defer k.Close() + + keys, err := k.ReadSubKeyNames(-1) + + if err != nil { + return "", err + } + + // look for matching component id and name + + var componentFound bool + + for _, v := range keys { + + componentId, instanceId, err := getInstanceId(v) + if err != nil || componentId != targetComponentId { + continue + } + + componentFound = true + + deviceName, err := getName(instanceId) + if err != nil || deviceName != targetDeviceName { + continue + } + + return instanceId, nil + } + + // provide a descriptive error message + + if componentFound { + return "", fmt.Errorf("Unable to find tun/tap device with name = %s", targetDeviceName) + } + + return "", fmt.Errorf( + "Unable to find device in registry with ComponentId = %s, is tap-windows installed?", + targetComponentId, + ) +} + +// setStatus is used to bring up or bring down the interface +func setStatus(fd windows.Handle, status bool) error { + var code [4]byte + if status { + binary.LittleEndian.PutUint32(code[:], 1) + } + + var bytesReturned uint32 + rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE) + return windows.DeviceIoControl( + fd, + TAP_IOCTL_SET_MEDIA_STATUS, + &code[0], + uint32(4), + &rdbbuf[0], + uint32(len(rdbbuf)), + &bytesReturned, + nil, + ) +} + +/* When operating in TUN mode we must assign an ip address & subnet to the device. + * + */ +func setTUN(fd windows.Handle, network string) error { + var bytesReturned uint32 + rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE) + localIP, remoteNet, err := net.ParseCIDR(network) + + if err != nil { + return fmt.Errorf("Failed to parse network CIDR in config, %v", err) + } + + if localIP.To4() == nil { + return fmt.Errorf("Provided network(%s) is not a valid IPv4 address", network) + } + + var param [12]byte + + copy(param[0:4], localIP.To4()) + copy(param[4:8], remoteNet.IP.To4()) + copy(param[8:12], remoteNet.Mask) + + return windows.DeviceIoControl( + fd, + TAP_IOCTL_CONFIG_TUN, + ¶m[0], + uint32(12), + &rdbbuf[0], + uint32(len(rdbbuf)), + &bytesReturned, + nil, + ) +} + +func (tun *NativeTUN) MTU() (int, error) { + var mtu [4]byte + var bytesReturned uint32 + err := windows.DeviceIoControl( + tun.fd, + TAP_IOCTL_GET_MTU, + &mtu[0], + uint32(len(mtu)), + &mtu[0], + uint32(len(mtu)), + &bytesReturned, + nil, + ) + val := binary.LittleEndian.Uint32(mtu[:]) + return int(val), err +} + +func (tun *NativeTUN) Name() string { + return tun.name +} + +func CreateTUN(name string) (TUNDevice, error) { + + // find the device in registry. + + deviceid, err := getdeviceid(ComponentID, name) + if err != nil { + return nil, err + } + path := "\\\\.\\Global\\" + deviceid + ".tap" + pathp, err := windows.UTF16PtrFromString(path) + if err != nil { + return nil, err + } + + // create TUN device + + handle, err := windows.CreateFile( + pathp, + windows.GENERIC_READ|windows.GENERIC_WRITE, + 0, + nil, + windows.OPEN_EXISTING, + windows.FILE_ATTRIBUTE_SYSTEM|windows.FILE_FLAG_OVERLAPPED, + 0, + ) + + if err != nil { + return nil, err + } + + ro, err := newOverlapped() + if err != nil { + windows.Close(handle) + return nil, err + } + + wo, err := newOverlapped() + if err != nil { + windows.Close(handle) + return nil, err + } + + tun := &NativeTUN{ + fd: handle, + name: name, + ro: ro, + wo: wo, + events: make(chan TUNEvent, 5), + } + + // find addresses of interface + // TODO: fix this hack, the question is how + + inter, err := net.InterfaceByName(name) + if err != nil { + windows.Close(handle) + return nil, err + } + + addrs, err := inter.Addrs() + if err != nil { + windows.Close(handle) + return nil, err + } + + var ip net.IP + for _, addr := range addrs { + ip = func() net.IP { + switch v := addr.(type) { + case *net.IPNet: + return v.IP.To4() + case *net.IPAddr: + return v.IP.To4() + } + return nil + }() + if ip != nil { + break + } + } + + if ip == nil { + windows.Close(handle) + return nil, errors.New("No IPv4 address found for interface") + } + + // bring up device. + + if err := setStatus(handle, true); err != nil { + windows.Close(handle) + return nil, err + } + + // set tun mode + + mask := ip.String() + "/0" + if err := setTUN(handle, mask); err != nil { + windows.Close(handle) + return nil, err + } + + // start listener + + go func(native *NativeTUN, ifname string) { + // TODO: Fix this very niave implementation + var ( + statusUp bool + statusMTU int + ) + + for ; ; time.Sleep(time.Second) { + intr, err := net.InterfaceByName(name) + if err != nil { + // TODO: handle + return + } + + // Up / Down event + up := (intr.Flags & net.FlagUp) != 0 + if up != statusUp && up { + native.events <- TUNEventUp + } + if up != statusUp && !up { + native.events <- TUNEventDown + } + statusUp = up + + // MTU changes + if intr.MTU != statusMTU { + native.events <- TUNEventMTUUpdate + } + statusMTU = intr.MTU + } + }(tun, name) + + return tun, nil +} diff --git a/src/uapi_windows.go b/src/uapi_windows.go new file mode 100644 index 0000000..d56e965 --- /dev/null +++ b/src/uapi_windows.go @@ -0,0 +1,44 @@ +package main + +/* UAPI on windows uses a bidirectional named pipe + */ + +import ( + "fmt" + "github.com/Microsoft/go-winio" + "golang.org/x/sys/windows" + "net" +) + +const ( + ipcErrorIO = -int64(windows.ERROR_BROKEN_PIPE) + ipcErrorNotDefined = -int64(windows.ERROR_SERVICE_SPECIFIC_ERROR) + ipcErrorProtocol = -int64(windows.ERROR_SERVICE_SPECIFIC_ERROR) + ipcErrorInvalid = -int64(windows.ERROR_SERVICE_SPECIFIC_ERROR) +) + +const PipeNameFmt = "\\\\.\\pipe\\wireguard-ipc-%s" + +type UAPIListener struct { + listener net.Listener +} + +func (uapi *UAPIListener) Accept() (net.Conn, error) { + return nil, nil +} + +func (uapi *UAPIListener) Close() error { + return uapi.listener.Close() +} + +func (uapi *UAPIListener) Addr() net.Addr { + return nil +} + +func NewUAPIListener(name string) (net.Listener, error) { + path := fmt.Sprintf(PipeNameFmt, name) + return winio.ListenPipe(path, &winio.PipeConfig{ + InputBufferSize: 2048, + OutputBufferSize: 2048, + }) +}