First set of code review patches
This commit is contained in:
parent
22c83f4b8d
commit
8c34c4cbb3
229
src/config.go
229
src/config.go
|
@ -61,6 +61,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||||
send(fmt.Sprintf("persistent_keepalive_interval=%d",
|
send(fmt.Sprintf("persistent_keepalive_interval=%d",
|
||||||
atomic.LoadUint64(&peer.persistentKeepaliveInterval),
|
atomic.LoadUint64(&peer.persistentKeepaliveInterval),
|
||||||
))
|
))
|
||||||
|
|
||||||
for _, ip := range device.routingTable.AllowedIPs(peer) {
|
for _, ip := range device.routingTable.AllowedIPs(peer) {
|
||||||
send("allowed_ip=" + ip.String())
|
send("allowed_ip=" + ip.String())
|
||||||
}
|
}
|
||||||
|
@ -89,6 +90,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
|
|
||||||
var peer *Peer
|
var peer *Peer
|
||||||
|
|
||||||
|
deviceConfig := true
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
|
|
||||||
// parse line
|
// parse line
|
||||||
|
@ -99,86 +103,110 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||||
}
|
}
|
||||||
parts := strings.Split(line, "=")
|
parts := strings.Split(line, "=")
|
||||||
if len(parts) != 2 {
|
if len(parts) != 2 {
|
||||||
return &IPCError{Code: ipcErrorNoKeyValue}
|
return &IPCError{Code: ipcErrorProtocol}
|
||||||
}
|
}
|
||||||
key := parts[0]
|
key := parts[0]
|
||||||
value := parts[1]
|
value := parts[1]
|
||||||
|
|
||||||
switch key {
|
/* device configuration */
|
||||||
|
|
||||||
/* interface configuration */
|
if deviceConfig {
|
||||||
|
|
||||||
case "private_key":
|
switch key {
|
||||||
var sk NoisePrivateKey
|
case "private_key":
|
||||||
if value == "" {
|
var sk NoisePrivateKey
|
||||||
device.SetPrivateKey(sk)
|
if value == "" {
|
||||||
} else {
|
device.SetPrivateKey(sk)
|
||||||
err := sk.FromHex(value)
|
} else {
|
||||||
|
err := sk.FromHex(value)
|
||||||
|
if err != nil {
|
||||||
|
logError.Println("Failed to set private_key:", err)
|
||||||
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
|
}
|
||||||
|
device.SetPrivateKey(sk)
|
||||||
|
}
|
||||||
|
|
||||||
|
case "listen_port":
|
||||||
|
port, err := strconv.ParseUint(value, 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to set private_key:", err)
|
logError.Println("Failed to set listen_port:", err)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
}
|
}
|
||||||
device.SetPrivateKey(sk)
|
netc := &device.net
|
||||||
}
|
netc.mutex.Lock()
|
||||||
|
if netc.addr.Port != int(port) {
|
||||||
case "listen_port":
|
if netc.conn != nil {
|
||||||
port, err := strconv.ParseUint(value, 10, 16)
|
netc.conn.Close()
|
||||||
if err != nil {
|
}
|
||||||
logError.Println("Failed to set listen_port:", err)
|
netc.addr.Port = int(port)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
netc.conn, err = net.ListenUDP("udp", netc.addr)
|
||||||
}
|
|
||||||
netc := &device.net
|
|
||||||
netc.mutex.Lock()
|
|
||||||
if netc.addr.Port != int(port) {
|
|
||||||
if netc.conn != nil {
|
|
||||||
netc.conn.Close()
|
|
||||||
}
|
}
|
||||||
netc.addr.Port = int(port)
|
netc.mutex.Unlock()
|
||||||
netc.conn, err = net.ListenUDP("udp", netc.addr)
|
if err != nil {
|
||||||
}
|
logError.Println("Failed to create UDP listener:", err)
|
||||||
netc.mutex.Unlock()
|
return &IPCError{Code: ipcErrorIO}
|
||||||
if err != nil {
|
}
|
||||||
logError.Println("Failed to create UDP listener:", err)
|
// TODO: Clear source address of all peers
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "fwmark":
|
case "fwmark":
|
||||||
logError.Println("FWMark not handled yet")
|
logError.Println("FWMark not handled yet")
|
||||||
|
// TODO: Clear source address of all peers
|
||||||
|
|
||||||
case "public_key":
|
case "public_key":
|
||||||
var pubKey NoisePublicKey
|
|
||||||
err := pubKey.FromHex(value)
|
|
||||||
if err != nil {
|
|
||||||
logError.Println("Failed to get peer by public_key:", err)
|
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
|
||||||
}
|
|
||||||
device.mutex.RLock()
|
|
||||||
peer, _ = device.peers[pubKey]
|
|
||||||
device.mutex.RUnlock()
|
|
||||||
if peer == nil {
|
|
||||||
peer = device.NewPeer(pubKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
case "replace_peers":
|
// switch to peer configuration
|
||||||
if value == "true" {
|
|
||||||
|
deviceConfig = false
|
||||||
|
|
||||||
|
case "replace_peers":
|
||||||
|
if value != "true" {
|
||||||
|
logError.Println("Failed to set replace_peers, invalid value:", value)
|
||||||
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
|
}
|
||||||
device.RemoveAllPeers()
|
device.RemoveAllPeers()
|
||||||
} else {
|
|
||||||
logError.Println("Failed to set replace_peers, invalid value:", value)
|
default:
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
logError.Println("Invalid UAPI key (device configuration):", key)
|
||||||
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
/* peer configuration */
|
||||||
|
|
||||||
/* peer configuration */
|
if !deviceConfig {
|
||||||
|
|
||||||
if peer == nil {
|
|
||||||
logError.Println("No peer referenced, before peer operation")
|
|
||||||
return &IPCError{Code: ipcErrorNoPeer}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch key {
|
switch key {
|
||||||
|
|
||||||
|
case "public_key":
|
||||||
|
var pubKey NoisePublicKey
|
||||||
|
err := pubKey.FromHex(value)
|
||||||
|
if err != nil {
|
||||||
|
logError.Println("Failed to get peer by public_key:", err)
|
||||||
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if public key of peer equal to device
|
||||||
|
|
||||||
|
device.mutex.RLock()
|
||||||
|
if device.publicKey.Equals(pubKey) {
|
||||||
|
device.mutex.RUnlock()
|
||||||
|
logError.Println("Public key of peer matches private key of device")
|
||||||
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
|
}
|
||||||
|
|
||||||
|
// find peer referenced
|
||||||
|
|
||||||
|
peer, _ = device.peers[pubKey]
|
||||||
|
device.mutex.RUnlock()
|
||||||
|
if peer == nil {
|
||||||
|
peer = device.NewPeer(pubKey)
|
||||||
|
}
|
||||||
|
|
||||||
case "remove":
|
case "remove":
|
||||||
|
if value != "true" {
|
||||||
|
logError.Println("Failed to set remove, invalid value:", value)
|
||||||
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
|
}
|
||||||
device.RemovePeer(peer.handshake.remoteStatic)
|
device.RemovePeer(peer.handshake.remoteStatic)
|
||||||
logDebug.Println("Removing", peer.String())
|
logDebug.Println("Removing", peer.String())
|
||||||
peer = nil
|
peer = nil
|
||||||
|
@ -191,50 +219,67 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||||
}()
|
}()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to set preshared_key:", err)
|
logError.Println("Failed to set preshared_key:", err)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
}
|
}
|
||||||
|
|
||||||
case "endpoint":
|
case "endpoint":
|
||||||
|
// TODO: Only IP and port
|
||||||
addr, err := net.ResolveUDPAddr("udp", value)
|
addr, err := net.ResolveUDPAddr("udp", value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to set endpoint:", value)
|
logError.Println("Failed to set endpoint:", value)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
}
|
}
|
||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
peer.endpoint = addr
|
peer.endpoint = addr
|
||||||
peer.mutex.Unlock()
|
peer.mutex.Unlock()
|
||||||
|
|
||||||
case "persistent_keepalive_interval":
|
case "persistent_keepalive_interval":
|
||||||
secs, err := strconv.ParseInt(value, 10, 64)
|
|
||||||
if secs < 0 || err != nil {
|
// update keep-alive interval
|
||||||
|
|
||||||
|
secs, err := strconv.ParseUint(value, 10, 16)
|
||||||
|
if err != nil {
|
||||||
logError.Println("Failed to set persistent_keepalive_interval:", err)
|
logError.Println("Failed to set persistent_keepalive_interval:", err)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
}
|
}
|
||||||
atomic.StoreUint64(
|
|
||||||
|
old := atomic.SwapUint64(
|
||||||
&peer.persistentKeepaliveInterval,
|
&peer.persistentKeepaliveInterval,
|
||||||
uint64(secs),
|
secs,
|
||||||
)
|
)
|
||||||
|
|
||||||
case "replace_allowed_ips":
|
// send immediate keep-alive
|
||||||
if value == "true" {
|
|
||||||
device.routingTable.RemovePeer(peer)
|
if old == 0 && secs != 0 {
|
||||||
} else {
|
up, err := device.tun.IsUp()
|
||||||
logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
|
if err != nil {
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
logError.Println("Failed to get tun device status:", err)
|
||||||
|
return &IPCError{Code: ipcErrorIO}
|
||||||
|
}
|
||||||
|
if up {
|
||||||
|
peer.SendKeepAlive()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "replace_allowed_ips":
|
||||||
|
if value != "true" {
|
||||||
|
logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
|
||||||
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
|
}
|
||||||
|
device.routingTable.RemovePeer(peer)
|
||||||
|
|
||||||
case "allowed_ip":
|
case "allowed_ip":
|
||||||
_, network, err := net.ParseCIDR(value)
|
_, network, err := net.ParseCIDR(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to set allowed_ip:", err)
|
logError.Println("Failed to set allowed_ip:", err)
|
||||||
return &IPCError{Code: ipcErrorInvalidValue}
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
}
|
}
|
||||||
ones, _ := network.Mask.Size()
|
ones, _ := network.Mask.Size()
|
||||||
device.routingTable.Insert(network.IP, uint(ones), peer)
|
device.routingTable.Insert(network.IP, uint(ones), peer)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logError.Println("Invalid UAPI key:", key)
|
logError.Println("Invalid UAPI key (peer configuration):", key)
|
||||||
return &IPCError{Code: ipcErrorInvalidKey}
|
return &IPCError{Code: ipcErrorInvalid}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -244,6 +289,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
||||||
|
|
||||||
func ipcHandle(device *Device, socket net.Conn) {
|
func ipcHandle(device *Device, socket net.Conn) {
|
||||||
|
|
||||||
|
// create buffered read/writer
|
||||||
|
|
||||||
defer socket.Close()
|
defer socket.Close()
|
||||||
|
|
||||||
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
|
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
|
||||||
|
@ -259,30 +306,30 @@ func ipcHandle(device *Device, socket net.Conn) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch op {
|
// handle operation
|
||||||
|
|
||||||
|
var status *IPCError
|
||||||
|
|
||||||
|
switch op {
|
||||||
case "set=1\n":
|
case "set=1\n":
|
||||||
device.log.Debug.Println("Config, set operation")
|
device.log.Debug.Println("Config, set operation")
|
||||||
err := ipcSetOperation(device, buffered)
|
status = ipcSetOperation(device, buffered)
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(buffered, "errno=0\n\n")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
|
|
||||||
case "get=1\n":
|
case "get=1\n":
|
||||||
device.log.Debug.Println("Config, get operation")
|
device.log.Debug.Println("Config, get operation")
|
||||||
err := ipcGetOperation(device, buffered)
|
status = ipcGetOperation(device, buffered)
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(buffered, "errno=0\n\n")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
device.log.Error.Println("Invalid UAPI operation:", op)
|
device.log.Error.Println("Invalid UAPI operation:", op)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// write status
|
||||||
|
|
||||||
|
if status != nil {
|
||||||
|
device.log.Error.Println(status)
|
||||||
|
fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(buffered, "errno=0\n\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ const (
|
||||||
KeepaliveTimeout = time.Second * 10
|
KeepaliveTimeout = time.Second * 10
|
||||||
CookieRefreshTime = time.Second * 120
|
CookieRefreshTime = time.Second * 120
|
||||||
MaxHandshakeAttemptTime = time.Second * 90
|
MaxHandshakeAttemptTime = time.Second * 90
|
||||||
|
PaddingMultiple = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -31,5 +32,5 @@ const (
|
||||||
QueueHandshakeSize = 1024
|
QueueHandshakeSize = 1024
|
||||||
QueueHandshakeBusySize = QueueHandshakeSize / 8
|
QueueHandshakeBusySize = QueueHandshakeSize / 8
|
||||||
MinMessageSize = MessageTransportSize // size of keep-alive
|
MinMessageSize = MessageTransportSize // size of keep-alive
|
||||||
MaxMessageSize = (1 << 16) - 1
|
MaxMessageSize = ((1 << 16) - 1) + MessageTransportHeaderSize
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -10,6 +12,7 @@ import (
|
||||||
|
|
||||||
type Device struct {
|
type Device struct {
|
||||||
mtu int32
|
mtu int32
|
||||||
|
tun TUNDevice
|
||||||
log *Logger // collection of loggers for levels
|
log *Logger // collection of loggers for levels
|
||||||
idCounter uint // for assigning debug ids to peers
|
idCounter uint // for assigning debug ids to peers
|
||||||
fwMark uint32
|
fwMark uint32
|
||||||
|
@ -43,24 +46,46 @@ type Device struct {
|
||||||
mac MACStateDevice
|
mac MACStateDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
|
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||||
device.mutex.Lock()
|
device.mutex.Lock()
|
||||||
defer device.mutex.Unlock()
|
defer device.mutex.Unlock()
|
||||||
|
|
||||||
|
// check if public key is matching any peer
|
||||||
|
|
||||||
|
publicKey := sk.publicKey()
|
||||||
|
for _, peer := range device.peers {
|
||||||
|
h := &peer.handshake
|
||||||
|
h.mutex.RLock()
|
||||||
|
if h.remoteStatic.Equals(publicKey) {
|
||||||
|
h.mutex.RUnlock()
|
||||||
|
return errors.New("Private key matches public key of peer")
|
||||||
|
}
|
||||||
|
h.mutex.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
// update key material
|
// update key material
|
||||||
|
|
||||||
device.privateKey = sk
|
device.privateKey = sk
|
||||||
device.publicKey = sk.publicKey()
|
device.publicKey = publicKey
|
||||||
device.mac.Init(device.publicKey)
|
device.mac.Init(publicKey)
|
||||||
|
|
||||||
// do DH precomputations
|
// do DH precomputations
|
||||||
|
|
||||||
|
isZero := device.privateKey.IsZero()
|
||||||
|
|
||||||
for _, peer := range device.peers {
|
for _, peer := range device.peers {
|
||||||
h := &peer.handshake
|
h := &peer.handshake
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
|
if isZero {
|
||||||
|
h.precomputedStaticStatic = [NoisePublicKeySize]byte{}
|
||||||
|
} else {
|
||||||
|
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
|
||||||
|
}
|
||||||
|
fmt.Println(h.precomputedStaticStatic)
|
||||||
h.mutex.Unlock()
|
h.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
||||||
|
@ -77,6 +102,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
||||||
device.mutex.Lock()
|
device.mutex.Lock()
|
||||||
defer device.mutex.Unlock()
|
defer device.mutex.Unlock()
|
||||||
|
|
||||||
|
device.tun = tun
|
||||||
device.log = NewLogger(logLevel)
|
device.log = NewLogger(logLevel)
|
||||||
device.peers = make(map[NoisePublicKey]*Peer)
|
device.peers = make(map[NoisePublicKey]*Peer)
|
||||||
device.indices.Init()
|
device.indices.Init()
|
||||||
|
@ -119,22 +145,22 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
||||||
}
|
}
|
||||||
|
|
||||||
go device.RoutineBusyMonitor()
|
go device.RoutineBusyMonitor()
|
||||||
go device.RoutineMTUUpdater(tun)
|
go device.RoutineMTUUpdater()
|
||||||
go device.RoutineWriteToTUN(tun)
|
go device.RoutineWriteToTUN()
|
||||||
go device.RoutineReadFromTUN(tun)
|
go device.RoutineReadFromTUN()
|
||||||
go device.RoutineReceiveIncomming()
|
go device.RoutineReceiveIncomming()
|
||||||
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
|
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
|
||||||
|
|
||||||
return device
|
return device
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) RoutineMTUUpdater(tun TUNDevice) {
|
func (device *Device) RoutineMTUUpdater() {
|
||||||
logError := device.log.Error
|
logError := device.log.Error
|
||||||
for ; ; time.Sleep(5 * time.Second) {
|
for ; ; time.Sleep(5 * time.Second) {
|
||||||
|
|
||||||
// load updated MTU
|
// load updated MTU
|
||||||
|
|
||||||
mtu, err := tun.MTU()
|
mtu, err := device.tun.MTU()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to load updated MTU of device:", err)
|
logError.Println("Failed to load updated MTU of device:", err)
|
||||||
continue
|
continue
|
||||||
|
|
10
src/index.go
10
src/index.go
|
@ -3,6 +3,7 @@ package main
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"sync"
|
"sync"
|
||||||
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Index=0 is reserved for unset indecies
|
/* Index=0 is reserved for unset indecies
|
||||||
|
@ -23,14 +24,7 @@ type IndexTable struct {
|
||||||
func randUint32() (uint32, error) {
|
func randUint32() (uint32, error) {
|
||||||
var buff [4]byte
|
var buff [4]byte
|
||||||
_, err := rand.Read(buff[:])
|
_, err := rand.Read(buff[:])
|
||||||
id := uint32(buff[0])
|
return *((*uint32)(unsafe.Pointer(&buff))), err
|
||||||
id <<= 8
|
|
||||||
id |= uint32(buff[1])
|
|
||||||
id <<= 8
|
|
||||||
id |= uint32(buff[2])
|
|
||||||
id <<= 8
|
|
||||||
id |= uint32(buff[3])
|
|
||||||
return id, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *IndexTable) Init() {
|
func (table *IndexTable) Init() {
|
||||||
|
|
15
src/macs.go
15
src/macs.go
|
@ -3,7 +3,6 @@ package main
|
||||||
import (
|
import (
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"errors"
|
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -15,14 +14,14 @@ type MACStateDevice struct {
|
||||||
refreshed time.Time
|
refreshed time.Time
|
||||||
secret [blake2s.Size]byte
|
secret [blake2s.Size]byte
|
||||||
keyMAC1 [blake2s.Size]byte
|
keyMAC1 [blake2s.Size]byte
|
||||||
keyMAC2 [blake2s.Size]byte
|
keyMAC2 [blake2s.Size]byte // TODO: Change to more descriptive size constant, rename to something.
|
||||||
}
|
}
|
||||||
|
|
||||||
type MACStatePeer struct {
|
type MACStatePeer struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
cookieSet time.Time
|
cookieSet time.Time
|
||||||
cookie [blake2s.Size128]byte
|
cookie [blake2s.Size128]byte
|
||||||
lastMAC1 [blake2s.Size128]byte
|
lastMAC1 [blake2s.Size128]byte // TODO: Check if set
|
||||||
keyMAC1 [blake2s.Size]byte
|
keyMAC1 [blake2s.Size]byte
|
||||||
keyMAC2 [blake2s.Size]byte
|
keyMAC2 [blake2s.Size]byte
|
||||||
}
|
}
|
||||||
|
@ -83,7 +82,7 @@ func (state *MACStateDevice) CheckMAC2(msg []byte, addr *net.UDPAddr) bool {
|
||||||
port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)}
|
port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)}
|
||||||
mac, _ := blake2s.New128(state.secret[:])
|
mac, _ := blake2s.New128(state.secret[:])
|
||||||
mac.Write(addr.IP)
|
mac.Write(addr.IP)
|
||||||
mac.Write(port[:])
|
mac.Write(port[:]) // TODO: Be faster and more platform dependent?
|
||||||
mac.Sum(cookie[:0])
|
mac.Sum(cookie[:0])
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -130,7 +129,7 @@ func (device *Device) CreateMessageCookieReply(
|
||||||
port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)}
|
port := [2]byte{byte(addr.Port >> 8), byte(addr.Port)}
|
||||||
mac, _ := blake2s.New128(state.secret[:])
|
mac, _ := blake2s.New128(state.secret[:])
|
||||||
mac.Write(addr.IP)
|
mac.Write(addr.IP)
|
||||||
mac.Write(port[:])
|
mac.Write(port[:]) // TODO: Do whatever we did above
|
||||||
mac.Sum(cookie[:0])
|
mac.Sum(cookie[:0])
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -196,6 +195,7 @@ func (device *Device) ConsumeMessageCookieReply(msg *MessageCookieReply) bool {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
state.cookieSet = time.Now()
|
state.cookieSet = time.Now()
|
||||||
state.cookie = cookie
|
state.cookie = cookie
|
||||||
return true
|
return true
|
||||||
|
@ -229,10 +229,6 @@ func (state *MACStatePeer) Init(pk NoisePublicKey) {
|
||||||
func (state *MACStatePeer) AddMacs(msg []byte) {
|
func (state *MACStatePeer) AddMacs(msg []byte) {
|
||||||
size := len(msg)
|
size := len(msg)
|
||||||
|
|
||||||
if size < blake2s.Size128*2 {
|
|
||||||
panic(errors.New("bug: message too short"))
|
|
||||||
}
|
|
||||||
|
|
||||||
startMac1 := size - (blake2s.Size128 * 2)
|
startMac1 := size - (blake2s.Size128 * 2)
|
||||||
startMac2 := size - blake2s.Size128
|
startMac2 := size - blake2s.Size128
|
||||||
|
|
||||||
|
@ -250,6 +246,7 @@ func (state *MACStatePeer) AddMacs(msg []byte) {
|
||||||
mac.Sum(mac1[:0])
|
mac.Sum(mac1[:0])
|
||||||
}()
|
}()
|
||||||
copy(state.lastMAC1[:], mac1)
|
copy(state.lastMAC1[:], mac1)
|
||||||
|
// TODO: Set lastMac flag
|
||||||
|
|
||||||
// set mac2
|
// set mac2
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,14 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isZero(val []byte) bool {
|
||||||
|
var acc byte
|
||||||
|
for _, b := range val {
|
||||||
|
acc |= b
|
||||||
|
}
|
||||||
|
return acc == 0
|
||||||
|
}
|
||||||
|
|
||||||
/* curve25519 wrappers */
|
/* curve25519 wrappers */
|
||||||
|
|
||||||
func newPrivateKey() (sk NoisePrivateKey, err error) {
|
func newPrivateKey() (sk NoisePrivateKey, err error) {
|
||||||
|
|
|
@ -135,6 +135,10 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
defer handshake.mutex.Unlock()
|
defer handshake.mutex.Unlock()
|
||||||
|
|
||||||
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||||
|
return nil, errors.New("Static shared secret is zero")
|
||||||
|
}
|
||||||
|
|
||||||
// create ephemeral key
|
// create ephemeral key
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
@ -226,7 +230,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// verify identity
|
// verify identity
|
||||||
|
|
||||||
|
@ -472,6 +480,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
|
||||||
func() {
|
func() {
|
||||||
kp.mutex.Lock()
|
kp.mutex.Lock()
|
||||||
defer kp.mutex.Unlock()
|
defer kp.mutex.Unlock()
|
||||||
|
// TODO: Adapt kernel behavior noise.c:161
|
||||||
if isInitiator {
|
if isInitiator {
|
||||||
if kp.previous != nil {
|
if kp.previous != nil {
|
||||||
kp.previous.send = nil
|
kp.previous.send = nil
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/subtle"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
@ -31,12 +32,12 @@ func loadExactHex(dst []byte, src string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key NoisePrivateKey) IsZero() bool {
|
func (key NoisePrivateKey) IsZero() bool {
|
||||||
for _, b := range key[:] {
|
var zero NoisePrivateKey
|
||||||
if b != 0 {
|
return key.Equals(zero)
|
||||||
return false
|
}
|
||||||
}
|
|
||||||
}
|
func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
|
||||||
return true
|
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *NoisePrivateKey) FromHex(src string) error {
|
func (key *NoisePrivateKey) FromHex(src string) error {
|
||||||
|
@ -55,6 +56,15 @@ func (key NoisePublicKey) ToHex() string {
|
||||||
return hex.EncodeToString(key[:])
|
return hex.EncodeToString(key[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (key NoisePublicKey) IsZero() bool {
|
||||||
|
var zero NoisePublicKey
|
||||||
|
return key.Equals(zero)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
|
||||||
|
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
|
||||||
|
}
|
||||||
|
|
||||||
func (key *NoiseSymmetricKey) FromHex(src string) error {
|
func (key *NoiseSymmetricKey) FromHex(src string) error {
|
||||||
return loadExactHex(key[:], src)
|
return loadExactHex(key[:], src)
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,6 +73,8 @@ func (device *Device) addToHandshakeQueue(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Routine determining the busy state of the interface
|
/* Routine determining the busy state of the interface
|
||||||
|
*
|
||||||
|
* TODO: Under load for some time
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineBusyMonitor() {
|
func (device *Device) RoutineBusyMonitor() {
|
||||||
samples := 0
|
samples := 0
|
||||||
|
@ -131,6 +133,7 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||||
buffer = device.GetMessageBuffer()
|
buffer = device.GetMessageBuffer()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Take writelock to sleep
|
||||||
device.net.mutex.RLock()
|
device.net.mutex.RLock()
|
||||||
conn := device.net.conn
|
conn := device.net.conn
|
||||||
device.net.mutex.RUnlock()
|
device.net.mutex.RUnlock()
|
||||||
|
@ -139,6 +142,7 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Wait for new conn or message
|
||||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
conn.SetReadDeadline(time.Now().Add(time.Second))
|
||||||
|
|
||||||
size, raddr, err := conn.ReadFromUDP(buffer[:])
|
size, raddr, err := conn.ReadFromUDP(buffer[:])
|
||||||
|
@ -156,6 +160,8 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||||
|
|
||||||
case MessageInitiationType, MessageResponseType:
|
case MessageInitiationType, MessageResponseType:
|
||||||
|
|
||||||
|
// TODO: Check size early
|
||||||
|
|
||||||
// add to handshake queue
|
// add to handshake queue
|
||||||
|
|
||||||
device.addToHandshakeQueue(
|
device.addToHandshakeQueue(
|
||||||
|
@ -171,6 +177,8 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||||
|
|
||||||
case MessageCookieReplyType:
|
case MessageCookieReplyType:
|
||||||
|
|
||||||
|
// TODO: Queue all the things
|
||||||
|
|
||||||
// verify and update peer cookie state
|
// verify and update peer cookie state
|
||||||
|
|
||||||
if len(packet) != MessageCookieReplySize {
|
if len(packet) != MessageCookieReplySize {
|
||||||
|
@ -250,7 +258,7 @@ func (device *Device) RoutineDecryption() {
|
||||||
// check if dropped
|
// check if dropped
|
||||||
|
|
||||||
if elem.IsDropped() {
|
if elem.IsDropped() {
|
||||||
elem.mutex.Unlock()
|
elem.mutex.Unlock() // TODO: Make consistent with send
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -318,6 +326,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
logError.Println("Failed to create cookie reply:", err)
|
logError.Println("Failed to create cookie reply:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// TODO: Use temp
|
||||||
writer := bytes.NewBuffer(elem.packet[:0])
|
writer := bytes.NewBuffer(elem.packet[:0])
|
||||||
binary.Write(writer, binary.LittleEndian, reply)
|
binary.Write(writer, binary.LittleEndian, reply)
|
||||||
elem.packet = writer.Bytes()
|
elem.packet = writer.Bytes()
|
||||||
|
@ -330,6 +339,8 @@ func (device *Device) RoutineHandshake() {
|
||||||
|
|
||||||
// ratelimit
|
// ratelimit
|
||||||
|
|
||||||
|
// TODO: Only ratelimit when busy
|
||||||
|
|
||||||
if !device.ratelimiter.Allow(elem.source.IP) {
|
if !device.ratelimiter.Allow(elem.source.IP) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -364,9 +375,14 @@ func (device *Device) RoutineHandshake() {
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peer.TimerPacketReceived()
|
|
||||||
|
// update timers
|
||||||
|
|
||||||
|
peer.TimerAnyAuthenticatedPacketTraversal()
|
||||||
|
peer.TimerAnyAuthenticatedPacketReceived()
|
||||||
|
|
||||||
// update endpoint
|
// update endpoint
|
||||||
|
// TODO: Add a race condition \s
|
||||||
|
|
||||||
peer.mutex.Lock()
|
peer.mutex.Lock()
|
||||||
peer.endpoint = elem.source
|
peer.endpoint = elem.source
|
||||||
|
@ -381,6 +397,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.TimerEphemeralKeyCreated()
|
peer.TimerEphemeralKeyCreated()
|
||||||
|
peer.NewKeyPair()
|
||||||
|
|
||||||
logDebug.Println("Creating response message for", peer.String())
|
logDebug.Println("Creating response message for", peer.String())
|
||||||
|
|
||||||
|
@ -392,8 +409,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
// send response
|
// send response
|
||||||
|
|
||||||
peer.SendBuffer(packet)
|
peer.SendBuffer(packet)
|
||||||
peer.TimerPacketSent()
|
peer.TimerAnyAuthenticatedPacketTraversal()
|
||||||
peer.NewKeyPair()
|
|
||||||
|
|
||||||
case MessageResponseType:
|
case MessageResponseType:
|
||||||
|
|
||||||
|
@ -423,8 +439,14 @@ func (device *Device) RoutineHandshake() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.TimerPacketReceived()
|
// update timers
|
||||||
|
|
||||||
|
peer.TimerAnyAuthenticatedPacketTraversal()
|
||||||
|
peer.TimerAnyAuthenticatedPacketReceived()
|
||||||
peer.TimerHandshakeComplete()
|
peer.TimerHandshakeComplete()
|
||||||
|
|
||||||
|
// derive key-pair
|
||||||
|
|
||||||
peer.NewKeyPair()
|
peer.NewKeyPair()
|
||||||
peer.SendKeepAlive()
|
peer.SendKeepAlive()
|
||||||
|
|
||||||
|
@ -467,8 +489,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.TimerPacketReceived()
|
peer.TimerAnyAuthenticatedPacketTraversal()
|
||||||
peer.TimerTransportReceived()
|
peer.TimerAnyAuthenticatedPacketReceived()
|
||||||
peer.KeepKeyFreshReceiving()
|
peer.KeepKeyFreshReceiving()
|
||||||
|
|
||||||
// check if using new key-pair
|
// check if using new key-pair
|
||||||
|
@ -504,6 +526,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||||
|
|
||||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||||
length := binary.BigEndian.Uint16(field)
|
length := binary.BigEndian.Uint16(field)
|
||||||
|
// TODO: check length of packet & NOT TOO SMALL either
|
||||||
elem.packet = elem.packet[:length]
|
elem.packet = elem.packet[:length]
|
||||||
|
|
||||||
// verify IPv4 source
|
// verify IPv4 source
|
||||||
|
@ -525,6 +548,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||||
length := binary.BigEndian.Uint16(field)
|
length := binary.BigEndian.Uint16(field)
|
||||||
length += ipv6.HeaderLen
|
length += ipv6.HeaderLen
|
||||||
|
// TODO: check length of packet
|
||||||
elem.packet = elem.packet[:length]
|
elem.packet = elem.packet[:length]
|
||||||
|
|
||||||
// verify IPv6 source
|
// verify IPv6 source
|
||||||
|
@ -542,11 +566,13 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||||
|
|
||||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
||||||
device.addToInboundQueue(device.queue.inbound, elem)
|
device.addToInboundQueue(device.queue.inbound, elem)
|
||||||
|
|
||||||
|
// TODO: move TUN write into per peer routine
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
|
func (device *Device) RoutineWriteToTUN() {
|
||||||
|
|
||||||
logError := device.log.Error
|
logError := device.log.Error
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
|
@ -557,7 +583,7 @@ func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
|
||||||
case <-device.signal.stop:
|
case <-device.signal.stop:
|
||||||
return
|
return
|
||||||
case elem := <-device.queue.inbound:
|
case elem := <-device.queue.inbound:
|
||||||
_, err := tun.Write(elem.packet)
|
_, err := device.tun.Write(elem.packet)
|
||||||
device.PutMessageBuffer(elem.buffer)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to write packet to TUN device:", err)
|
logError.Println("Failed to write packet to TUN device:", err)
|
||||||
|
|
51
src/send.go
51
src/send.go
|
@ -110,17 +110,19 @@ func addToEncryptionQueue(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
|
func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
|
||||||
|
peer.device.net.mutex.RLock()
|
||||||
|
defer peer.device.net.mutex.RUnlock()
|
||||||
|
|
||||||
peer.mutex.RLock()
|
peer.mutex.RLock()
|
||||||
|
defer peer.mutex.RUnlock()
|
||||||
|
|
||||||
endpoint := peer.endpoint
|
endpoint := peer.endpoint
|
||||||
peer.mutex.RUnlock()
|
conn := peer.device.net.conn
|
||||||
|
|
||||||
if endpoint == nil {
|
if endpoint == nil {
|
||||||
return 0, ErrorNoEndpoint
|
return 0, ErrorNoEndpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.device.net.mutex.RLock()
|
|
||||||
conn := peer.device.net.conn
|
|
||||||
peer.device.net.mutex.RUnlock()
|
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return 0, ErrorNoConnection
|
return 0, ErrorNoConnection
|
||||||
}
|
}
|
||||||
|
@ -133,13 +135,13 @@ func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
|
||||||
*
|
*
|
||||||
* Obs. Single instance per TUN device
|
* Obs. Single instance per TUN device
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
|
func (device *Device) RoutineReadFromTUN() {
|
||||||
|
|
||||||
if tun == nil {
|
if device.tun == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
elem := device.NewOutboundElement()
|
var elem *QueueOutboundElement
|
||||||
|
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
logError := device.log.Error
|
logError := device.log.Error
|
||||||
|
@ -153,32 +155,38 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
|
||||||
elem = device.NewOutboundElement()
|
elem = device.NewOutboundElement()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: THIS!
|
||||||
elem.packet = elem.buffer[MessageTransportHeaderSize:]
|
elem.packet = elem.buffer[MessageTransportHeaderSize:]
|
||||||
size, err := tun.Read(elem.packet)
|
size, err := device.tun.Read(elem.packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
||||||
// stop process
|
|
||||||
|
|
||||||
logError.Println("Failed to read packet from TUN device:", err)
|
logError.Println("Failed to read packet from TUN device:", err)
|
||||||
device.Close()
|
device.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
elem.packet = elem.packet[:size]
|
if size == 0 {
|
||||||
if len(elem.packet) < ipv4.HeaderLen {
|
|
||||||
logError.Println("Packet too short, length:", size)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
println(size, err)
|
||||||
|
|
||||||
|
elem.packet = elem.packet[:size]
|
||||||
|
|
||||||
// lookup peer
|
// lookup peer
|
||||||
|
|
||||||
var peer *Peer
|
var peer *Peer
|
||||||
switch elem.packet[0] >> 4 {
|
switch elem.packet[0] >> 4 {
|
||||||
case ipv4.Version:
|
case ipv4.Version:
|
||||||
|
if len(elem.packet) < ipv4.HeaderLen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||||
peer = device.routingTable.LookupIPv4(dst)
|
peer = device.routingTable.LookupIPv4(dst)
|
||||||
|
|
||||||
case ipv6.Version:
|
case ipv6.Version:
|
||||||
|
if len(elem.packet) < ipv6.HeaderLen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||||
peer = device.routingTable.LookupIPv6(dst)
|
peer = device.routingTable.LookupIPv6(dst)
|
||||||
|
|
||||||
|
@ -190,10 +198,15 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check if known endpoint
|
||||||
|
|
||||||
|
peer.mutex.RLock()
|
||||||
if peer.endpoint == nil {
|
if peer.endpoint == nil {
|
||||||
|
peer.mutex.RUnlock()
|
||||||
logDebug.Println("No known endpoint for peer", peer.String())
|
logDebug.Println("No known endpoint for peer", peer.String())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
peer.mutex.RUnlock()
|
||||||
|
|
||||||
// insert into nonce/pre-handshake queue
|
// insert into nonce/pre-handshake queue
|
||||||
|
|
||||||
|
@ -334,8 +347,12 @@ func (device *Device) RoutineEncryption() {
|
||||||
// pad content to MTU size
|
// pad content to MTU size
|
||||||
|
|
||||||
mtu := int(atomic.LoadInt32(&device.mtu))
|
mtu := int(atomic.LoadInt32(&device.mtu))
|
||||||
for i := len(elem.packet); i < mtu; i++ {
|
pad := len(elem.packet) % PaddingMultiple
|
||||||
elem.packet = append(elem.packet, 0)
|
if pad > 0 {
|
||||||
|
for i := 0; i < PaddingMultiple-pad && len(elem.packet) < mtu; i++ {
|
||||||
|
elem.packet = append(elem.packet, 0)
|
||||||
|
}
|
||||||
|
// TODO: How good is this code
|
||||||
}
|
}
|
||||||
|
|
||||||
// encrypt content (append to header)
|
// encrypt content (append to header)
|
||||||
|
@ -390,7 +407,7 @@ func (peer *Peer) RoutineSequentialSender() {
|
||||||
|
|
||||||
// update timers
|
// update timers
|
||||||
|
|
||||||
peer.TimerPacketSent()
|
peer.TimerAnyAuthenticatedPacketTraversal()
|
||||||
if len(elem.packet) != MessageKeepaliveSize {
|
if len(elem.packet) != MessageKeepaliveSize {
|
||||||
peer.TimerDataSent()
|
peer.TimerDataSent()
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,10 +60,8 @@ func (peer *Peer) SendKeepAlive() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Authenticated data packet send
|
/* Event:
|
||||||
* Always called together with peer.EventPacketSend
|
* Sent non-empty (authenticated) transport message
|
||||||
*
|
|
||||||
* - Start new handshake timer
|
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) TimerDataSent() {
|
func (peer *Peer) TimerDataSent() {
|
||||||
timerStop(peer.timer.keepalivePassive)
|
timerStop(peer.timer.keepalivePassive)
|
||||||
|
@ -75,8 +73,6 @@ func (peer *Peer) TimerDataSent() {
|
||||||
|
|
||||||
/* Event:
|
/* Event:
|
||||||
* Received non-empty (authenticated) transport message
|
* Received non-empty (authenticated) transport message
|
||||||
*
|
|
||||||
* - Start passive keep-alive timer
|
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) TimerDataReceived() {
|
func (peer *Peer) TimerDataReceived() {
|
||||||
if peer.timer.pendingKeepalivePassive {
|
if peer.timer.pendingKeepalivePassive {
|
||||||
|
@ -88,17 +84,16 @@ func (peer *Peer) TimerDataReceived() {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Event:
|
/* Event:
|
||||||
* Any (authenticated) transport message received
|
* Any (authenticated) packet received
|
||||||
* (keep-alive or data)
|
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) TimerTransportReceived() {
|
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
|
||||||
timerStop(peer.timer.newHandshake)
|
timerStop(peer.timer.newHandshake)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Event:
|
/* Event:
|
||||||
* Any packet send to the peer.
|
* Any authenticated packet send / received.
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) TimerPacketSent() {
|
func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
|
||||||
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
|
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
|
||||||
if interval > 0 {
|
if interval > 0 {
|
||||||
duration := time.Duration(interval) * time.Second
|
duration := time.Duration(interval) * time.Second
|
||||||
|
@ -106,13 +101,6 @@ func (peer *Peer) TimerPacketSent() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Event:
|
|
||||||
* Any authenticated packet received from peer
|
|
||||||
*/
|
|
||||||
func (peer *Peer) TimerPacketReceived() {
|
|
||||||
peer.TimerPacketSent()
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Called after succesfully completing a handshake.
|
/* Called after succesfully completing a handshake.
|
||||||
* i.e. after:
|
* i.e. after:
|
||||||
*
|
*
|
||||||
|
@ -129,7 +117,9 @@ func (peer *Peer) TimerHandshakeComplete() {
|
||||||
peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
|
peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Called whenever an ephemeral key is generated
|
/* Event:
|
||||||
|
* An ephemeral key is generated
|
||||||
|
*
|
||||||
* i.e after:
|
* i.e after:
|
||||||
*
|
*
|
||||||
* CreateMessageInitiation
|
* CreateMessageInitiation
|
||||||
|
@ -257,7 +247,6 @@ func (peer *Peer) RoutineHandshakeInitiator() {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-peer.signal.handshakeBegin:
|
case <-peer.signal.handshakeBegin:
|
||||||
signalSend(peer.signal.handshakeBegin)
|
|
||||||
case <-peer.signal.stop:
|
case <-peer.signal.stop:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -303,7 +292,6 @@ func (peer *Peer) RoutineHandshakeInitiator() {
|
||||||
binary.Write(writer, binary.LittleEndian, msg)
|
binary.Write(writer, binary.LittleEndian, msg)
|
||||||
packet := writer.Bytes()
|
packet := writer.Bytes()
|
||||||
peer.mac.AddMacs(packet)
|
peer.mac.AddMacs(packet)
|
||||||
peer.TimerPacketSent()
|
|
||||||
|
|
||||||
_, err = peer.SendBuffer(packet)
|
_, err = peer.SendBuffer(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -314,6 +302,8 @@ func (peer *Peer) RoutineHandshakeInitiator() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
peer.TimerAnyAuthenticatedPacketTraversal()
|
||||||
|
|
||||||
// set timeout
|
// set timeout
|
||||||
|
|
||||||
timeout := time.NewTimer(RekeyTimeout)
|
timeout := time.NewTimer(RekeyTimeout)
|
||||||
|
@ -337,7 +327,6 @@ func (peer *Peer) RoutineHandshakeInitiator() {
|
||||||
continue
|
continue
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// allow new signal to be set
|
// allow new signal to be set
|
||||||
|
|
|
@ -32,11 +32,14 @@ type Trie struct {
|
||||||
/* Finds length of matching prefix
|
/* Finds length of matching prefix
|
||||||
* TODO: Make faster
|
* TODO: Make faster
|
||||||
*
|
*
|
||||||
* Assumption: len(ip1) == len(ip2)
|
* Assumption:
|
||||||
|
* len(ip1) == len(ip2)
|
||||||
|
* len(ip1) mod 4 = 0
|
||||||
*/
|
*/
|
||||||
func commonBits(ip1 net.IP, ip2 net.IP) uint {
|
func commonBits(ip1 []byte, ip2 []byte) uint {
|
||||||
var i uint
|
var i uint
|
||||||
size := uint(len(ip1))
|
size := uint(len(ip1)) / 4
|
||||||
|
|
||||||
for i = 0; i < size; i++ {
|
for i = 0; i < size; i++ {
|
||||||
v := ip1[i] ^ ip2[i]
|
v := ip1[i] ^ ip2[i]
|
||||||
if v != 0 {
|
if v != 0 {
|
||||||
|
|
|
@ -9,6 +9,7 @@ const DefaultMTU = 1420
|
||||||
type TUNDevice interface {
|
type TUNDevice interface {
|
||||||
Read([]byte) (int, error) // read a packet from the device (without any additional headers)
|
Read([]byte) (int, error) // read a packet from the device (without any additional headers)
|
||||||
Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
|
Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
|
||||||
|
IsUp() (bool, error) // is the interface up?
|
||||||
MTU() (int, error) // returns the MTU of the device
|
MTU() (int, error) // returns the MTU of the device
|
||||||
Name() string // returns the current name
|
Name() string // returns the current name
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
@ -19,6 +20,11 @@ type NativeTun struct {
|
||||||
name string
|
name string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) IsUp() (bool, error) {
|
||||||
|
inter, err := net.InterfaceByName(tun.name)
|
||||||
|
return inter.Flags&net.FlagUp != 0, err
|
||||||
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Name() string {
|
func (tun *NativeTun) Name() string {
|
||||||
return tun.name
|
return tun.name
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,13 +11,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ipcErrorIO = int64(unix.EIO)
|
ipcErrorIO = -int64(unix.EIO)
|
||||||
ipcErrorNoPeer = int64(unix.EPROTO)
|
ipcErrorNotDefined = -int64(unix.ENODEV)
|
||||||
ipcErrorNoKeyValue = int64(unix.EPROTO)
|
ipcErrorProtocol = -int64(unix.EPROTO)
|
||||||
ipcErrorInvalidKey = int64(unix.EPROTO)
|
ipcErrorInvalid = -int64(unix.EINVAL)
|
||||||
ipcErrorInvalidValue = int64(unix.EPROTO)
|
socketDirectory = "/var/run/wireguard"
|
||||||
socketDirectory = "/var/run/wireguard"
|
socketName = "%s.sock"
|
||||||
socketName = "%s.sock"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
/* TODO:
|
/* TODO:
|
||||||
|
|
Loading…
Reference in a new issue