Number of fixes in response to code review

This version cannot complete a handshake.
The program will panic upon receiving any message on the UDP socket.
This commit is contained in:
Mathias Hall-Andersen 2017-08-07 15:25:04 +02:00
parent 8c34c4cbb3
commit cba1d6585a
12 changed files with 552 additions and 445 deletions

View file

@ -84,13 +84,47 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return nil
}
func updateUDPConn(device *Device) error {
var err error
netc := &device.net
netc.mutex.Lock()
// close existing connection
if netc.conn != nil {
netc.conn.Close()
netc.conn = nil
}
// open new existing connection
conn, err := net.ListenUDP("udp", netc.addr)
if err == nil {
netc.conn = conn
signalSend(device.signal.newUDPConn)
}
netc.mutex.Unlock()
return err
}
func closeUDPConn(device *Device) {
device.net.mutex.Lock()
device.net.conn = nil
device.net.mutex.Unlock()
println("send signal")
signalSend(device.signal.newUDPConn)
}
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
scanner := bufio.NewScanner(socket)
logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
var peer *Peer
dummy := false
deviceConfig := true
for scanner.Scan() {
@ -135,17 +169,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
netc := &device.net
netc.mutex.Lock()
if netc.addr.Port != int(port) {
if netc.conn != nil {
netc.conn.Close()
}
netc.addr.Port = int(port)
netc.conn, err = net.ListenUDP("udp", netc.addr)
}
netc.mutex.Unlock()
if err != nil {
logError.Println("Failed to create UDP listener:", err)
return &IPCError{Code: ipcErrorIO}
}
updateUDPConn(device)
// TODO: Clear source address of all peers
case "fwmark":
@ -189,17 +217,30 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.mutex.RLock()
if device.publicKey.Equals(pubKey) {
// create dummy instance
peer = &Peer{}
dummy = true
device.mutex.RUnlock()
logError.Println("Public key of peer matches private key of device")
return &IPCError{Code: ipcErrorInvalid}
}
logInfo.Println("Ignoring peer with public key of device")
// find peer referenced
} else {
// find peer referenced
peer, _ = device.peers[pubKey]
device.mutex.RUnlock()
if peer == nil {
peer, err = device.NewPeer(pubKey)
if err != nil {
logError.Println("Failed to create new peer:", err)
return &IPCError{Code: ipcErrorInvalid}
}
}
signalSend(peer.signal.handshakeReset)
dummy = false
peer, _ = device.peers[pubKey]
device.mutex.RUnlock()
if peer == nil {
peer = device.NewPeer(pubKey)
}
case "remove":
@ -207,16 +248,17 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set remove, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
}
device.RemovePeer(peer.handshake.remoteStatic)
logDebug.Println("Removing", peer.String())
peer = nil
if !dummy {
logDebug.Println("Removing", peer.String())
device.RemovePeer(peer.handshake.remoteStatic)
}
peer = &Peer{}
dummy = true
case "preshared_key":
err := func() error {
peer.mutex.Lock()
defer peer.mutex.Unlock()
return peer.handshake.presharedKey.FromHex(value)
}()
peer.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value)
peer.mutex.Unlock()
if err != nil {
logError.Println("Failed to set preshared_key:", err)
return &IPCError{Code: ipcErrorInvalid}
@ -232,6 +274,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
peer.mutex.Lock()
peer.endpoint = addr
peer.mutex.Unlock()
signalSend(peer.signal.handshakeReset)
case "persistent_keepalive_interval":
@ -251,12 +294,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// send immediate keep-alive
if old == 0 && secs != 0 {
up, err := device.tun.IsUp()
if err != nil {
logError.Println("Failed to get tun device status:", err)
return &IPCError{Code: ipcErrorIO}
}
if up {
if atomic.LoadInt32(&device.isUp) == AtomicTrue && !dummy {
peer.SendKeepAlive()
}
}
@ -266,7 +308,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
}
device.routingTable.RemovePeer(peer)
if !dummy {
device.routingTable.RemovePeer(peer)
}
case "allowed_ip":
_, network, err := net.ParseCIDR(value)
@ -275,7 +319,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid}
}
ones, _ := network.Mask.Size()
device.routingTable.Insert(network.IP, uint(ones), peer)
if !dummy {
device.routingTable.Insert(network.IP, uint(ones), peer)
}
default:
logError.Println("Invalid UAPI key (peer configuration):", key)

