Terminate on interface deletion
Program now terminates when the interface is removed Increases the number of os threads (relevant for Go <1.5, not tested) More consistent commenting Improved logging (additional peer information)
This commit is contained in:
parent
8393cbff52
commit
93e3848ea7
|
@ -29,6 +29,6 @@ const (
|
||||||
QueueInboundSize = 1024
|
QueueInboundSize = 1024
|
||||||
QueueHandshakeSize = 1024
|
QueueHandshakeSize = 1024
|
||||||
QueueHandshakeBusySize = QueueHandshakeSize / 8
|
QueueHandshakeBusySize = QueueHandshakeSize / 8
|
||||||
MinMessageSize = MessageTransportSize // keep-alive
|
MinMessageSize = MessageTransportSize // size of keep-alive
|
||||||
MaxMessageSize = 4096 // TODO: make depend on the MTU?
|
MaxMessageSize = (1 << 16) - 1
|
||||||
)
|
)
|
||||||
|
|
|
@ -98,9 +98,9 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
|
||||||
}
|
}
|
||||||
|
|
||||||
go device.RoutineBusyMonitor()
|
go device.RoutineBusyMonitor()
|
||||||
|
go device.RoutineWriteToTUN(tun)
|
||||||
go device.RoutineReadFromTUN(tun)
|
go device.RoutineReadFromTUN(tun)
|
||||||
go device.RoutineReceiveIncomming()
|
go device.RoutineReceiveIncomming()
|
||||||
go device.RoutineWriteToTUN(tun)
|
|
||||||
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
|
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
|
||||||
|
|
||||||
return device
|
return device
|
||||||
|
@ -141,5 +141,8 @@ func (device *Device) RemoveAllPeers() {
|
||||||
func (device *Device) Close() {
|
func (device *Device) Close() {
|
||||||
device.RemoveAllPeers()
|
device.RemoveAllPeers()
|
||||||
close(device.signal.stop)
|
close(device.signal.stop)
|
||||||
close(device.queue.encryption)
|
}
|
||||||
|
|
||||||
|
func (device *Device) Wait() {
|
||||||
|
<-device.signal.stop
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,17 +5,13 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
IPv4version = 4
|
|
||||||
IPv4offsetTotalLength = 2
|
IPv4offsetTotalLength = 2
|
||||||
IPv4offsetSrc = 12
|
IPv4offsetSrc = 12
|
||||||
IPv4offsetDst = IPv4offsetSrc + net.IPv4len
|
IPv4offsetDst = IPv4offsetSrc + net.IPv4len
|
||||||
IPv4headerSize = 20
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
IPv6version = 6
|
|
||||||
IPv6offsetPayloadLength = 4
|
IPv6offsetPayloadLength = 4
|
||||||
IPv6offsetSrc = 8
|
IPv6offsetSrc = 8
|
||||||
IPv6offsetDst = IPv6offsetSrc + net.IPv6len
|
IPv6offsetDst = IPv6offsetSrc + net.IPv6len
|
||||||
IPv6headerSize = 40
|
|
||||||
)
|
)
|
||||||
|
|
31
src/main.go
31
src/main.go
|
@ -5,6 +5,7 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* TODO: Fix logging
|
/* TODO: Fix logging
|
||||||
|
@ -18,6 +19,10 @@ func main() {
|
||||||
}
|
}
|
||||||
deviceName := os.Args[1]
|
deviceName := os.Args[1]
|
||||||
|
|
||||||
|
// increase number of go workers (for Go <1.5)
|
||||||
|
|
||||||
|
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||||
|
|
||||||
// open TUN device
|
// open TUN device
|
||||||
|
|
||||||
tun, err := CreateTUN(deviceName)
|
tun, err := CreateTUN(deviceName)
|
||||||
|
@ -31,17 +36,21 @@ func main() {
|
||||||
|
|
||||||
// start configuration lister
|
// start configuration lister
|
||||||
|
|
||||||
socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
|
go func() {
|
||||||
l, err := net.Listen("unix", socketPath)
|
socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
|
||||||
if err != nil {
|
l, err := net.Listen("unix", socketPath)
|
||||||
log.Fatal("listen error:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
conn, err := l.Accept()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("accept error:", err)
|
log.Fatal("listen error:", err)
|
||||||
}
|
}
|
||||||
go ipcHandle(device, conn)
|
|
||||||
}
|
for {
|
||||||
|
conn, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("accept error:", err)
|
||||||
|
}
|
||||||
|
go ipcHandle(device, conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
device.Wait()
|
||||||
}
|
}
|
||||||
|
|
19
src/peer.go
19
src/peer.go
|
@ -1,7 +1,9 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -38,9 +40,9 @@ type Peer struct {
|
||||||
/* Both keep-alive timers acts as one (see timers.go)
|
/* Both keep-alive timers acts as one (see timers.go)
|
||||||
* They are kept seperate to simplify the implementation.
|
* They are kept seperate to simplify the implementation.
|
||||||
*/
|
*/
|
||||||
keepalivePersistent *time.Timer // set for persistent keepalives
|
keepalivePersistent *time.Timer // set for persistent keepalives
|
||||||
keepaliveAcknowledgement *time.Timer // set upon recieving messages
|
keepalivePassive *time.Timer // set upon recieving messages
|
||||||
zeroAllKeys *time.Timer // zero all key material after RejectAfterTime*3
|
zeroAllKeys *time.Timer // zero all key material after RejectAfterTime*3
|
||||||
}
|
}
|
||||||
queue struct {
|
queue struct {
|
||||||
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
|
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
|
||||||
|
@ -63,8 +65,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
|
||||||
peer.mac.Init(pk)
|
peer.mac.Init(pk)
|
||||||
peer.device = device
|
peer.device = device
|
||||||
|
|
||||||
|
peer.timer.keepalivePassive = NewStoppedTimer()
|
||||||
peer.timer.keepalivePersistent = NewStoppedTimer()
|
peer.timer.keepalivePersistent = NewStoppedTimer()
|
||||||
peer.timer.keepaliveAcknowledgement = NewStoppedTimer()
|
|
||||||
peer.timer.zeroAllKeys = NewStoppedTimer()
|
peer.timer.zeroAllKeys = NewStoppedTimer()
|
||||||
|
|
||||||
peer.flags.keepaliveWaiting = AtomicFalse
|
peer.flags.keepaliveWaiting = AtomicFalse
|
||||||
|
@ -115,6 +117,15 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) String() string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"peer(%d %s %s)",
|
||||||
|
peer.id,
|
||||||
|
peer.endpoint.String(),
|
||||||
|
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func (peer *Peer) Close() {
|
func (peer *Peer) Close() {
|
||||||
close(peer.signal.stop)
|
close(peer.signal.stop)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -362,7 +364,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logDebug.Println("Creating response...")
|
logDebug.Println("Creating response message for", peer.String())
|
||||||
|
|
||||||
outElem := device.NewOutboundElement()
|
outElem := device.NewOutboundElement()
|
||||||
writer := bytes.NewBuffer(outElem.data[:0])
|
writer := bytes.NewBuffer(outElem.data[:0])
|
||||||
|
@ -416,6 +418,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||||
var elem *QueueInboundElement
|
var elem *QueueInboundElement
|
||||||
|
|
||||||
device := peer.device
|
device := peer.device
|
||||||
|
|
||||||
|
logInfo := device.log.Info
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
|
logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
|
||||||
|
|
||||||
|
@ -450,7 +454,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||||
|
|
||||||
peer.KeepKeyFreshReceiving()
|
peer.KeepKeyFreshReceiving()
|
||||||
|
|
||||||
// check if confirming handshake
|
// check if using new key-pair
|
||||||
|
|
||||||
kp := &peer.keyPairs
|
kp := &peer.keyPairs
|
||||||
kp.mutex.Lock()
|
kp.mutex.Lock()
|
||||||
|
@ -465,17 +469,18 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||||
// check for keep-alive
|
// check for keep-alive
|
||||||
|
|
||||||
if len(elem.packet) == 0 {
|
if len(elem.packet) == 0 {
|
||||||
|
logDebug.Println("Received keep-alive from", peer.String())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// verify source and strip padding
|
// verify source and strip padding
|
||||||
|
|
||||||
switch elem.packet[0] >> 4 {
|
switch elem.packet[0] >> 4 {
|
||||||
case IPv4version:
|
case ipv4.Version:
|
||||||
|
|
||||||
// strip padding
|
// strip padding
|
||||||
|
|
||||||
if len(elem.packet) < IPv4headerSize {
|
if len(elem.packet) < ipv4.HeaderLen {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -487,31 +492,33 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||||
|
|
||||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||||
if device.routingTable.LookupIPv4(dst) != peer {
|
if device.routingTable.LookupIPv4(dst) != peer {
|
||||||
|
logInfo.Println("Packet with unallowed source IP from", peer.String())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
case IPv6version:
|
case ipv6.Version:
|
||||||
|
|
||||||
// strip padding
|
// strip padding
|
||||||
|
|
||||||
if len(elem.packet) < IPv6headerSize {
|
if len(elem.packet) < ipv6.HeaderLen {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||||
length := binary.BigEndian.Uint16(field)
|
length := binary.BigEndian.Uint16(field)
|
||||||
length += IPv6headerSize
|
length += ipv6.HeaderLen
|
||||||
elem.packet = elem.packet[:length]
|
elem.packet = elem.packet[:length]
|
||||||
|
|
||||||
// verify IPv6 source
|
// verify IPv6 source
|
||||||
|
|
||||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||||
if device.routingTable.LookupIPv6(dst) != peer {
|
if device.routingTable.LookupIPv6(dst) != peer {
|
||||||
|
logInfo.Println("Packet with unallowed source IP from", peer.String())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logDebug.Println("Receieved packet with unknown IP version")
|
logInfo.Println("Packet with invalid IP version from", peer.String())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -522,6 +529,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
|
func (device *Device) RoutineWriteToTUN(tun TUNDevice) {
|
||||||
|
|
||||||
logError := device.log.Error
|
logError := device.log.Error
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
logDebug.Println("Routine, sequential tun writer, started")
|
logDebug.Println("Routine, sequential tun writer, started")
|
||||||
|
|
69
src/send.go
69
src/send.go
|
@ -3,6 +3,8 @@ package main
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -21,28 +23,26 @@ import (
|
||||||
* The functions in this file occure (roughly) in the order packets are processed.
|
* The functions in this file occure (roughly) in the order packets are processed.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/* A work unit
|
/* The sequential consumers will attempt to take the lock,
|
||||||
*
|
* workers release lock when they have completed work (encryption) on the packet.
|
||||||
* The sequential consumers will attempt to take the lock,
|
|
||||||
* workers release lock when they have completed work on the packet.
|
|
||||||
*
|
*
|
||||||
* If the element is inserted into the "encryption queue",
|
* If the element is inserted into the "encryption queue",
|
||||||
* the content is preceeded by enough "junk" to contain the header
|
* the content is preceeded by enough "junk" to contain the transport header
|
||||||
* (to allow the construction of transport messages in-place)
|
* (to allow the construction of transport messages in-place)
|
||||||
*/
|
*/
|
||||||
type QueueOutboundElement struct {
|
type QueueOutboundElement struct {
|
||||||
dropped int32
|
dropped int32
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
data [MaxMessageSize]byte
|
data [MaxMessageSize]byte // slice holding the packet data
|
||||||
packet []byte // slice of "data" (always!)
|
packet []byte // slice of "data" (always!)
|
||||||
nonce uint64 // nonce for encryption
|
nonce uint64 // nonce for encryption
|
||||||
keyPair *KeyPair // key-pair for encryption
|
keyPair *KeyPair // key-pair for encryption
|
||||||
peer *Peer // related peer
|
peer *Peer // related peer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) FlushNonceQueue() {
|
func (peer *Peer) FlushNonceQueue() {
|
||||||
elems := len(peer.queue.nonce)
|
elems := len(peer.queue.nonce)
|
||||||
for i := 0; i < elems; i += 1 {
|
for i := 0; i < elems; i++ {
|
||||||
select {
|
select {
|
||||||
case <-peer.queue.nonce:
|
case <-peer.queue.nonce:
|
||||||
default:
|
default:
|
||||||
|
@ -111,14 +111,18 @@ func addToEncryptionQueue(
|
||||||
* Obs. Single instance per TUN device
|
* Obs. Single instance per TUN device
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
|
func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
|
||||||
|
|
||||||
if tun == nil {
|
if tun == nil {
|
||||||
// dummy
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
elem := device.NewOutboundElement()
|
elem := device.NewOutboundElement()
|
||||||
|
|
||||||
device.log.Debug.Println("Routine, TUN Reader: started")
|
logDebug := device.log.Debug
|
||||||
|
logError := device.log.Error
|
||||||
|
|
||||||
|
logDebug.Println("Routine, TUN Reader: started")
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// read packet
|
// read packet
|
||||||
|
|
||||||
|
@ -129,12 +133,17 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
|
||||||
elem.packet = elem.data[MessageTransportHeaderSize:]
|
elem.packet = elem.data[MessageTransportHeaderSize:]
|
||||||
size, err := tun.Read(elem.packet)
|
size, err := tun.Read(elem.packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
device.log.Error.Println("Failed to read packet from TUN device:", err)
|
|
||||||
continue
|
// stop process
|
||||||
|
|
||||||
|
logError.Println("Failed to read packet from TUN device:", err)
|
||||||
|
device.Close()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
elem.packet = elem.packet[:size]
|
elem.packet = elem.packet[:size]
|
||||||
if len(elem.packet) < IPv4headerSize {
|
if len(elem.packet) < ipv4.HeaderLen {
|
||||||
device.log.Error.Println("Packet too short, length:", size)
|
logError.Println("Packet too short, length:", size)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,23 +151,24 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
|
||||||
|
|
||||||
var peer *Peer
|
var peer *Peer
|
||||||
switch elem.packet[0] >> 4 {
|
switch elem.packet[0] >> 4 {
|
||||||
case IPv4version:
|
case ipv4.Version:
|
||||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||||
peer = device.routingTable.LookupIPv4(dst)
|
peer = device.routingTable.LookupIPv4(dst)
|
||||||
|
|
||||||
case IPv6version:
|
case ipv6.Version:
|
||||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||||
peer = device.routingTable.LookupIPv6(dst)
|
peer = device.routingTable.LookupIPv6(dst)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
device.log.Debug.Println("Receieved packet with unknown IP version")
|
logDebug.Println("Receieved packet with unknown IP version")
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.endpoint == nil {
|
if peer.endpoint == nil {
|
||||||
device.log.Debug.Println("No known endpoint for peer", peer.id)
|
logDebug.Println("No known endpoint for peer", peer.String())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -184,7 +194,7 @@ func (peer *Peer) RoutineNonce() {
|
||||||
|
|
||||||
device := peer.device
|
device := peer.device
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
logDebug.Println("Routine, nonce worker, started for peer", peer.id)
|
logDebug.Println("Routine, nonce worker, started for peer", peer.String())
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
|
|
||||||
|
@ -216,15 +226,15 @@ func (peer *Peer) RoutineNonce() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
signalSend(peer.signal.handshakeBegin)
|
signalSend(peer.signal.handshakeBegin)
|
||||||
logDebug.Println("Waiting for key-pair, peer", peer.id)
|
logDebug.Println("Awaiting key-pair for", peer.String())
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-peer.signal.newKeyPair:
|
case <-peer.signal.newKeyPair:
|
||||||
logDebug.Println("Key-pair negotiated for peer", peer.id)
|
logDebug.Println("Key-pair negotiated for", peer.String())
|
||||||
goto NextPacket
|
goto NextPacket
|
||||||
|
|
||||||
case <-peer.signal.flushNonceQueue:
|
case <-peer.signal.flushNonceQueue:
|
||||||
logDebug.Println("Clearing queue for peer", peer.id)
|
logDebug.Println("Clearing queue for", peer.String())
|
||||||
peer.FlushNonceQueue()
|
peer.FlushNonceQueue()
|
||||||
elem = nil
|
elem = nil
|
||||||
goto NextPacket
|
goto NextPacket
|
||||||
|
@ -313,13 +323,14 @@ func (peer *Peer) RoutineSequentialSender() {
|
||||||
device := peer.device
|
device := peer.device
|
||||||
|
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
logDebug.Println("Routine, sequential sender, started for peer", peer.id)
|
logDebug.Println("Routine, sequential sender, started for", peer.String())
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-peer.signal.stop:
|
case <-peer.signal.stop:
|
||||||
logDebug.Println("Routine, sequential sender, stopped for peer", peer.id)
|
logDebug.Println("Routine, sequential sender, stopped for", peer.String())
|
||||||
return
|
return
|
||||||
|
|
||||||
case work := <-peer.queue.outbound:
|
case work := <-peer.queue.outbound:
|
||||||
work.mutex.Lock()
|
work.mutex.Lock()
|
||||||
if work.IsDropped() {
|
if work.IsDropped() {
|
||||||
|
@ -334,7 +345,7 @@ func (peer *Peer) RoutineSequentialSender() {
|
||||||
defer peer.mutex.RUnlock()
|
defer peer.mutex.RUnlock()
|
||||||
|
|
||||||
if peer.endpoint == nil {
|
if peer.endpoint == nil {
|
||||||
logDebug.Println("No endpoint for peer:", peer.id)
|
logDebug.Println("No endpoint for", peer.String())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -352,7 +363,7 @@ func (peer *Peer) RoutineSequentialSender() {
|
||||||
}
|
}
|
||||||
atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
|
atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
|
||||||
|
|
||||||
// reset keep-alive (passive keep-alives / acknowledgements)
|
// reset keep-alive
|
||||||
|
|
||||||
peer.TimerResetKeepalive()
|
peer.TimerResetKeepalive()
|
||||||
}()
|
}()
|
||||||
|
|
|
@ -50,7 +50,7 @@ func (peer *Peer) KeepKeyFreshReceiving() {
|
||||||
* - First transport message under the "next" key
|
* - First transport message under the "next" key
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) EventHandshakeComplete() {
|
func (peer *Peer) EventHandshakeComplete() {
|
||||||
peer.device.log.Debug.Println("Handshake completed")
|
peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
|
||||||
peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
|
peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
|
||||||
signalSend(peer.signal.handshakeCompleted)
|
signalSend(peer.signal.handshakeCompleted)
|
||||||
}
|
}
|
||||||
|
@ -112,7 +112,7 @@ func (peer *Peer) TimerResetKeepalive() {
|
||||||
|
|
||||||
// stop acknowledgement timer
|
// stop acknowledgement timer
|
||||||
|
|
||||||
timerStop(peer.timer.keepaliveAcknowledgement)
|
timerStop(peer.timer.keepalivePassive)
|
||||||
atomic.StoreInt32(&peer.flags.keepaliveWaiting, AtomicFalse)
|
atomic.StoreInt32(&peer.flags.keepaliveWaiting, AtomicFalse)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -140,7 +140,7 @@ func (peer *Peer) RoutineTimerHandler() {
|
||||||
device := peer.device
|
device := peer.device
|
||||||
|
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
logDebug.Println("Routine, timer handler, started for peer", peer.id)
|
logDebug.Println("Routine, timer handler, started for peer", peer.String())
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
@ -152,14 +152,14 @@ func (peer *Peer) RoutineTimerHandler() {
|
||||||
|
|
||||||
case <-peer.timer.keepalivePersistent.C:
|
case <-peer.timer.keepalivePersistent.C:
|
||||||
|
|
||||||
logDebug.Println("Sending persistent keep-alive to peer", peer.id)
|
logDebug.Println("Sending persistent keep-alive to", peer.String())
|
||||||
|
|
||||||
peer.SendKeepAlive()
|
peer.SendKeepAlive()
|
||||||
peer.TimerResetKeepalive()
|
peer.TimerResetKeepalive()
|
||||||
|
|
||||||
case <-peer.timer.keepaliveAcknowledgement.C:
|
case <-peer.timer.keepalivePassive.C:
|
||||||
|
|
||||||
logDebug.Println("Sending passive persistent keep-alive to peer", peer.id)
|
logDebug.Println("Sending passive persistent keep-alive to", peer.String())
|
||||||
|
|
||||||
peer.SendKeepAlive()
|
peer.SendKeepAlive()
|
||||||
peer.TimerResetKeepalive()
|
peer.TimerResetKeepalive()
|
||||||
|
@ -168,7 +168,7 @@ func (peer *Peer) RoutineTimerHandler() {
|
||||||
|
|
||||||
case <-peer.timer.zeroAllKeys.C:
|
case <-peer.timer.zeroAllKeys.C:
|
||||||
|
|
||||||
logDebug.Println("Clearing all key material for peer", peer.id)
|
logDebug.Println("Clearing all key material for", peer.String())
|
||||||
|
|
||||||
// zero out key pairs
|
// zero out key pairs
|
||||||
|
|
||||||
|
@ -208,14 +208,12 @@ func (peer *Peer) RoutineHandshakeInitiator() {
|
||||||
|
|
||||||
var elem *QueueOutboundElement
|
var elem *QueueOutboundElement
|
||||||
|
|
||||||
|
logInfo := device.log.Info
|
||||||
logError := device.log.Error
|
logError := device.log.Error
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
logDebug.Println("Routine, handshake initator, started for peer", peer.id)
|
logDebug.Println("Routine, handshake initator, started for", peer.String())
|
||||||
|
|
||||||
for run := true; run; {
|
for {
|
||||||
var err error
|
|
||||||
var attempts uint
|
|
||||||
var deadline time.Time
|
|
||||||
|
|
||||||
// wait for signal
|
// wait for signal
|
||||||
|
|
||||||
|
@ -227,15 +225,17 @@ func (peer *Peer) RoutineHandshakeInitiator() {
|
||||||
|
|
||||||
// wait for handshake
|
// wait for handshake
|
||||||
|
|
||||||
run = func() bool {
|
func() {
|
||||||
for {
|
var err error
|
||||||
|
var deadline time.Time
|
||||||
|
for attempts := uint(1); ; attempts++ {
|
||||||
|
|
||||||
// clear completed signal
|
// clear completed signal
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-peer.signal.handshakeCompleted:
|
case <-peer.signal.handshakeCompleted:
|
||||||
case <-peer.signal.stop:
|
case <-peer.signal.stop:
|
||||||
return false
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -246,43 +246,39 @@ func (peer *Peer) RoutineHandshakeInitiator() {
|
||||||
}
|
}
|
||||||
elem, err = peer.BeginHandshakeInitiation()
|
elem, err = peer.BeginHandshakeInitiation()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to create initiation message:", err)
|
logError.Println("Failed to create initiation message", err, "for", peer.String())
|
||||||
break
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// set timeout
|
// set timeout
|
||||||
|
|
||||||
attempts += 1
|
|
||||||
if attempts == 1 {
|
if attempts == 1 {
|
||||||
deadline = time.Now().Add(MaxHandshakeAttemptTime)
|
deadline = time.Now().Add(MaxHandshakeAttemptTime)
|
||||||
}
|
}
|
||||||
timeout := time.NewTimer(RekeyTimeout)
|
timeout := time.NewTimer(RekeyTimeout)
|
||||||
logDebug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id)
|
logDebug.Println("Handshake initiation attempt", attempts, "queued for", peer.String())
|
||||||
|
|
||||||
// wait for handshake or timeout
|
// wait for handshake or timeout
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|
||||||
case <-peer.signal.stop:
|
case <-peer.signal.stop:
|
||||||
return true
|
return
|
||||||
|
|
||||||
case <-peer.signal.handshakeCompleted:
|
case <-peer.signal.handshakeCompleted:
|
||||||
<-timeout.C
|
<-timeout.C
|
||||||
return true
|
return
|
||||||
|
|
||||||
case <-timeout.C:
|
case <-timeout.C:
|
||||||
logDebug.Println("Timeout")
|
|
||||||
|
|
||||||
// check if sufficient time for retry
|
|
||||||
|
|
||||||
if deadline.Before(time.Now().Add(RekeyTimeout)) {
|
if deadline.Before(time.Now().Add(RekeyTimeout)) {
|
||||||
|
logInfo.Println("Handshake negotiation timed out for", peer.String())
|
||||||
signalSend(peer.signal.flushNonceQueue)
|
signalSend(peer.signal.flushNonceQueue)
|
||||||
timerStop(peer.timer.keepalivePersistent)
|
timerStop(peer.timer.keepalivePersistent)
|
||||||
timerStop(peer.timer.keepaliveAcknowledgement)
|
timerStop(peer.timer.keepalivePassive)
|
||||||
return true
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
signalClear(peer.signal.handshakeBegin)
|
signalClear(peer.signal.handshakeBegin)
|
||||||
|
|
19
src/trie.go
19
src/trie.go
|
@ -23,7 +23,8 @@ type Trie struct {
|
||||||
bits []byte
|
bits []byte
|
||||||
peer *Peer
|
peer *Peer
|
||||||
|
|
||||||
// Index of "branching" bit
|
// index of "branching" bit
|
||||||
|
|
||||||
bit_at_byte uint
|
bit_at_byte uint
|
||||||
bit_at_shift uint
|
bit_at_shift uint
|
||||||
}
|
}
|
||||||
|
@ -36,7 +37,7 @@ type Trie struct {
|
||||||
func commonBits(ip1 net.IP, ip2 net.IP) uint {
|
func commonBits(ip1 net.IP, ip2 net.IP) uint {
|
||||||
var i uint
|
var i uint
|
||||||
size := uint(len(ip1))
|
size := uint(len(ip1))
|
||||||
for i = 0; i < size; i += 1 {
|
for i = 0; i < size; i++ {
|
||||||
v := ip1[i] ^ ip2[i]
|
v := ip1[i] ^ ip2[i]
|
||||||
if v != 0 {
|
if v != 0 {
|
||||||
v >>= 1
|
v >>= 1
|
||||||
|
@ -84,7 +85,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
// Walk recursivly
|
// walk recursivly
|
||||||
|
|
||||||
node.child[0] = node.child[0].RemovePeer(p)
|
node.child[0] = node.child[0].RemovePeer(p)
|
||||||
node.child[1] = node.child[1].RemovePeer(p)
|
node.child[1] = node.child[1].RemovePeer(p)
|
||||||
|
@ -93,7 +94,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove peer & merge
|
// remove peer & merge
|
||||||
|
|
||||||
node.peer = nil
|
node.peer = nil
|
||||||
if node.child[0] == nil {
|
if node.child[0] == nil {
|
||||||
|
@ -108,7 +109,7 @@ func (node *Trie) choose(ip net.IP) byte {
|
||||||
|
|
||||||
func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
|
func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
|
||||||
|
|
||||||
// At leaf
|
// at leaf
|
||||||
|
|
||||||
if node == nil {
|
if node == nil {
|
||||||
return &Trie{
|
return &Trie{
|
||||||
|
@ -120,7 +121,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Traverse deeper
|
// traverse deeper
|
||||||
|
|
||||||
common := commonBits(node.bits, ip)
|
common := commonBits(node.bits, ip)
|
||||||
if node.cidr <= cidr && common >= node.cidr {
|
if node.cidr <= cidr && common >= node.cidr {
|
||||||
|
@ -133,7 +134,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
// Split node
|
// split node
|
||||||
|
|
||||||
newNode := &Trie{
|
newNode := &Trie{
|
||||||
bits: ip,
|
bits: ip,
|
||||||
|
@ -145,7 +146,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
|
||||||
|
|
||||||
cidr = min(cidr, common)
|
cidr = min(cidr, common)
|
||||||
|
|
||||||
// Check for shorter prefix
|
// check for shorter prefix
|
||||||
|
|
||||||
if newNode.cidr == cidr {
|
if newNode.cidr == cidr {
|
||||||
bit := newNode.choose(node.bits)
|
bit := newNode.choose(node.bits)
|
||||||
|
@ -153,7 +154,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
|
||||||
return newNode
|
return newNode
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new parent for node & newNode
|
// create new parent for node & newNode
|
||||||
|
|
||||||
parent := &Trie{
|
parent := &Trie{
|
||||||
bits: ip,
|
bits: ip,
|
||||||
|
|
Loading…
Reference in a new issue