Added code from windows branch

This commit is contained in:
Mathias Hall-Andersen 2017-08-27 15:41:00 +02:00
parent eafa3df606
commit 6f5ef153c3
6 changed files with 896 additions and 337 deletions

6
src/build.cmd Executable file
View file

@ -0,0 +1,6 @@
@echo off
REM builds wireguard for windows
go get
go build -o wireguard-go.exe

View file

@ -6,6 +6,6 @@ import (
"net" "net"
) )
func setFwmark(conn *net.UDPConn, value int) error { func setMark(conn *net.UDPConn, value int) error {
return nil return nil
} }

34
src/daemon_windows.go Normal file
View file

@ -0,0 +1,34 @@
package main
import (
"os"
)
/* Daemonizes the process on windows
*
* This is done by spawning and releasing a copy with the --foreground flag
*/
func Daemonize() error {
argv := []string{os.Args[0], "--foreground"}
argv = append(argv, os.Args[1:]...)
attr := &os.ProcAttr{
Dir: ".",
Env: os.Environ(),
Files: []*os.File{
os.Stdin,
nil,
nil,
},
}
process, err := os.StartProcess(
argv[0],
argv,
attr,
)
if err != nil {
return err
}
process.Release()
return nil
}

View file

@ -1,336 +1,336 @@
package main package main
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"math/rand" "math/rand"
"sync/atomic" "sync/atomic"
"time" "time"
) )
/* Called when a new authenticated message has been send /* Called when a new authenticated message has been send
* *
*/ */
func (peer *Peer) KeepKeyFreshSending() { func (peer *Peer) KeepKeyFreshSending() {
kp := peer.keyPairs.Current() kp := peer.keyPairs.Current()
if kp == nil { if kp == nil {
return return
} }
nonce := atomic.LoadUint64(&kp.sendNonce) nonce := atomic.LoadUint64(&kp.sendNonce)
if nonce > RekeyAfterMessages { if nonce > RekeyAfterMessages {
signalSend(peer.signal.handshakeBegin) signalSend(peer.signal.handshakeBegin)
} }
if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime { if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
signalSend(peer.signal.handshakeBegin) signalSend(peer.signal.handshakeBegin)
} }
} }
/* Called when a new authenticated message has been recevied /* Called when a new authenticated message has been recevied
* *
*/ */
func (peer *Peer) KeepKeyFreshReceiving() { func (peer *Peer) KeepKeyFreshReceiving() {
// TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete) // TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete)
kp := peer.keyPairs.Current() kp := peer.keyPairs.Current()
if kp == nil { if kp == nil {
return return
} }
if !kp.isInitiator { if !kp.isInitiator {
return return
} }
nonce := atomic.LoadUint64(&kp.sendNonce) nonce := atomic.LoadUint64(&kp.sendNonce)
send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
if send { if send {
signalSend(peer.signal.handshakeBegin) signalSend(peer.signal.handshakeBegin)
} }
} }
/* Queues a keep-alive if no packets are queued for peer /* Queues a keep-alive if no packets are queued for peer
*/ */
func (peer *Peer) SendKeepAlive() bool { func (peer *Peer) SendKeepAlive() bool {
elem := peer.device.NewOutboundElement() elem := peer.device.NewOutboundElement()
elem.packet = nil elem.packet = nil
if len(peer.queue.nonce) == 0 { if len(peer.queue.nonce) == 0 {
select { select {
case peer.queue.nonce <- elem: case peer.queue.nonce <- elem:
return true return true
default: default:
return false return false
} }
} }
return true return true
} }
/* Event: /* Event:
* Sent non-empty (authenticated) transport message * Sent non-empty (authenticated) transport message
*/ */
func (peer *Peer) TimerDataSent() { func (peer *Peer) TimerDataSent() {
timerStop(peer.timer.keepalivePassive) timerStop(peer.timer.keepalivePassive)
if !peer.timer.pendingNewHandshake { if !peer.timer.pendingNewHandshake {
peer.timer.pendingNewHandshake = true peer.timer.pendingNewHandshake = true
peer.timer.newHandshake.Reset(NewHandshakeTime) peer.timer.newHandshake.Reset(NewHandshakeTime)
} }
} }
/* Event: /* Event:
* Received non-empty (authenticated) transport message * Received non-empty (authenticated) transport message
*/ */
func (peer *Peer) TimerDataReceived() { func (peer *Peer) TimerDataReceived() {
if peer.timer.pendingKeepalivePassive { if peer.timer.pendingKeepalivePassive {
peer.timer.needAnotherKeepalive = true peer.timer.needAnotherKeepalive = true
return return
} }
peer.timer.pendingKeepalivePassive = false peer.timer.pendingKeepalivePassive = false
peer.timer.keepalivePassive.Reset(KeepaliveTimeout) peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
} }
/* Event: /* Event:
* Any (authenticated) packet received * Any (authenticated) packet received
*/ */
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() { func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
timerStop(peer.timer.newHandshake) timerStop(peer.timer.newHandshake)
} }
/* Event: /* Event:
* Any authenticated packet send / received. * Any authenticated packet send / received.
*/ */
func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 { if interval > 0 {
duration := time.Duration(interval) * time.Second duration := time.Duration(interval) * time.Second
peer.timer.keepalivePersistent.Reset(duration) peer.timer.keepalivePersistent.Reset(duration)
} }
} }
/* Called after succesfully completing a handshake. /* Called after succesfully completing a handshake.
* i.e. after: * i.e. after:
* *
* - Valid handshake response * - Valid handshake response
* - First transport message under the "next" key * - First transport message under the "next" key
*/ */
func (peer *Peer) TimerHandshakeComplete() { func (peer *Peer) TimerHandshakeComplete() {
atomic.StoreInt64( atomic.StoreInt64(
&peer.stats.lastHandshakeNano, &peer.stats.lastHandshakeNano,
time.Now().UnixNano(), time.Now().UnixNano(),
) )
signalSend(peer.signal.handshakeCompleted) signalSend(peer.signal.handshakeCompleted)
peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
} }
/* Event: /* Event:
* An ephemeral key is generated * An ephemeral key is generated
* *
* i.e after: * i.e after:
* *
* CreateMessageInitiation * CreateMessageInitiation
* CreateMessageResponse * CreateMessageResponse
* *
* Schedules the deletion of all key material * Schedules the deletion of all key material
* upon failure to complete a handshake * upon failure to complete a handshake
*/ */
func (peer *Peer) TimerEphemeralKeyCreated() { func (peer *Peer) TimerEphemeralKeyCreated() {
peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
} }
func (peer *Peer) RoutineTimerHandler() { func (peer *Peer) RoutineTimerHandler() {
device := peer.device device := peer.device
indices := &device.indices indices := &device.indices
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.String()) logDebug.Println("Routine, timer handler, started for peer", peer.String())
for { for {
select { select {
case <-peer.signal.stop: case <-peer.signal.stop:
return return
// keep-alives // keep-alives
case <-peer.timer.keepalivePersistent.C: case <-peer.timer.keepalivePersistent.C:
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 { if interval > 0 {
logDebug.Println("Sending keep-alive to", peer.String()) logDebug.Println("Sending keep-alive to", peer.String())
peer.SendKeepAlive() peer.SendKeepAlive()
} }
case <-peer.timer.keepalivePassive.C: case <-peer.timer.keepalivePassive.C:
logDebug.Println("Sending keep-alive to", peer.String()) logDebug.Println("Sending keep-alive to", peer.String())
peer.SendKeepAlive() peer.SendKeepAlive()
if peer.timer.needAnotherKeepalive { if peer.timer.needAnotherKeepalive {
peer.timer.keepalivePassive.Reset(KeepaliveTimeout) peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
peer.timer.needAnotherKeepalive = false peer.timer.needAnotherKeepalive = false
} }
// unresponsive session // unresponsive session
case <-peer.timer.newHandshake.C: case <-peer.timer.newHandshake.C:
logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply") logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply")
signalSend(peer.signal.handshakeBegin) signalSend(peer.signal.handshakeBegin)
// clear key material // clear key material
case <-peer.timer.zeroAllKeys.C: case <-peer.timer.zeroAllKeys.C:
logDebug.Println("Clearing all key material for", peer.String()) logDebug.Println("Clearing all key material for", peer.String())
hs := &peer.handshake hs := &peer.handshake
hs.mutex.Lock() hs.mutex.Lock()
kp := &peer.keyPairs kp := &peer.keyPairs
kp.mutex.Lock() kp.mutex.Lock()
// unmap indecies // unmap indecies
indices.mutex.Lock() indices.mutex.Lock()
if kp.previous != nil { if kp.previous != nil {
delete(indices.table, kp.previous.localIndex) delete(indices.table, kp.previous.localIndex)
} }
if kp.current != nil { if kp.current != nil {
delete(indices.table, kp.current.localIndex) delete(indices.table, kp.current.localIndex)
} }
if kp.next != nil { if kp.next != nil {
delete(indices.table, kp.next.localIndex) delete(indices.table, kp.next.localIndex)
} }
delete(indices.table, hs.localIndex) delete(indices.table, hs.localIndex)
indices.mutex.Unlock() indices.mutex.Unlock()
// zero out key pairs (TODO: better than wait for GC) // zero out key pairs (TODO: better than wait for GC)
kp.current = nil kp.current = nil
kp.previous = nil kp.previous = nil
kp.next = nil kp.next = nil
kp.mutex.Unlock() kp.mutex.Unlock()
// zero out handshake // zero out handshake
hs.localIndex = 0 hs.localIndex = 0
hs.localEphemeral = NoisePrivateKey{} hs.localEphemeral = NoisePrivateKey{}
hs.remoteEphemeral = NoisePublicKey{} hs.remoteEphemeral = NoisePublicKey{}
hs.chainKey = [blake2s.Size]byte{} hs.chainKey = [blake2s.Size]byte{}
hs.hash = [blake2s.Size]byte{} hs.hash = [blake2s.Size]byte{}
hs.mutex.Unlock() hs.mutex.Unlock()
} }
} }
} }
/* This is the state machine for handshake initiation /* This is the state machine for handshake initiation
* *
* Associated with this routine is the signal "handshakeBegin" * Associated with this routine is the signal "handshakeBegin"
* The routine will read from the "handshakeBegin" channel * The routine will read from the "handshakeBegin" channel
* at most every RekeyTimeout seconds * at most every RekeyTimeout seconds
*/ */
func (peer *Peer) RoutineHandshakeInitiator() { func (peer *Peer) RoutineHandshakeInitiator() {
device := peer.device device := peer.device
logInfo := device.log.Info logInfo := device.log.Info
logError := device.log.Error logError := device.log.Error
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, handshake initator, started for", peer.String()) logDebug.Println("Routine, handshake initator, started for", peer.String())
var temp [256]byte var temp [256]byte
for { for {
// wait for signal // wait for signal
select { select {
case <-peer.signal.handshakeBegin: case <-peer.signal.handshakeBegin:
case <-peer.signal.stop: case <-peer.signal.stop:
return return
} }
// set deadline // set deadline
BeginHandshakes: BeginHandshakes:
signalClear(peer.signal.handshakeReset) signalClear(peer.signal.handshakeReset)
deadline := time.NewTimer(RekeyAttemptTime) deadline := time.NewTimer(RekeyAttemptTime)
AttemptHandshakes: AttemptHandshakes:
for attempts := uint(1); ; attempts++ { for attempts := uint(1); ; attempts++ {
// check if deadline reached // check if deadline reached
select { select {
case <-deadline.C: case <-deadline.C:
logInfo.Println("Handshake negotiation timed out for:", peer.String()) logInfo.Println("Handshake negotiation timed out for:", peer.String())
signalSend(peer.signal.flushNonceQueue) signalSend(peer.signal.flushNonceQueue)
timerStop(peer.timer.keepalivePersistent) timerStop(peer.timer.keepalivePersistent)
break break
case <-peer.signal.stop: case <-peer.signal.stop:
return return
default: default:
} }
signalClear(peer.signal.handshakeCompleted) signalClear(peer.signal.handshakeCompleted)
// create initiation message // create initiation message
msg, err := peer.device.CreateMessageInitiation(peer) msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil { if err != nil {
logError.Println("Failed to create handshake initiation message:", err) logError.Println("Failed to create handshake initiation message:", err)
break AttemptHandshakes break AttemptHandshakes
} }
jitter := time.Millisecond * time.Duration(rand.Uint32()%334) jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
// marshal and send // marshal and send
writer := bytes.NewBuffer(temp[:0]) writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, msg) binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes() packet := writer.Bytes()
peer.mac.AddMacs(packet) peer.mac.AddMacs(packet)
_, err = peer.SendBuffer(packet) _, err = peer.SendBuffer(packet)
if err != nil { if err != nil {
logError.Println( logError.Println(
"Failed to send handshake initiation message to", "Failed to send handshake initiation message to",
peer.String(), ":", err, peer.String(), ":", err,
) )
break continue
} }
peer.TimerAnyAuthenticatedPacketTraversal() peer.TimerAnyAuthenticatedPacketTraversal()
// set handshake timeout // set handshake timeout
timeout := time.NewTimer(RekeyTimeout + jitter) timeout := time.NewTimer(RekeyTimeout + jitter)
logDebug.Println( logDebug.Println(
"Handshake initiation attempt", "Handshake initiation attempt",
attempts, "sent to", peer.String(), attempts, "sent to", peer.String(),
) )
// wait for handshake or timeout // wait for handshake or timeout
select { select {
case <-peer.signal.stop: case <-peer.signal.stop:
return return
case <-peer.signal.handshakeCompleted: case <-peer.signal.handshakeCompleted:
<-timeout.C <-timeout.C
break AttemptHandshakes break AttemptHandshakes
case <-peer.signal.handshakeReset: case <-peer.signal.handshakeReset:
<-timeout.C <-timeout.C
goto BeginHandshakes goto BeginHandshakes
case <-timeout.C: case <-timeout.C:
// TODO: Clear source address for peer // TODO: Clear source address for peer
continue continue
} }
} }
// clear signal set in the meantime // clear signal set in the meantime
signalClear(peer.signal.handshakeBegin) signalClear(peer.signal.handshakeBegin)
} }
} }