View file

@ -7,16 +7,15 @@ import (
/* Specification constants */
const (
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5
RejectAfterTime = time.Second * 180
KeepaliveTimeout = time.Second * 10
CookieRefreshTime = time.Second * 120
MaxHandshakeAttemptTime = time.Second * 90
PaddingMultiple = 16
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5
RejectAfterTime = time.Second * 180
KeepaliveTimeout = time.Second * 10
CookieRefreshTime = time.Second * 120
PaddingMultiple = 16
)
const (
@ -33,4 +32,5 @@ const (
QueueHandshakeBusySize = QueueHandshakeSize / 8
MinMessageSize = MessageTransportSize // size of keep-alive
MaxMessageSize = ((1 << 16) - 1) + MessageTransportHeaderSize
MaxPeers = 1 << 16
)

View file

@ -7,6 +7,8 @@ import (
/* Daemonizes the process on linux
*
* This is done by spawning and releasing a copy with the --foreground flag
*
* TODO: Use env variable to spawn in background
*/
func Daemonize() error {

View file

@ -1,13 +1,10 @@
package main
import (
"errors"
"fmt"
"net"
"runtime"
"sync"
"sync/atomic"
"time"
)
type Device struct {
@ -34,31 +31,45 @@ type Device struct {
queue struct {
encryption chan *QueueOutboundElement
decryption chan *QueueInboundElement
inbound chan *QueueInboundElement
handshake chan QueueHandshakeElement
}
signal struct {
stop chan struct{}
stop chan struct{} // halts all go routines
newUDPConn chan struct{} // a net.conn was set
}
underLoad int32 // used as an atomic bool
isUp int32 // atomic bool: interface is up
underLoad int32 // atomic bool: device is under load
ratelimiter Ratelimiter
peers map[NoisePublicKey]*Peer
mac MACStateDevice
}
/* Warning:
* The caller must hold the device mutex (write lock)
*/
func removePeerUnsafe(device *Device, key NoisePublicKey) {
peer, ok := device.peers[key]
if !ok {
return
}
peer.mutex.Lock()
device.routingTable.RemovePeer(peer)
delete(device.peers, key)
peer.Close()
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
device.mutex.Lock()
defer device.mutex.Unlock()
// check if public key is matching any peer
// remove peers with matching public keys
publicKey := sk.publicKey()
for _, peer := range device.peers {
for key, peer := range device.peers {
h := &peer.handshake
h.mutex.RLock()
if h.remoteStatic.Equals(publicKey) {
h.mutex.RUnlock()
return errors.New("Private key matches public key of peer")
removePeerUnsafe(device, key)
}
h.mutex.RUnlock()
}
@ -71,17 +82,19 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// do DH precomputations
isZero := device.privateKey.IsZero()
rmKey := device.privateKey.IsZero()
for _, peer := range device.peers {
for key, peer := range device.peers {
h := &peer.handshake
h.mutex.Lock()
if isZero {
if rmKey {
h.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
if isZero(h.precomputedStaticStatic[:]) {
removePeerUnsafe(device, key)
}
}
fmt.Println(h.precomputedStaticStatic)
h.mutex.Unlock()
}
@ -130,11 +143,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
device.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
// prepare signals
device.signal.stop = make(chan struct{})
device.signal.newUDPConn = make(chan struct{}, 1)
// start workers
@ -145,33 +158,42 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
}
go device.RoutineBusyMonitor()
go device.RoutineMTUUpdater()
go device.RoutineWriteToTUN()
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
go device.RoutineReceiveIncomming()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
return device
}
func (device *Device) RoutineMTUUpdater() {
func (device *Device) RoutineTUNEventReader() {
events := device.tun.Events()
logError := device.log.Error
for ; ; time.Sleep(5 * time.Second) {
// load updated MTU
mtu, err := device.tun.MTU()
if err != nil {
logError.Println("Failed to load updated MTU of device:", err)
continue
for event := range events {
if event&TUNEventMTUUpdate != 0 {
mtu, err := device.tun.MTU()
if err != nil {
logError.Println("Failed to load updated MTU of device:", err)
} else {
if mtu+MessageTransportSize > MaxMessageSize {
mtu = MaxMessageSize - MessageTransportSize
}
atomic.StoreInt32(&device.mtu, int32(mtu))
}
}
// upper bound of mtu
if mtu+MessageTransportSize > MaxMessageSize {
mtu = MaxMessageSize - MessageTransportSize
if event&TUNEventUp != 0 {
println("handle 1")
atomic.StoreInt32(&device.isUp, AtomicTrue)
updateUDPConn(device)
println("handle 2", device.net.conn)
}
if event&TUNEventDown != 0 {
atomic.StoreInt32(&device.isUp, AtomicFalse)
closeUDPConn(device)
}
atomic.StoreInt32(&device.mtu, int32(mtu))
}
}
@ -184,15 +206,7 @@ func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
func (device *Device) RemovePeer(key NoisePublicKey) {
device.mutex.Lock()
defer device.mutex.Unlock()
peer, ok := device.peers[key]
if !ok {
return
}
peer.mutex.Lock()
device.routingTable.RemovePeer(peer)
delete(device.peers, key)
peer.Close()
removePeerUnsafe(device, key)
}
func (device *Device) RemoveAllPeers() {

View file

@ -18,12 +18,13 @@ type MACStateDevice struct {
}
type MACStatePeer struct {
mutex sync.RWMutex
cookieSet time.Time
cookie [blake2s.Size128]byte
lastMAC1 [blake2s.Size128]byte // TODO: Check if set
keyMAC1 [blake2s.Size]byte
keyMAC2 [blake2s.Size]byte
mutex sync.RWMutex
cookieSet time.Time
cookie [blake2s.Size128]byte
lastMAC1Set bool
lastMAC1 [blake2s.Size128]byte
keyMAC1 [blake2s.Size]byte
keyMAC2 [blake2s.Size]byte
}
/* Methods for verifing MAC fields
@ -184,6 +185,10 @@ func (device *Device) ConsumeMessageCookieReply(msg *MessageCookieReply) bool {
state.mutex.Lock()
defer state.mutex.Unlock()
if !state.lastMAC1Set {
return false
}
_, err := XChaCha20Poly1305Decrypt(
cookie[:0],
&msg.Nonce,
@ -246,7 +251,7 @@ func (state *MACStatePeer) AddMacs(msg []byte) {
mac.Sum(mac1[:0])
}()
copy(state.lastMAC1[:], mac1)
// TODO: Set lastMac flag
state.lastMAC1Set = true
// set mac2

View file

@ -9,16 +9,14 @@ import (
"time"
)
const ()
type Peer struct {
id uint
mutex sync.RWMutex
endpoint *net.UDPAddr
persistentKeepaliveInterval uint64
keyPairs KeyPairs
handshake Handshake
device *Device
endpoint *net.UDPAddr
stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
@ -34,6 +32,7 @@ type Peer struct {
newKeyPair chan struct{} // (size 1) : a new key pair was generated
handshakeBegin chan struct{} // (size 1) : request that a new handshake be started ("queue handshake")
handshakeCompleted chan struct{} // (size 1) : handshake completed
handshakeReset chan struct{} // (size 1) : reset handshake negotiation state
flushNonceQueue chan struct{} // (size 1) : empty queued packets
messageSend chan struct{} // (size 1) : a message was send to the peer
messageReceived chan struct{} // (size 1) : an authenticated message was received
@ -44,6 +43,7 @@ type Peer struct {
keepalivePassive *time.Timer // set upon recieving messages
newHandshake *time.Timer // begin a new handshake (after Keepalive + RekeyTimeout)
zeroAllKeys *time.Timer // zero all key material (after RejectAfterTime*3)
handshakeDeadline *time.Timer // Current handshake must be completed
pendingKeepalivePassive bool
pendingNewHandshake bool
@ -59,7 +59,7 @@ type Peer struct {
mac MACStatePeer
}
func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// create peer
peer := new(Peer)
@ -80,11 +80,17 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
peer.id = device.idCounter
device.idCounter += 1
// check if over limit
if len(device.peers) >= MaxPeers {
return nil, errors.New("Too many peers")
}
// map public key
_, ok := device.peers[pk]
if ok {
panic(errors.New("bug: adding existing peer"))
return nil, errors.New("Adding existing peer")
}
device.peers[pk] = peer
device.mutex.Unlock()
@ -108,6 +114,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
peer.signal.stop = make(chan struct{})
peer.signal.newKeyPair = make(chan struct{}, 1)
peer.signal.handshakeBegin = make(chan struct{}, 1)
peer.signal.handshakeReset = make(chan struct{}, 1)
peer.signal.handshakeCompleted = make(chan struct{}, 1)
peer.signal.flushNonceQueue = make(chan struct{}, 1)
@ -117,7 +124,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver()
return peer
return peer, nil
}
func (peer *Peer) String() string {

View file

@ -111,113 +111,84 @@ func (device *Device) RoutineBusyMonitor() {
func (device *Device) RoutineReceiveIncomming() {
logInfo := device.log.Info
logDebug := device.log.Debug
logDebug.Println("Routine, receive incomming, started")
var buffer *[MaxMessageSize]byte
for {
// check if stopped
// wait for new conn
var conn *net.UDPConn
select {
case <-device.signal.newUDPConn:
device.net.mutex.RLock()
conn = device.net.conn
device.net.mutex.RUnlock()
case <-device.signal.stop:
return
default:
}
// read next datagram
if buffer == nil {
buffer = device.GetMessageBuffer()
}
// TODO: Take writelock to sleep
device.net.mutex.RLock()
conn := device.net.conn
device.net.mutex.RUnlock()
if conn == nil {
time.Sleep(time.Second)
continue
}
// TODO: Wait for new conn or message
conn.SetReadDeadline(time.Now().Add(time.Second))
// receive datagrams until closed
size, raddr, err := conn.ReadFromUDP(buffer[:])
if err != nil || size < MinMessageSize {
continue
}
buffer := device.GetMessageBuffer()
// handle packet
for {
packet := buffer[:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
// read next datagram
size, raddr, err := conn.ReadFromUDP(buffer[:]) // TODO: This is broken
if err != nil {
break
}
if size < MinMessageSize {
continue
}
// check size of packet
packet := buffer[:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
var okay bool
func() {
switch msgType {
case MessageInitiationType, MessageResponseType:
// TODO: Check size early
// add to handshake queue
device.addToHandshakeQueue(
device.queue.handshake,
QueueHandshakeElement{
msgType: msgType,
buffer: buffer,
packet: packet,
source: raddr,
},
)
buffer = nil
case MessageCookieReplyType:
// TODO: Queue all the things
// verify and update peer cookie state
if len(packet) != MessageCookieReplySize {
return
}
var reply MessageCookieReply
reader := bytes.NewReader(packet)
err := binary.Read(reader, binary.LittleEndian, &reply)
if err != nil {
logDebug.Println("Failed to decode cookie reply")
return
}
device.ConsumeMessageCookieReply(&reply)
// check if transport
case MessageTransportType:
// lookup key pair
// check size
if len(packet) < MessageTransportSize {
return
if len(packet) < MessageTransportType {
continue
}
// lookup key pair
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indices.Lookup(receiver)
keyPair := value.keyPair
if keyPair == nil {
return
continue
}
// check key-pair expiry
if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
return
continue
}
// add to peer queue
// create work element
peer := value.peer
elem := &QueueInboundElement{
@ -233,11 +204,33 @@ func (device *Device) RoutineReceiveIncomming() {
device.addToInboundQueue(device.queue.decryption, elem)
device.addToInboundQueue(peer.queue.inbound, elem)
buffer = nil
continue
default:
logInfo.Println("Got unknown message from:", raddr)
// otherwise it is a handshake related packet
case MessageInitiationType:
okay = len(packet) == MessageInitiationSize
case MessageResponseType:
okay = len(packet) == MessageResponseSize
case MessageCookieReplyType:
okay = len(packet) == MessageCookieReplySize
}
}()
if okay {
device.addToHandshakeQueue(
device.queue.handshake,
QueueHandshakeElement{
msgType: msgType,
buffer: buffer,
packet: packet,
source: raddr,
},
)
buffer = device.GetMessageBuffer()
}
}
}
}
@ -306,154 +299,165 @@ func (device *Device) RoutineHandshake() {
return
}
func() {
// handle cookie fields and ratelimiting
// verify mac1
switch elem.msgType {
case MessageCookieReplyType:
// verify and update peer cookie state
var reply MessageCookieReply
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &reply)
if err != nil {
logDebug.Println("Failed to decode cookie reply")
return
}
device.ConsumeMessageCookieReply(&reply)
continue
case MessageInitiationType, MessageResponseType:
// check mac fields and ratelimit
if !device.mac.CheckMAC1(elem.packet) {
logDebug.Println("Received packet with invalid mac1")
return
}
// verify mac2
busy := atomic.LoadInt32(&device.underLoad) == AtomicTrue
if busy && !device.mac.CheckMAC2(elem.packet, elem.source) {
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source)
if err != nil {
logError.Println("Failed to create cookie reply:", err)
return
}
// TODO: Use temp
writer := bytes.NewBuffer(elem.packet[:0])
binary.Write(writer, binary.LittleEndian, reply)
elem.packet = writer.Bytes()
_, err = device.net.conn.WriteToUDP(elem.packet, elem.source)
if err != nil {
logDebug.Println("Failed to send cookie reply:", err)
}
return
}
// ratelimit
// TODO: Only ratelimit when busy
if !device.ratelimiter.Allow(elem.source.IP) {
return
}
// handle messages
switch elem.msgType {
case MessageInitiationType:
// unmarshal
if len(elem.packet) != MessageInitiationSize {
return
}
var msg MessageInitiation
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
logError.Println("Failed to decode initiation message")
return
}
// consume initiation
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
logInfo.Println(
"Recieved invalid initiation message from",
elem.source.IP.String(),
elem.source.Port,
if busy {
if !device.mac.CheckMAC2(elem.packet, elem.source) {
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source)
if err != nil {
logError.Println("Failed to create cookie reply:", err)
return
}
writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, reply)
_, err = device.net.conn.WriteToUDP(
writer.Bytes(),
elem.source,
)
return
if err != nil {
logDebug.Println("Failed to send cookie reply:", err)
}
continue
}
// update timers
peer.TimerAnyAuthenticatedPacketTraversal()
peer.TimerAnyAuthenticatedPacketReceived()
// update endpoint
// TODO: Add a race condition \s
peer.mutex.Lock()
peer.endpoint = elem.source
peer.mutex.Unlock()
// create response
response, err := device.CreateMessageResponse(peer)
if err != nil {
logError.Println("Failed to create response message:", err)
return
if !device.ratelimiter.Allow(elem.source.IP) {
continue
}
peer.TimerEphemeralKeyCreated()
peer.NewKeyPair()
logDebug.Println("Creating response message for", peer.String())
writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.mac.AddMacs(packet)
// send response
peer.SendBuffer(packet)
peer.TimerAnyAuthenticatedPacketTraversal()
case MessageResponseType:
// unmarshal
if len(elem.packet) != MessageResponseSize {
return
}
var msg MessageResponse
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
logError.Println("Failed to decode response message")
return
}
// consume response
peer := device.ConsumeMessageResponse(&msg)
if peer == nil {
logInfo.Println(
"Recieved invalid response message from",
elem.source.IP.String(),
elem.source.Port,
)
return
}
// update timers
peer.TimerAnyAuthenticatedPacketTraversal()
peer.TimerAnyAuthenticatedPacketReceived()
peer.TimerHandshakeComplete()
// derive key-pair
peer.NewKeyPair()
peer.SendKeepAlive()
default:
logError.Println("Invalid message type in handshake queue")
}
}()
default:
logError.Println("Invalid packet ended up in the handshake queue")
continue
}
// handle handshake initation/response content
switch elem.msgType {
case MessageInitiationType:
// unmarshal
var msg MessageInitiation
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
logError.Println("Failed to decode initiation message")
continue
}
// consume initiation
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
logInfo.Println(
"Recieved invalid initiation message from",
elem.source.IP.String(),
elem.source.Port,
)
continue
}
// update timers
peer.TimerAnyAuthenticatedPacketTraversal()
peer.TimerAnyAuthenticatedPacketReceived()
// update endpoint
// TODO: Discover destination address also, only update on change
peer.mutex.Lock()
peer.endpoint = elem.source
peer.mutex.Unlock()
// create response
response, err := device.CreateMessageResponse(peer)
if err != nil {
logError.Println("Failed to create response message:", err)
continue
}
peer.TimerEphemeralKeyCreated()
peer.NewKeyPair()
logDebug.Println("Creating response message for", peer.String())
writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.mac.AddMacs(packet)
// send response
_, err = peer.SendBuffer(packet)
if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal()
}
case MessageResponseType:
// unmarshal
var msg MessageResponse
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
logError.Println("Failed to decode response message")
continue
}
// consume response
peer := device.ConsumeMessageResponse(&msg)
if peer == nil {
logInfo.Println(
"Recieved invalid response message from",
elem.source.IP.String(),
elem.source.Port,
)
continue
}
peer.TimerEphemeralKeyCreated()
// update timers
peer.TimerAnyAuthenticatedPacketTraversal()
peer.TimerAnyAuthenticatedPacketReceived()
peer.TimerHandshakeComplete()
// derive key-pair
peer.NewKeyPair()
peer.SendKeepAlive()
}
}
}
@ -463,6 +467,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
device := peer.device
logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
@ -478,116 +483,104 @@ func (peer *Peer) RoutineSequentialReceiver() {
// process packet
func() {
if elem.IsDropped() {
return
if elem.IsDropped() {
continue
}
// check for replay
if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
continue
}
peer.TimerAnyAuthenticatedPacketTraversal()
peer.TimerAnyAuthenticatedPacketReceived()
peer.KeepKeyFreshReceiving()
// check if using new key-pair
kp := &peer.keyPairs
kp.mutex.Lock()
if kp.next == elem.keyPair {
peer.TimerHandshakeComplete()
kp.previous = kp.current
kp.current = kp.next
kp.next = nil
}
kp.mutex.Unlock()
// check for keep-alive
if len(elem.packet) == 0 {
logDebug.Println("Received keep-alive from", peer.String())
continue
}
peer.TimerDataReceived()
// verify source and strip padding
switch elem.packet[0] >> 4 {
case ipv4.Version:
// strip padding
if len(elem.packet) < ipv4.HeaderLen {
continue
}
// check for replay
if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
return
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
continue
}
peer.TimerAnyAuthenticatedPacketTraversal()
peer.TimerAnyAuthenticatedPacketReceived()
peer.KeepKeyFreshReceiving()
elem.packet = elem.packet[:length]
// check if using new key-pair
// verify IPv4 source
kp := &peer.keyPairs
kp.mutex.Lock()
if kp.next == elem.keyPair {
peer.TimerHandshakeComplete()
kp.previous = kp.current
kp.current = kp.next
kp.next = nil
}
kp.mutex.Unlock()
// check for keep-alive
if len(elem.packet) == 0 {
logDebug.Println("Received keep-alive from", peer.String())
return
}
peer.TimerDataReceived()
// verify source and strip padding
switch elem.packet[0] >> 4 {
case ipv4.Version:
// strip padding
if len(elem.packet) < ipv4.HeaderLen {
return
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
// TODO: check length of packet & NOT TOO SMALL either
elem.packet = elem.packet[:length]
// verify IPv4 source
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.routingTable.LookupIPv4(src) != peer {
logInfo.Println("Packet with unallowed source IP from", peer.String())
return
}
case ipv6.Version:
// strip padding
if len(elem.packet) < ipv6.HeaderLen {
return
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
// TODO: check length of packet
elem.packet = elem.packet[:length]
// verify IPv6 source
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routingTable.LookupIPv6(src) != peer {
logInfo.Println("Packet with unallowed source IP from", peer.String())
return
}
default:
logInfo.Println("Packet with invalid IP version from", peer.String())
return
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.routingTable.LookupIPv4(src) != peer {
logInfo.Println("Packet with unallowed source IP from", peer.String())
continue
}
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
device.addToInboundQueue(device.queue.inbound, elem)
case ipv6.Version:
// TODO: move TUN write into per peer routine
}()
}
}
// strip padding
func (device *Device) RoutineWriteToTUN() {
logError := device.log.Error
logDebug := device.log.Debug
logDebug.Println("Routine, sequential tun writer, started")
for {
select {
case <-device.signal.stop:
return
case elem := <-device.queue.inbound:
_, err := device.tun.Write(elem.packet)
device.PutMessageBuffer(elem.buffer)
if err != nil {
logError.Println("Failed to write packet to TUN device:", err)
if len(elem.packet) < ipv6.HeaderLen {
continue
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
continue
}
elem.packet = elem.packet[:length]
// verify IPv6 source
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routingTable.LookupIPv6(src) != peer {
logInfo.Println("Packet with unallowed source IP from", peer.String())
continue
}
default:
logInfo.Println("Packet with invalid IP version from", peer.String())
continue
}
// write to tun
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
_, err := device.tun.Write(elem.packet)
device.PutMessageBuffer(elem.buffer)
if err != nil {
logError.Println("Failed to write packet to TUN device:", err)
}
}
}

View file

@ -168,8 +168,6 @@ func (device *Device) RoutineReadFromTUN() {
continue
}
println(size, err)
elem.packet = elem.packet[:size]
// lookup peer
@ -210,6 +208,7 @@ func (device *Device) RoutineReadFromTUN() {
// insert into nonce/pre-handshake queue
signalSend(peer.signal.handshakeReset)
addToOutboundQueue(peer.queue.nonce, elem)
elem = nil

View file

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"golang.org/x/crypto/blake2s"
"math/rand"
"sync/atomic"
"time"
)
@ -16,12 +17,11 @@ func (peer *Peer) KeepKeyFreshSending() {
if kp == nil {
return
}
if !kp.isInitiator {
return
}
nonce := atomic.LoadUint64(&kp.sendNonce)
send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime
if send {
if nonce > RekeyAfterMessages {
signalSend(peer.signal.handshakeBegin)
}
if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
signalSend(peer.signal.handshakeBegin)
}
}
@ -30,6 +30,7 @@ func (peer *Peer) KeepKeyFreshSending() {
*
*/
func (peer *Peer) KeepKeyFreshReceiving() {
// TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete)
kp := peer.keyPairs.Current()
if kp == nil {
return
@ -108,7 +109,6 @@ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
* - First transport message under the "next" key
*/
func (peer *Peer) TimerHandshakeComplete() {
timerStop(peer.timer.zeroAllKeys)
atomic.StoreInt64(
&peer.stats.lastHandshakeNano,
time.Now().UnixNano(),
@ -129,10 +129,7 @@ func (peer *Peer) TimerHandshakeComplete() {
* upon failure to complete a handshake
*/
func (peer *Peer) TimerEphemeralKeyCreated() {
if !peer.timer.pendingZeroAllKeys {
peer.timer.pendingZeroAllKeys = true
peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
}
peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
}
func (peer *Peer) RoutineTimerHandler() {
@ -154,19 +151,19 @@ func (peer *Peer) RoutineTimerHandler() {
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 {
logDebug.Println("Sending persistent keep-alive to", peer.String())
logDebug.Println("Sending keep-alive to", peer.String())
peer.SendKeepAlive()
}
case <-peer.timer.keepalivePassive.C:
logDebug.Println("Sending passive keep-alive to", peer.String())
logDebug.Println("Sending keep-alive to", peer.String())
peer.SendKeepAlive()
if peer.timer.needAnotherKeepalive {
peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
peer.timer.needAnotherKeepalive = true
peer.timer.needAnotherKeepalive = false
}
// unresponsive session
@ -189,8 +186,6 @@ func (peer *Peer) RoutineTimerHandler() {
kp := &peer.keyPairs
kp.mutex.Lock()
peer.timer.pendingZeroAllKeys = false
// unmap indecies
indices.mutex.Lock()
@ -251,40 +246,41 @@ func (peer *Peer) RoutineHandshakeInitiator() {
return
}
// wait for handshake
// set deadline
deadline := time.Now().Add(MaxHandshakeAttemptTime)
BeginHandshakes:
signalClear(peer.signal.handshakeReset)
deadline := time.NewTimer(RekeyAttemptTime)
AttemptHandshakes:
Loop:
for attempts := uint(1); ; attempts++ {
// clear completed signal
// check if deadline reached
select {
case <-peer.signal.handshakeCompleted:
case <-deadline.C:
logInfo.Println("Handshake negotiation timed out for:", peer.String())
signalSend(peer.signal.flushNonceQueue)
timerStop(peer.timer.keepalivePersistent)
break
case <-peer.signal.stop:
return
default:
}
// check if sufficient time for retry
if deadline.Before(time.Now().Add(RekeyTimeout)) {
logInfo.Println("Handshake negotiation timed out for", peer.String())
signalSend(peer.signal.flushNonceQueue)
timerStop(peer.timer.keepalivePersistent)
timerStop(peer.timer.keepalivePassive)
break Loop
}
signalClear(peer.signal.handshakeCompleted)
// create initiation message
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
logError.Println("Failed to create handshake initiation message:", err)
break Loop
break AttemptHandshakes
}
peer.TimerEphemeralKeyCreated()
jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
// marshal and send
@ -299,14 +295,14 @@ func (peer *Peer) RoutineHandshakeInitiator() {
"Failed to send handshake initiation message to",
peer.String(), ":", err,
)
continue
break
}
peer.TimerAnyAuthenticatedPacketTraversal()
// set timeout
// set handshake timeout
timeout := time.NewTimer(RekeyTimeout)
timeout := time.NewTimer(RekeyTimeout + jitter)
logDebug.Println(
"Handshake initiation attempt",
attempts, "sent to", peer.String(),
@ -321,15 +317,19 @@ func (peer *Peer) RoutineHandshakeInitiator() {
case <-peer.signal.handshakeCompleted:
<-timeout.C
break Loop
break AttemptHandshakes
case <-peer.signal.handshakeReset:
<-timeout.C
goto BeginHandshakes
case <-timeout.C:
// TODO: Clear source address for peer
continue
}
}
// allow new signal to be set
// clear signal set in the meantime
signalClear(peer.signal.handshakeBegin)
}

View file

@ -6,10 +6,19 @@ package main
const DefaultMTU = 1420
type TUNEvent int
const (
TUNEventUp = 1 << iota
TUNEventDown
TUNEventMTUUpdate
)
type TUNDevice interface {
Read([]byte) (int, error) // read a packet from the device (without any additional headers)
Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
IsUp() (bool, error) // is the interface up?
MTU() (int, error) // returns the MTU of the device
Name() string // returns the current name
Events() chan TUNEvent // returns a constant channel of events related to the device
Close() error // stops the device and closes the event channel
}

View file

@ -16,11 +16,12 @@ import (
const CloneDevicePath = "/dev/net/tun"
type NativeTun struct {
fd *os.File
name string
fd *os.File
name string
events chan TUNEvent
}
func (tun *NativeTun) IsUp() (bool, error) {
func (tun *NativeTun) isUp() (bool, error) {
inter, err := net.InterfaceByName(tun.name)
return inter.Flags&net.FlagUp != 0, err
}
@ -111,6 +112,14 @@ func (tun *NativeTun) Read(d []byte) (int, error) {
return tun.fd.Read(d)
}
func (tun *NativeTun) Events() chan TUNEvent {
return tun.events
}
func (tun *NativeTun) Close() error {
return nil
}
func CreateTUN(name string) (TUNDevice, error) {
// open clone device
@ -146,10 +155,14 @@ func CreateTUN(name string) (TUNDevice, error) {
newName := string(ifr[:])
newName = newName[:strings.Index(newName, "\000")]
device := &NativeTun{
fd: fd,
name: newName,
fd: fd,
name: newName,
events: make(chan TUNEvent, 5),
}
// TODO: Wait for device to be upped
device.events <- TUNEventUp
// set default MTU
err = device.setMTU(DefaultMTU)

View file

@ -7,7 +7,6 @@ import (
"net"
"os"
"path"
"time"
)
const (
@ -26,9 +25,10 @@ const (
*/
type UAPIListener struct {
listener net.Listener // unix socket listener
connNew chan net.Conn
connErr chan error
listener net.Listener // unix socket listener
connNew chan net.Conn
connErr chan error
inotifyFd int
}
func (l *UAPIListener) Accept() (net.Conn, error) {
@ -106,9 +106,28 @@ func NewUAPIListener(name string) (net.Listener, error) {
// watch for deletion of socket
uapi.inotifyFd, err = unix.InotifyInit()
if err != nil {
return nil, err
}
_, err = unix.InotifyAddWatch(
uapi.inotifyFd,
socketPath,
unix.IN_ATTRIB|
unix.IN_DELETE|
unix.IN_DELETE_SELF,
)
if err != nil {
return nil, err
}
go func(l *UAPIListener) {
for ; ; time.Sleep(time.Second) {
if _, err := os.Stat(socketPath); os.IsNotExist(err) {
var buff [4096]byte
for {
unix.Read(uapi.inotifyFd, buff[:])
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
l.connErr <- err
return
}