475
src/tun_windows.go Normal file
View file

@ -0,0 +1,475 @@
package main
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"net"
"sync"
"syscall"
"time"
"unsafe"
)
/* Relies on the OpenVPN TAP-Windows driver (NDIS 6 version)
*
* https://github.com/OpenVPN/tap-windows
*/
type NativeTUN struct {
fd windows.Handle
rl sync.Mutex
wl sync.Mutex
ro *windows.Overlapped
wo *windows.Overlapped
events chan TUNEvent
name string
}
const (
METHOD_BUFFERED = 0
ComponentID = "tap0901" // tap0801
)
func ctl_code(device_type, function, method, access uint32) uint32 {
return (device_type << 16) | (access << 14) | (function << 2) | method
}
func TAP_CONTROL_CODE(request, method uint32) uint32 {
return ctl_code(file_device_unknown, request, method, 0)
}
var (
errIfceNameNotFound = errors.New("Failed to find the name of interface")
TAP_IOCTL_GET_MAC = TAP_CONTROL_CODE(1, METHOD_BUFFERED)
TAP_IOCTL_GET_VERSION = TAP_CONTROL_CODE(2, METHOD_BUFFERED)
TAP_IOCTL_GET_MTU = TAP_CONTROL_CODE(3, METHOD_BUFFERED)
TAP_IOCTL_GET_INFO = TAP_CONTROL_CODE(4, METHOD_BUFFERED)
TAP_IOCTL_CONFIG_POINT_TO_POINT = TAP_CONTROL_CODE(5, METHOD_BUFFERED)
TAP_IOCTL_SET_MEDIA_STATUS = TAP_CONTROL_CODE(6, METHOD_BUFFERED)
TAP_IOCTL_CONFIG_DHCP_MASQ = TAP_CONTROL_CODE(7, METHOD_BUFFERED)
TAP_IOCTL_GET_LOG_LINE = TAP_CONTROL_CODE(8, METHOD_BUFFERED)
TAP_IOCTL_CONFIG_DHCP_SET_OPT = TAP_CONTROL_CODE(9, METHOD_BUFFERED)
TAP_IOCTL_CONFIG_TUN = TAP_CONTROL_CODE(10, METHOD_BUFFERED)
file_device_unknown = uint32(0x00000022)
nCreateEvent,
nResetEvent,
nGetOverlappedResult uintptr
)
func init() {
k32, err := windows.LoadLibrary("kernel32.dll")
if err != nil {
panic("LoadLibrary " + err.Error())
}
defer windows.FreeLibrary(k32)
nCreateEvent = getProcAddr(k32, "CreateEventW")
nResetEvent = getProcAddr(k32, "ResetEvent")
nGetOverlappedResult = getProcAddr(k32, "GetOverlappedResult")
}
/* implementation of the read/write/closer interface */
func getProcAddr(lib windows.Handle, name string) uintptr {
addr, err := windows.GetProcAddress(lib, name)
if err != nil {
panic(name + " " + err.Error())
}
return addr
}
func resetEvent(h windows.Handle) error {
r, _, err := syscall.Syscall(nResetEvent, 1, uintptr(h), 0, 0)
if r == 0 {
return err
}
return nil
}
func getOverlappedResult(h windows.Handle, overlapped *windows.Overlapped) (int, error) {
var n int
r, _, err := syscall.Syscall6(
nGetOverlappedResult,
4,
uintptr(h),
uintptr(unsafe.Pointer(overlapped)),
uintptr(unsafe.Pointer(&n)), 1, 0, 0)
if r == 0 {
return n, err
}
return n, nil
}
func newOverlapped() (*windows.Overlapped, error) {
var overlapped windows.Overlapped
r, _, err := syscall.Syscall6(nCreateEvent, 4, 0, 1, 0, 0, 0, 0)
if r == 0 {
return nil, err
}
overlapped.HEvent = windows.Handle(r)
return &overlapped, nil
}
func (f *NativeTUN) Events() chan TUNEvent {
return f.events
}
func (f *NativeTUN) Close() error {
return windows.Close(f.fd)
}
func (f *NativeTUN) Write(b []byte) (int, error) {
f.wl.Lock()
defer f.wl.Unlock()
if err := resetEvent(f.wo.HEvent); err != nil {
return 0, err
}
var n uint32
err := windows.WriteFile(f.fd, b, &n, f.wo)
if err != nil && err != windows.ERROR_IO_PENDING {
return int(n), err
}
return getOverlappedResult(f.fd, f.wo)
}
func (f *NativeTUN) Read(b []byte) (int, error) {
f.rl.Lock()
defer f.rl.Unlock()
if err := resetEvent(f.ro.HEvent); err != nil {
return 0, err
}
var done uint32
err := windows.ReadFile(f.fd, b, &done, f.ro)
if err != nil && err != windows.ERROR_IO_PENDING {
return int(done), err
}
return getOverlappedResult(f.fd, f.ro)
}
func getdeviceid(
targetComponentId string,
targetDeviceName string,
) (deviceid string, err error) {
getName := func(instanceId string) (string, error) {
path := fmt.Sprintf(
`SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s\Connection`,
instanceId,
)
key, err := registry.OpenKey(
registry.LOCAL_MACHINE,
path,
registry.READ,
)
if err != nil {
return "", err
}
defer key.Close()
val, _, err := key.GetStringValue("Name")
key.Close()
return val, err
}
getInstanceId := func(keyName string) (string, string, error) {
path := fmt.Sprintf(
`SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s`,
keyName,
)
key, err := registry.OpenKey(
registry.LOCAL_MACHINE,
path,
registry.READ,
)
if err != nil {
return "", "", err
}
defer key.Close()
componentId, _, err := key.GetStringValue("ComponentId")
if err != nil {
return "", "", err
}
instanceId, _, err := key.GetStringValue("NetCfgInstanceId")
return componentId, instanceId, err
}
// find list of all network devices
k, err := registry.OpenKey(
registry.LOCAL_MACHINE,
`SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}`,
registry.READ,
)
if err != nil {
return "", fmt.Errorf("Failed to open the adapter registry, TAP driver may be not installed, %v", err)
}
defer k.Close()
keys, err := k.ReadSubKeyNames(-1)
if err != nil {
return "", err
}
// look for matching component id and name
var componentFound bool
for _, v := range keys {
componentId, instanceId, err := getInstanceId(v)
if err != nil || componentId != targetComponentId {
continue
}
componentFound = true
deviceName, err := getName(instanceId)
if err != nil || deviceName != targetDeviceName {
continue
}
return instanceId, nil
}
// provide a descriptive error message
if componentFound {
return "", fmt.Errorf("Unable to find tun/tap device with name = %s", targetDeviceName)
}
return "", fmt.Errorf(
"Unable to find device in registry with ComponentId = %s, is tap-windows installed?",
targetComponentId,
)
}
// setStatus is used to bring up or bring down the interface
func setStatus(fd windows.Handle, status bool) error {
var code [4]byte
if status {
binary.LittleEndian.PutUint32(code[:], 1)
}
var bytesReturned uint32
rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE)
return windows.DeviceIoControl(
fd,
TAP_IOCTL_SET_MEDIA_STATUS,
&code[0],
uint32(4),
&rdbbuf[0],
uint32(len(rdbbuf)),
&bytesReturned,
nil,
)
}
/* When operating in TUN mode we must assign an ip address & subnet to the device.
*
*/
func setTUN(fd windows.Handle, network string) error {
var bytesReturned uint32
rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE)
localIP, remoteNet, err := net.ParseCIDR(network)
if err != nil {
return fmt.Errorf("Failed to parse network CIDR in config, %v", err)
}
if localIP.To4() == nil {
return fmt.Errorf("Provided network(%s) is not a valid IPv4 address", network)
}
var param [12]byte
copy(param[0:4], localIP.To4())
copy(param[4:8], remoteNet.IP.To4())
copy(param[8:12], remoteNet.Mask)
return windows.DeviceIoControl(
fd,
TAP_IOCTL_CONFIG_TUN,
&param[0],
uint32(12),
&rdbbuf[0],
uint32(len(rdbbuf)),
&bytesReturned,
nil,
)
}
func (tun *NativeTUN) MTU() (int, error) {
var mtu [4]byte
var bytesReturned uint32
err := windows.DeviceIoControl(
tun.fd,
TAP_IOCTL_GET_MTU,
&mtu[0],
uint32(len(mtu)),
&mtu[0],
uint32(len(mtu)),
&bytesReturned,
nil,
)
val := binary.LittleEndian.Uint32(mtu[:])
return int(val), err
}
func (tun *NativeTUN) Name() string {
return tun.name
}
func CreateTUN(name string) (TUNDevice, error) {
// find the device in registry.
deviceid, err := getdeviceid(ComponentID, name)
if err != nil {
return nil, err
}
path := "\\\\.\\Global\\" + deviceid + ".tap"
pathp, err := windows.UTF16PtrFromString(path)
if err != nil {
return nil, err
}
// create TUN device
handle, err := windows.CreateFile(
pathp,
windows.GENERIC_READ|windows.GENERIC_WRITE,
0,
nil,
windows.OPEN_EXISTING,
windows.FILE_ATTRIBUTE_SYSTEM|windows.FILE_FLAG_OVERLAPPED,
0,
)
if err != nil {
return nil, err
}
ro, err := newOverlapped()
if err != nil {
windows.Close(handle)
return nil, err
}
wo, err := newOverlapped()
if err != nil {
windows.Close(handle)
return nil, err
}
tun := &NativeTUN{
fd: handle,
name: name,
ro: ro,
wo: wo,
events: make(chan TUNEvent, 5),
}
// find addresses of interface
// TODO: fix this hack, the question is how
inter, err := net.InterfaceByName(name)
if err != nil {
windows.Close(handle)
return nil, err
}
addrs, err := inter.Addrs()
if err != nil {
windows.Close(handle)
return nil, err
}
var ip net.IP
for _, addr := range addrs {
ip = func() net.IP {
switch v := addr.(type) {
case *net.IPNet:
return v.IP.To4()
case *net.IPAddr:
return v.IP.To4()
}
return nil
}()
if ip != nil {
break
}
}
if ip == nil {
windows.Close(handle)
return nil, errors.New("No IPv4 address found for interface")
}
// bring up device.
if err := setStatus(handle, true); err != nil {
windows.Close(handle)
return nil, err
}
// set tun mode
mask := ip.String() + "/0"
if err := setTUN(handle, mask); err != nil {
windows.Close(handle)
return nil, err
}
// start listener
go func(native *NativeTUN, ifname string) {
// TODO: Fix this very niave implementation
var (
statusUp bool
statusMTU int
)
for ; ; time.Sleep(time.Second) {
intr, err := net.InterfaceByName(name)
if err != nil {
// TODO: handle
return
}
// Up / Down event
up := (intr.Flags & net.FlagUp) != 0
if up != statusUp && up {
native.events <- TUNEventUp
}
if up != statusUp && !up {
native.events <- TUNEventDown
}
statusUp = up
// MTU changes
if intr.MTU != statusMTU {
native.events <- TUNEventMTUUpdate
}
statusMTU = intr.MTU
}
}(tun, name)
return tun, nil
}

44
src/uapi_windows.go Normal file
View file

@ -0,0 +1,44 @@
package main
/* UAPI on windows uses a bidirectional named pipe
*/
import (
"fmt"
"github.com/Microsoft/go-winio"
"golang.org/x/sys/windows"
"net"
)
const (
ipcErrorIO = -int64(windows.ERROR_BROKEN_PIPE)
ipcErrorNotDefined = -int64(windows.ERROR_SERVICE_SPECIFIC_ERROR)
ipcErrorProtocol = -int64(windows.ERROR_SERVICE_SPECIFIC_ERROR)
ipcErrorInvalid = -int64(windows.ERROR_SERVICE_SPECIFIC_ERROR)
)
const PipeNameFmt = "\\\\.\\pipe\\wireguard-ipc-%s"
type UAPIListener struct {
listener net.Listener
}
func (uapi *UAPIListener) Accept() (net.Conn, error) {
return nil, nil
}
func (uapi *UAPIListener) Close() error {
return uapi.listener.Close()
}
func (uapi *UAPIListener) Addr() net.Addr {
return nil
}
func NewUAPIListener(name string) (net.Listener, error) {
path := fmt.Sprintf(PipeNameFmt, name)
return winio.ListenPipe(path, &winio.PipeConfig{
InputBufferSize: 2048,
OutputBufferSize: 2048,
})
}