Rewrite timers and related state machines

This commit is contained in:
Jason A. Donenfeld 2018-05-07 22:27:03 +02:00
parent 375dcbd4ae
commit 233f079a94
14 changed files with 453 additions and 602 deletions

View file

@ -12,21 +12,18 @@ import (
/* Specification constants */ /* Specification constants */
const ( const (
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
RejectAfterMessages = (1 << 64) - (1 << 4) - 1 RejectAfterMessages = (1 << 64) - (1 << 4) - 1
RekeyAfterTime = time.Second * 120 RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90 RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5 RekeyTimeout = time.Second * 5
RejectAfterTime = time.Second * 180 MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */
KeepaliveTimeout = time.Second * 10 RekeyTimeoutJitterMaxMs = 334
CookieRefreshTime = time.Second * 120 RejectAfterTime = time.Second * 180
HandshakeInitationRate = time.Second / 20 KeepaliveTimeout = time.Second * 10
PaddingMultiple = 16 CookieRefreshTime = time.Second * 120
) HandshakeInitationRate = time.Second / 20
PaddingMultiple = 16
const (
RekeyAfterTimeReceiving = RejectAfterTime - KeepaliveTimeout - RekeyTimeout
NewHandshakeTime = KeepaliveTimeout + RekeyTimeout // upon failure to acknowledge transport message
) )
/* Implementation specific constants */ /* Implementation specific constants */

View file

@ -74,8 +74,8 @@ type Device struct {
handshake chan QueueHandshakeElement handshake chan QueueHandshakeElement
} }
signal struct { signals struct {
stop Signal stop chan struct{}
} }
tun struct { tun struct {
@ -302,7 +302,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
// prepare signals // prepare signals
device.signal.stop = NewSignal() device.signals.stop = make(chan struct{}, 1)
// prepare net // prepare net
@ -400,7 +400,7 @@ func (device *Device) Close() {
device.isUp.Set(false) device.isUp.Set(false)
device.signal.stop.Broadcast() close(device.signals.stop)
device.state.stopping.Wait() device.state.stopping.Wait()
device.FlushPacketQueues() device.FlushPacketQueues()
@ -413,5 +413,5 @@ func (device *Device) Close() {
} }
func (device *Device) Wait() chan struct{} { func (device *Device) Wait() chan struct{} {
return device.signal.stop.Wait() return device.signals.stop
} }

View file

@ -1,43 +0,0 @@
package main
import (
"sync/atomic"
"time"
)
type Event struct {
guard int32
next time.Time
interval time.Duration
C chan struct{}
}
func newEvent(interval time.Duration) *Event {
return &Event{
guard: 0,
next: time.Now(),
interval: interval,
C: make(chan struct{}, 1),
}
}
func (e *Event) Clear() {
select {
case <-e.C:
default:
}
}
func (e *Event) Fire() {
if e == nil || atomic.SwapInt32(&e.guard, 1) != 0 {
return
}
if now := time.Now(); now.After(e.next) {
select {
case e.C <- struct{}{}:
default:
}
e.next = now.Add(e.interval)
}
atomic.StoreInt32(&e.guard, 0)
}

View file

@ -18,7 +18,7 @@ import (
type IndexTableEntry struct { type IndexTableEntry struct {
peer *Peer peer *Peer
handshake *Handshake handshake *Handshake
keyPair *KeyPair keyPair *Keypair
} }
type IndexTable struct { type IndexTable struct {

View file

@ -18,7 +18,7 @@ import (
* we plan to resolve this issue; whenever Go allows us to do so. * we plan to resolve this issue; whenever Go allows us to do so.
*/ */
type KeyPair struct { type Keypair struct {
sendNonce uint64 sendNonce uint64
send cipher.AEAD send cipher.AEAD
receive cipher.AEAD receive cipher.AEAD
@ -29,20 +29,20 @@ type KeyPair struct {
remoteIndex uint32 remoteIndex uint32
} }
type KeyPairs struct { type Keypairs struct {
mutex sync.RWMutex mutex sync.RWMutex
current *KeyPair current *Keypair
previous *KeyPair previous *Keypair
next *KeyPair // not yet "confirmed by transport" next *Keypair // not yet "confirmed by transport"
} }
func (kp *KeyPairs) Current() *KeyPair { func (kp *Keypairs) Current() *Keypair {
kp.mutex.RLock() kp.mutex.RLock()
defer kp.mutex.RUnlock() defer kp.mutex.RUnlock()
return kp.current return kp.current
} }
func (device *Device) DeleteKeyPair(key *KeyPair) { func (device *Device) DeleteKeypair(key *Keypair) {
if key != nil { if key != nil {
device.indices.Delete(key.localIndex) device.indices.Delete(key.localIndex)
} }

15
main.go
View file

@ -30,6 +30,8 @@ func printUsage() {
} }
func warning() { func warning() {
shouldQuit := false
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING") fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W This is alpha software. It will very likely not G") fmt.Fprintln(os.Stderr, "W This is alpha software. It will very likely not G")
@ -37,6 +39,8 @@ func warning() {
fmt.Fprintln(os.Stderr, "W horribly wrong. You have been warned. Proceed G") fmt.Fprintln(os.Stderr, "W horribly wrong. You have been warned. Proceed G")
fmt.Fprintln(os.Stderr, "W at your own risk. G") fmt.Fprintln(os.Stderr, "W at your own risk. G")
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
shouldQuit = os.Getenv("WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD") != "1"
fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W Furthermore, you are running this software on a G") fmt.Fprintln(os.Stderr, "W Furthermore, you are running this software on a G")
fmt.Fprintln(os.Stderr, "W Linux kernel, which is probably unnecessary and G") fmt.Fprintln(os.Stderr, "W Linux kernel, which is probably unnecessary and G")
@ -46,9 +50,20 @@ func warning() {
fmt.Fprintln(os.Stderr, "W program. For more information on installing the G") fmt.Fprintln(os.Stderr, "W program. For more information on installing the G")
fmt.Fprintln(os.Stderr, "W kernel module, please visit: G") fmt.Fprintln(os.Stderr, "W kernel module, please visit: G")
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G") fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
if shouldQuit {
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W If you still want to use this program, against G")
fmt.Fprintln(os.Stderr, "W the sage advice here, please first export this G")
fmt.Fprintln(os.Stderr, "W environment variable: G")
fmt.Fprintln(os.Stderr, "W WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD=1 G")
}
} }
fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING") fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
if shouldQuit {
os.Exit(1)
}
} }
func main() { func main() {

View file

@ -1,6 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2015-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*/ */
package main package main
@ -488,7 +488,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
/* Derives a new key-pair from the current handshake state /* Derives a new key-pair from the current handshake state
* *
*/ */
func (peer *Peer) NewKeyPair() *KeyPair { func (peer *Peer) NewKeypair() *Keypair {
device := peer.device device := peer.device
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
@ -528,7 +528,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// create AEAD instances // create AEAD instances
keyPair := new(KeyPair) keyPair := new(Keypair)
keyPair.send, _ = chacha20poly1305.New(sendKey[:]) keyPair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.receive, _ = chacha20poly1305.New(recvKey[:]) keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
@ -559,24 +559,27 @@ func (peer *Peer) NewKeyPair() *KeyPair {
kp := &peer.keyPairs kp := &peer.keyPairs
kp.mutex.Lock() kp.mutex.Lock()
peer.timersSessionDerived()
previous := kp.previous
next := kp.next
current := kp.current
if isInitiator { if isInitiator {
if kp.previous != nil { if next != nil {
device.DeleteKeyPair(kp.previous) kp.next = nil
kp.previous = nil kp.previous = next
} device.DeleteKeypair(current)
if kp.next != nil {
kp.previous = kp.next
kp.next = keyPair
} else { } else {
kp.previous = kp.current kp.previous = current
kp.current = keyPair
peer.event.newKeyPair.Fire()
} }
device.DeleteKeypair(previous)
kp.current = keyPair
} else { } else {
kp.next = keyPair kp.next = keyPair
device.DeleteKeypair(next)
kp.previous = nil kp.previous = nil
device.DeleteKeypair(previous)
} }
kp.mutex.Unlock() kp.mutex.Unlock()

View file

@ -102,8 +102,8 @@ func TestNoiseHandshake(t *testing.T) {
t.Log("deriving keys") t.Log("deriving keys")
key1 := peer1.NewKeyPair() key1 := peer1.NewKeypair()
key2 := peer2.NewKeyPair() key2 := peer2.NewKeypair()
if key1 == nil { if key1 == nil {
t.Fatal("failed to dervice key-pair for peer 1") t.Fatal("failed to dervice key-pair for peer 1")

78
peer.go
View file

@ -14,14 +14,13 @@ import (
) )
const ( const (
PeerRoutineNumber = 4 PeerRoutineNumber = 3
EventInterval = 10 * time.Millisecond
) )
type Peer struct { type Peer struct {
isRunning AtomicBool isRunning AtomicBool
mutex sync.RWMutex mutex sync.RWMutex
keyPairs KeyPairs keyPairs Keypairs
handshake Handshake handshake Handshake
device *Device device *Device
endpoint Endpoint endpoint Endpoint
@ -34,34 +33,28 @@ type Peer struct {
lastHandshakeNano int64 // nano seconds since epoch lastHandshakeNano int64 // nano seconds since epoch
} }
time struct { timers struct {
mutex sync.RWMutex retransmitHandshake *Timer
lastSend time.Time // last send message sendKeepalive *Timer
lastHandshake time.Time // last completed handshake newHandshake *Timer
nextKeepalive time.Time zeroKeyMaterial *Timer
persistentKeepalive *Timer
handshakeAttempts uint
needAnotherKeepalive bool
sentLastMinuteHandshake bool
lastSentHandshake time.Time
} }
event struct { signals struct {
dataSent *Event newKeypairArrived chan struct{}
dataReceived *Event flushNonceQueue chan struct{}
anyAuthenticatedPacketReceived *Event
anyAuthenticatedPacketTraversal *Event
handshakeCompleted *Event
handshakePushDeadline *Event
handshakeBegin *Event
ephemeralKeyCreated *Event
newKeyPair *Event
flushNonceQueue *Event
}
timer struct {
sendLastMinuteHandshake AtomicBool
} }
queue struct { queue struct {
nonce chan *QueueOutboundElement // nonce / pre-handshake queue nonce chan *QueueOutboundElement // nonce / pre-handshake queue
outbound chan *QueueOutboundElement // sequential ordering of work outbound chan *QueueOutboundElement // sequential ordering of work
inbound chan *QueueInboundElement // sequential ordering of work inbound chan *QueueInboundElement // sequential ordering of work
packetInNonceQueueIsAwaitingKey bool
} }
routines struct { routines struct {
@ -188,6 +181,8 @@ func (peer *Peer) Start() {
peer.routines.starting.Wait() peer.routines.starting.Wait()
peer.routines.stopping.Wait() peer.routines.stopping.Wait()
peer.routines.stop = make(chan struct{}) peer.routines.stop = make(chan struct{})
peer.routines.starting.Add(PeerRoutineNumber)
peer.routines.stopping.Add(PeerRoutineNumber)
// prepare queues // prepare queues
@ -195,28 +190,13 @@ func (peer *Peer) Start() {
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
// events peer.timersInit()
peer.signals.newKeypairArrived = make(chan struct{}, 1)
peer.event.dataSent = newEvent(EventInterval) peer.signals.flushNonceQueue = make(chan struct{}, 1)
peer.event.dataReceived = newEvent(EventInterval)
peer.event.anyAuthenticatedPacketReceived = newEvent(EventInterval)
peer.event.anyAuthenticatedPacketTraversal = newEvent(EventInterval)
peer.event.handshakeCompleted = newEvent(EventInterval)
peer.event.handshakePushDeadline = newEvent(EventInterval)
peer.event.handshakeBegin = newEvent(EventInterval)
peer.event.ephemeralKeyCreated = newEvent(EventInterval)
peer.event.newKeyPair = newEvent(EventInterval)
peer.event.flushNonceQueue = newEvent(EventInterval)
peer.isRunning.Set(true)
// wait for routines to start // wait for routines to start
peer.routines.starting.Add(PeerRoutineNumber)
peer.routines.stopping.Add(PeerRoutineNumber)
go peer.RoutineNonce() go peer.RoutineNonce()
go peer.RoutineTimerHandler()
go peer.RoutineSequentialSender() go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver() go peer.RoutineSequentialReceiver()
@ -238,6 +218,8 @@ func (peer *Peer) Stop() {
device := peer.device device := peer.device
device.log.Debug.Println(peer, ": Stopping...") device.log.Debug.Println(peer, ": Stopping...")
peer.timersStop()
// stop & wait for ongoing peer routines // stop & wait for ongoing peer routines
peer.routines.starting.Wait() peer.routines.starting.Wait()
@ -255,9 +237,9 @@ func (peer *Peer) Stop() {
kp := &peer.keyPairs kp := &peer.keyPairs
kp.mutex.Lock() kp.mutex.Lock()
device.DeleteKeyPair(kp.previous) device.DeleteKeypair(kp.previous)
device.DeleteKeyPair(kp.current) device.DeleteKeypair(kp.current)
device.DeleteKeyPair(kp.next) device.DeleteKeypair(kp.next)
kp.previous = nil kp.previous = nil
kp.current = nil kp.current = nil
@ -271,4 +253,6 @@ func (peer *Peer) Stop() {
device.indices.Delete(hs.localIndex) device.indices.Delete(hs.localIndex)
hs.Clear() hs.Clear()
hs.mutex.Unlock() hs.mutex.Unlock()
peer.FlushNonceQueue()
} }

View file

@ -31,7 +31,7 @@ type QueueInboundElement struct {
buffer *[MaxMessageSize]byte buffer *[MaxMessageSize]byte
packet []byte packet []byte
counter uint64 counter uint64
keyPair *KeyPair keyPair *Keypair
endpoint Endpoint endpoint Endpoint
} }
@ -99,6 +99,21 @@ func (device *Device) addToHandshakeQueue(
} }
} }
/* Called when a new authenticated message has been received
*
* NOTE: Not thread safe, but called by sequential receiver!
*/
func (peer *Peer) keepKeyFreshReceiving() {
if peer.timers.sentLastMinuteHandshake {
return
}
kp := peer.keyPairs.Current()
if kp != nil && kp.isInitiator && time.Now().Sub(kp.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake = true
peer.SendHandshakeInitiation(false)
}
}
/* Receives incoming datagrams for the device /* Receives incoming datagrams for the device
* *
* Every time the bind is updated a new routine is started for * Every time the bind is updated a new routine is started for
@ -245,7 +260,7 @@ func (device *Device) RoutineDecryption() {
for { for {
select { select {
case <-device.signal.stop.Wait(): case <-device.signals.stop:
return return
case elem, ok := <-device.queue.decryption: case elem, ok := <-device.queue.decryption:
@ -317,7 +332,7 @@ func (device *Device) RoutineHandshake() {
for { for {
select { select {
case elem, ok = <-device.queue.handshake: case elem, ok = <-device.queue.handshake:
case <-device.signal.stop.Wait(): case <-device.signals.stop:
return return
} }
@ -441,8 +456,8 @@ func (device *Device) RoutineHandshake() {
// update timers // update timers
peer.event.anyAuthenticatedPacketTraversal.Fire() peer.timersAnyAuthenticatedPacketTraversal()
peer.event.anyAuthenticatedPacketReceived.Fire() peer.timersAnyAuthenticatedPacketReceived()
// update endpoint // update endpoint
@ -460,10 +475,11 @@ func (device *Device) RoutineHandshake() {
continue continue
} }
peer.TimerEphemeralKeyCreated() if peer.NewKeypair() == nil {
peer.NewKeyPair() continue
}
logDebug.Println(peer, ": Creating handshake response") logDebug.Println(peer, ": Sending handshake response")
writer := bytes.NewBuffer(temp[:0]) writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, response) binary.Write(writer, binary.LittleEndian, response)
@ -472,9 +488,10 @@ func (device *Device) RoutineHandshake() {
// send response // send response
peer.timers.lastSentHandshake = time.Now()
err = peer.SendBuffer(packet) err = peer.SendBuffer(packet)
if err == nil { if err == nil {
peer.event.anyAuthenticatedPacketTraversal.Fire() peer.timersAnyAuthenticatedPacketTraversal()
} else { } else {
logError.Println(peer, ": Failed to send handshake response", err) logError.Println(peer, ": Failed to send handshake response", err)
} }
@ -510,18 +527,23 @@ func (device *Device) RoutineHandshake() {
logDebug.Println(peer, ": Received handshake response") logDebug.Println(peer, ": Received handshake response")
peer.TimerEphemeralKeyCreated()
// update timers // update timers
peer.event.anyAuthenticatedPacketTraversal.Fire() peer.timersAnyAuthenticatedPacketTraversal()
peer.event.anyAuthenticatedPacketReceived.Fire() peer.timersAnyAuthenticatedPacketReceived()
peer.event.handshakeCompleted.Fire()
// derive key-pair // derive key-pair
peer.NewKeyPair() if peer.NewKeypair() == nil {
peer.SendKeepAlive() continue
}
peer.timersHandshakeComplete()
peer.SendKeepalive()
select {
case peer.signals.newKeypairArrived <- struct{}{}:
default:
}
} }
} }
} }
@ -569,38 +591,41 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue continue
} }
peer.event.anyAuthenticatedPacketTraversal.Fire()
peer.event.anyAuthenticatedPacketReceived.Fire()
peer.KeepKeyFreshReceiving()
// check if using new key-pair
kp := &peer.keyPairs
kp.mutex.Lock()
if kp.next == elem.keyPair {
peer.event.handshakeCompleted.Fire()
if kp.previous != nil {
device.DeleteKeyPair(kp.previous)
}
kp.previous = kp.current
kp.current = kp.next
kp.next = nil
}
kp.mutex.Unlock()
// update endpoint // update endpoint
peer.mutex.Lock() peer.mutex.Lock()
peer.endpoint = elem.endpoint peer.endpoint = elem.endpoint
peer.mutex.Unlock() peer.mutex.Unlock()
// check for keep-alive // check if using new key-pair
kp := &peer.keyPairs
kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true
if kp.next == elem.keyPair {
old := kp.previous
kp.previous = kp.current
device.DeleteKeypair(old)
kp.current = kp.next
kp.next = nil
peer.timersHandshakeComplete()
select {
case peer.signals.newKeypairArrived <- struct{}{}:
default:
}
}
kp.mutex.Unlock()
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
// check for keepalive
if len(elem.packet) == 0 { if len(elem.packet) == 0 {
logDebug.Println(peer, ": Received keep-alive") logDebug.Println(peer, ": Receiving keepalive packet")
continue continue
} }
peer.event.dataReceived.Fire() peer.timersDataReceived()
// verify source and strip padding // verify source and strip padding

134
send.go
View file

@ -6,6 +6,7 @@
package main package main
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
@ -46,21 +47,10 @@ type QueueOutboundElement struct {
buffer *[MaxMessageSize]byte // slice holding the packet data buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!) packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption nonce uint64 // nonce for encryption
keyPair *KeyPair // key-pair for encryption keyPair *Keypair // key-pair for encryption
peer *Peer // related peer peer *Peer // related peer
} }
func (peer *Peer) flushNonceQueue() {
elems := len(peer.queue.nonce)
for i := 0; i < elems; i++ {
select {
case <-peer.queue.nonce:
default:
return
}
}
}
func (device *Device) NewOutboundElement() *QueueOutboundElement { func (device *Device) NewOutboundElement() *QueueOutboundElement {
return &QueueOutboundElement{ return &QueueOutboundElement{
dropped: AtomicFalse, dropped: AtomicFalse,
@ -114,6 +104,73 @@ func addToEncryptionQueue(
} }
} }
/* Queues a keepalive if no packets are queued for peer
*/
func (peer *Peer) SendKeepalive() bool {
if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey {
return false
}
elem := peer.device.NewOutboundElement()
elem.packet = nil
select {
case peer.queue.nonce <- elem:
peer.device.log.Debug.Println(peer, ": Sending keepalive packet")
return true
default:
return false
}
}
/* Sends a new handshake initiation message to the peer (endpoint)
*/
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry {
peer.timers.handshakeAttempts = 0
}
if time.Now().Sub(peer.timers.lastSentHandshake) < RekeyTimeout {
return nil
}
peer.timers.lastSentHandshake = time.Now() //TODO: locking for this variable?
// create initiation message
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
return err
}
peer.device.log.Debug.Println(peer, ": Sending handshake initiation")
// marshal handshake message
var buff [MessageInitiationSize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.mac.AddMacs(packet)
// send to endpoint
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersHandshakeInitiated()
return peer.SendBuffer(packet)
}
/* Called when a new authenticated message has been send
*
*/
func (peer *Peer) keepKeyFreshSending() {
kp := peer.keyPairs.Current()
if kp == nil {
return
}
nonce := atomic.LoadUint64(&kp.sendNonce)
if nonce > RekeyAfterMessages || (kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false)
}
}
/* Reads packets from the TUN and inserts /* Reads packets from the TUN and inserts
* into nonce queue for peer * into nonce queue for peer
* *
@ -180,13 +237,22 @@ func (device *Device) RoutineReadFromTUN() {
// insert into nonce/pre-handshake queue // insert into nonce/pre-handshake queue
if peer.isRunning.Get() { if peer.isRunning.Get() {
peer.event.handshakePushDeadline.Fire() if peer.queue.packetInNonceQueueIsAwaitingKey {
peer.SendHandshakeInitiation(false)
}
addToOutboundQueue(peer.queue.nonce, elem) addToOutboundQueue(peer.queue.nonce, elem)
elem = device.NewOutboundElement() elem = device.NewOutboundElement()
} }
} }
} }
func (peer *Peer) FlushNonceQueue() {
select {
case peer.signals.flushNonceQueue <- struct{}{}:
default:
}
}
/* Queues packets when there is no handshake. /* Queues packets when there is no handshake.
* Then assigns nonces to packets sequentially * Then assigns nonces to packets sequentially
* and creates "work" structs for workers * and creates "work" structs for workers
@ -194,13 +260,14 @@ func (device *Device) RoutineReadFromTUN() {
* Obs. A single instance per peer * Obs. A single instance per peer
*/ */
func (peer *Peer) RoutineNonce() { func (peer *Peer) RoutineNonce() {
var keyPair *KeyPair var keyPair *Keypair
device := peer.device device := peer.device
logDebug := device.log.Debug logDebug := device.log.Debug
defer func() { defer func() {
logDebug.Println(peer, ": Routine: nonce worker - stopped") logDebug.Println(peer, ": Routine: nonce worker - stopped")
peer.queue.packetInNonceQueueIsAwaitingKey = false
peer.routines.stopping.Done() peer.routines.stopping.Done()
}() }()
@ -209,8 +276,7 @@ func (peer *Peer) RoutineNonce() {
for { for {
NextPacket: NextPacket:
peer.queue.packetInNonceQueueIsAwaitingKey = false
peer.event.flushNonceQueue.Clear()
select { select {
case <-peer.routines.stop: case <-peer.routines.stop:
@ -225,34 +291,48 @@ func (peer *Peer) RoutineNonce() {
// wait for key pair // wait for key pair
for { for {
peer.event.newKeyPair.Clear()
keyPair = peer.keyPairs.Current() keyPair = peer.keyPairs.Current()
if keyPair != nil && keyPair.sendNonce < RejectAfterMessages { if keyPair != nil && keyPair.sendNonce < RejectAfterMessages {
if time.Now().Sub(keyPair.created) < RejectAfterTime { if time.Now().Sub(keyPair.created) < RejectAfterTime {
break break
} }
} }
peer.queue.packetInNonceQueueIsAwaitingKey = true
peer.event.handshakeBegin.Fire() select {
case <-peer.signals.newKeypairArrived:
default:
}
peer.SendHandshakeInitiation(false)
logDebug.Println(peer, ": Awaiting key-pair") logDebug.Println(peer, ": Awaiting key-pair")
select { select {
case <-peer.event.newKeyPair.C: case <-peer.signals.newKeypairArrived:
logDebug.Println(peer, ": Obtained awaited key-pair") logDebug.Println(peer, ": Obtained awaited key-pair")
case <-peer.event.flushNonceQueue.C: case <-peer.signals.flushNonceQueue:
goto NextPacket for {
select {
case <-peer.queue.nonce:
default:
goto NextPacket
}
}
case <-peer.routines.stop: case <-peer.routines.stop:
return return
} }
} }
peer.queue.packetInNonceQueueIsAwaitingKey = false
// populate work element // populate work element
elem.peer = peer elem.peer = peer
elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
// double check in case of race condition added by future code
if elem.nonce >= RejectAfterMessages {
goto NextPacket
}
elem.keyPair = keyPair elem.keyPair = keyPair
elem.dropped = AtomicFalse elem.dropped = AtomicFalse
elem.mutex.Lock() elem.mutex.Lock()
@ -288,7 +368,7 @@ func (device *Device) RoutineEncryption() {
// fetch next element // fetch next element
select { select {
case <-device.signal.stop.Wait(): case <-device.signals.stop:
return return
case elem, ok := <-device.queue.encryption: case elem, ok := <-device.queue.encryption:
@ -389,11 +469,11 @@ func (peer *Peer) RoutineSequentialSender() {
// update timers // update timers
peer.event.anyAuthenticatedPacketTraversal.Fire() peer.timersAnyAuthenticatedPacketTraversal()
if len(elem.packet) != MessageKeepaliveSize { if len(elem.packet) != MessageKeepaliveSize {
peer.event.dataSent.Fire() peer.timersDataSent()
} }
peer.KeepKeyFreshSending() peer.keepKeyFreshSending()
} }
} }
} }

View file

@ -1,71 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*/
package main
func signalSend(s chan<- struct{}) {
select {
case s <- struct{}{}:
default:
}
}
type Signal struct {
enabled AtomicBool
C chan struct{}
}
func NewSignal() (s Signal) {
s.C = make(chan struct{}, 1)
s.Enable()
return
}
func (s *Signal) Close() {
close(s.C)
}
func (s *Signal) Disable() {
s.enabled.Set(false)
s.Clear()
}
func (s *Signal) Enable() {
s.enabled.Set(true)
}
/* Unblock exactly one listener
*/
func (s *Signal) Send() {
if s.enabled.Get() {
select {
case s.C <- struct{}{}:
default:
}
}
}
/* Clear the signal if already fired
*/
func (s Signal) Clear() {
select {
case <-s.C:
default:
}
}
/* Unblocks all listeners (forever)
*/
func (s Signal) Broadcast() {
if s.enabled.Get() {
close(s.C)
}
}
/* Wait for the signal
*/
func (s Signal) Wait() chan struct{} {
return s.C
}

512
timers.go
View file

@ -1,355 +1,221 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2015-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*
* This is based heavily on timers.c from the kernel implementation.
*/ */
package main package main
import ( import (
"bytes"
"encoding/binary"
"math/rand" "math/rand"
"sync/atomic" "sync/atomic"
"time" "time"
) )
/* NOTE: /* This Timer structure and related functions should roughly copy the interface of
* Notion of validity * the Linux kernel's struct timer_list.
*/ */
/* Called when a new authenticated message has been send type Timer struct {
* timer *time.Timer
*/ isPending bool
func (peer *Peer) KeepKeyFreshSending() {
kp := peer.keyPairs.Current()
if kp == nil {
return
}
nonce := atomic.LoadUint64(&kp.sendNonce)
if nonce > RekeyAfterMessages {
peer.event.handshakeBegin.Fire()
}
if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
peer.event.handshakeBegin.Fire()
}
} }
/* Called when a new authenticated message has been received func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
* timer := &Timer{}
* NOTE: Not thread safe, but called by sequential receiver! timer.timer = time.AfterFunc(time.Hour, func() {
*/ timer.isPending = false
func (peer *Peer) KeepKeyFreshReceiving() { expirationFunction(peer)
if peer.timer.sendLastMinuteHandshake.Get() { })
return timer.timer.Stop()
}
kp := peer.keyPairs.Current()
if kp == nil {
return
}
if !kp.isInitiator {
return
}
nonce := atomic.LoadUint64(&kp.sendNonce)
send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
if send {
// do a last minute attempt at initiating a new handshake
peer.timer.sendLastMinuteHandshake.Set(true)
peer.event.handshakeBegin.Fire()
}
}
/* Queues a keep-alive if no packets are queued for peer
*/
func (peer *Peer) SendKeepAlive() bool {
if len(peer.queue.nonce) != 0 {
return false
}
elem := peer.device.NewOutboundElement()
elem.packet = nil
select {
case peer.queue.nonce <- elem:
return true
default:
return false
}
}
/* Called after successfully completing a handshake.
* i.e. after:
*
* - Valid handshake response
* - First transport message under the "next" key
*/
// peer.device.log.Info.Println(peer, ": New handshake completed")
/* Event:
* An ephemeral key is generated
*
* i.e. after:
*
* CreateMessageInitiation
* CreateMessageResponse
*
* Action:
* Schedule the deletion of all key material
* upon failure to complete a handshake
*/
func (peer *Peer) TimerEphemeralKeyCreated() {
peer.event.ephemeralKeyCreated.Fire()
// peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
}
/* Sends a new handshake initiation message to the peer (endpoint)
*/
func (peer *Peer) sendNewHandshake() error {
// create initiation message
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
return err
}
// marshal handshake message
var buff [MessageInitiationSize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.mac.AddMacs(packet)
// send to endpoint
peer.event.anyAuthenticatedPacketTraversal.Fire()
return peer.SendBuffer(packet)
}
func newTimer() *time.Timer {
timer := time.NewTimer(time.Hour)
timer.Stop()
return timer return timer
} }
func (peer *Peer) RoutineTimerHandler() { func (timer *Timer) Mod(d time.Duration) {
timer.isPending = true
timer.timer.Reset(d)
}
device := peer.device func (timer *Timer) Del() {
timer.isPending = false
timer.timer.Stop()
}
logInfo := device.log.Info func (peer *Peer) timersActive() bool {
logDebug := device.log.Debug return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0
}
defer func() { func expiredRetransmitHandshake(peer *Peer) {
logDebug.Println(peer, ": Routine: timer handler - stopped") if peer.timers.handshakeAttempts > MaxTimerHandshakes {
peer.routines.stopping.Done() peer.device.log.Debug.Printf("%s: Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2)
}()
logDebug.Println(peer, ": Routine: timer handler - started") if peer.timersActive() {
peer.timers.sendKeepalive.Del()
}
// reset all timers /* We drop all packets without a keypair and don't try again,
* if we try unsuccessfully for too long to make a handshake.
*/
peer.FlushNonceQueue()
enableHandshake := true /* We set a timer for destroying any residue that might be left
pendingHandshakeNew := false * of a partial exchange.
pendingKeepalivePassive := false */
needAnotherKeepalive := false if peer.timersActive() && !peer.timers.zeroKeyMaterial.isPending {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
}
} else {
peer.timers.handshakeAttempts++
peer.device.log.Debug.Printf("%s: Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts+1)
timerKeepalivePassive := newTimer() /* We clear the endpoint address src address, in case this is the cause of trouble. */
timerHandshakeDeadline := newTimer() peer.mutex.Lock()
timerHandshakeTimeout := newTimer() if peer.endpoint != nil {
timerHandshakeNew := newTimer() peer.endpoint.ClearSrc()
timerZeroAllKeys := newTimer() }
timerKeepalivePersistent := newTimer() peer.mutex.Unlock()
interval := peer.persistentKeepaliveInterval peer.SendHandshakeInitiation(true)
if interval > 0 {
duration := time.Duration(interval) * time.Second
timerKeepalivePersistent.Reset(duration)
} }
}
// signal synchronised setup complete func expiredSendKeepalive(peer *Peer) {
peer.SendKeepalive()
peer.routines.starting.Done() if peer.timers.needAnotherKeepalive {
peer.timers.needAnotherKeepalive = false
// handle timer events if peer.timersActive() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
for {
select {
/* stopping */
case <-peer.routines.stop:
return
/* events */
case <-peer.event.dataSent.C:
timerKeepalivePassive.Stop()
if !pendingHandshakeNew {
timerHandshakeNew.Reset(NewHandshakeTime)
}
case <-peer.event.dataReceived.C:
if pendingKeepalivePassive {
needAnotherKeepalive = true
} else {
timerKeepalivePassive.Reset(KeepaliveTimeout)
}
case <-peer.event.anyAuthenticatedPacketTraversal.C:
interval := peer.persistentKeepaliveInterval
if interval > 0 {
duration := time.Duration(interval) * time.Second
timerKeepalivePersistent.Reset(duration)
}
case <-peer.event.handshakeBegin.C:
if !enableHandshake {
continue
}
logDebug.Println(peer, ": Event, Handshake Begin")
err := peer.sendNewHandshake()
// set timeout
jitter := time.Millisecond * time.Duration(rand.Int31n(334))
timerKeepalivePassive.Stop()
timerHandshakeTimeout.Reset(RekeyTimeout + jitter)
if err != nil {
logInfo.Println(peer, ": Failed to send handshake initiation", err)
} else {
logDebug.Println(peer, ": Send handshake initiation (initial)")
}
timerHandshakeDeadline.Reset(RekeyAttemptTime)
// disable further handshakes
peer.event.handshakeBegin.Clear()
enableHandshake = false
case <-peer.event.handshakeCompleted.C:
logInfo.Println(peer, ": Handshake completed")
atomic.StoreInt64(
&peer.stats.lastHandshakeNano,
time.Now().UnixNano(),
)
timerHandshakeTimeout.Stop()
timerHandshakeDeadline.Stop()
peer.timer.sendLastMinuteHandshake.Set(false)
// allow further handshakes
peer.event.handshakeBegin.Clear()
enableHandshake = true
/* timers */
case <-timerKeepalivePersistent.C:
interval := peer.persistentKeepaliveInterval
if interval > 0 {
logDebug.Println(peer, ": Send keep-alive (persistent)")
timerKeepalivePassive.Stop()
peer.SendKeepAlive()
}
case <-timerKeepalivePassive.C:
logDebug.Println(peer, ": Send keep-alive (passive)")
peer.SendKeepAlive()
if needAnotherKeepalive {
timerKeepalivePassive.Reset(KeepaliveTimeout)
needAnotherKeepalive = false
}
case <-timerZeroAllKeys.C:
logDebug.Println(peer, ": Clear all key-material (timer event)")
hs := &peer.handshake
hs.mutex.Lock()
kp := &peer.keyPairs
kp.mutex.Lock()
// remove key-pairs
if kp.previous != nil {
device.DeleteKeyPair(kp.previous)
kp.previous = nil
}
if kp.current != nil {
device.DeleteKeyPair(kp.current)
kp.current = nil
}
if kp.next != nil {
device.DeleteKeyPair(kp.next)
kp.next = nil
}
kp.mutex.Unlock()
// zero out handshake
device.indices.Delete(hs.localIndex)
hs.Clear()
hs.mutex.Unlock()
case <-timerHandshakeTimeout.C:
// allow new handshake to be send
enableHandshake = true
// clear source (in case this is causing problems)
peer.mutex.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.mutex.Unlock()
// send new handshake
err := peer.sendNewHandshake()
// set timeout
jitter := time.Millisecond * time.Duration(rand.Int31n(334))
timerKeepalivePassive.Stop()
timerHandshakeTimeout.Reset(RekeyTimeout + jitter)
if err != nil {
logInfo.Println(peer, ": Failed to send handshake initiation", err)
} else {
logDebug.Println(peer, ": Send handshake initiation (subsequent)")
}
// disable further handshakes
peer.event.handshakeBegin.Clear()
enableHandshake = false
case <-timerHandshakeDeadline.C:
// clear all queued packets and stop keep-alive
logInfo.Println(peer, ": Handshake negotiation timed-out")
peer.flushNonceQueue()
peer.event.flushNonceQueue.Fire()
// renable further handshakes
peer.event.handshakeBegin.Clear()
enableHandshake = true
} }
} }
} }
func expiredNewHandshake(peer *Peer) {
peer.device.log.Debug.Printf("%s: Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.mutex.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.mutex.Unlock()
peer.SendHandshakeInitiation(false)
}
func expiredZeroKeyMaterial(peer *Peer) {
peer.device.log.Debug.Printf(":%s Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds()))
hs := &peer.handshake
hs.mutex.Lock()
kp := &peer.keyPairs
kp.mutex.Lock()
if kp.previous != nil {
peer.device.DeleteKeypair(kp.previous)
kp.previous = nil
}
if kp.current != nil {
peer.device.DeleteKeypair(kp.current)
kp.current = nil
}
if kp.next != nil {
peer.device.DeleteKeypair(kp.next)
kp.next = nil
}
kp.mutex.Unlock()
peer.device.indices.Delete(hs.localIndex)
hs.Clear()
hs.mutex.Unlock()
}
func expiredPersistentKeepalive(peer *Peer) {
if peer.persistentKeepaliveInterval > 0 {
if peer.timersActive() {
peer.timers.sendKeepalive.Del()
}
peer.SendKeepalive()
}
}
/* Should be called after an authenticated data packet is sent. */
func (peer *Peer) timersDataSent() {
if peer.timersActive() {
peer.timers.sendKeepalive.Del()
}
if peer.timersActive() && !peer.timers.newHandshake.isPending {
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout)
}
}
/* Should be called after an authenticated data packet is received. */
func (peer *Peer) timersDataReceived() {
if peer.timersActive() {
if !peer.timers.sendKeepalive.isPending {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} else {
peer.timers.needAnotherKeepalive = true
}
}
}
/* Should be called after any type of authenticated packet is received -- keepalive or data. */
func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
if peer.timersActive() {
peer.timers.newHandshake.Del()
}
}
/* Should be called after a handshake initiation message is sent. */
func (peer *Peer) timersHandshakeInitiated() {
if peer.timersActive() {
peer.timers.sendKeepalive.Del()
peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
}
}
/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */
func (peer *Peer) timersHandshakeComplete() {
if peer.timersActive() {
peer.timers.retransmitHandshake.Del()
}
peer.timers.handshakeAttempts = 0
peer.timers.sentLastMinuteHandshake = false
atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
}
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
func (peer *Peer) timersSessionDerived() {
if peer.timersActive() {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
}
}
/* Should be called before a packet with authentication -- data, keepalive, either handshake -- is sent, or after one is received. */
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
if peer.persistentKeepaliveInterval > 0 && peer.timersActive() {
peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second)
}
}
func (peer *Peer) timersInit() {
peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake)
peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive)
peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
peer.timers.handshakeAttempts = 0
peer.timers.sentLastMinuteHandshake = false
peer.timers.needAnotherKeepalive = false
peer.timers.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
}
func (peer *Peer) timersStop() {
peer.timers.retransmitHandshake.Del()
peer.timers.sendKeepalive.Del()
peer.timers.newHandshake.Del()
peer.timers.zeroKeyMaterial.Del()
peer.timers.persistentKeepalive.Del()
}

11
uapi.go
View file

@ -256,8 +256,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logDebug.Println("UAPI: Created new peer:", peer) logDebug.Println("UAPI: Created new peer:", peer)
} }
peer.event.handshakePushDeadline.Fire()
case "remove": case "remove":
// remove currently selected peer from device // remove currently selected peer from device
@ -288,8 +286,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
peer.event.handshakePushDeadline.Fire()
case "endpoint": case "endpoint":
// set endpoint destination // set endpoint destination
@ -304,7 +300,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return err return err
} }
peer.endpoint = endpoint peer.endpoint = endpoint
peer.event.handshakePushDeadline.Fire()
return nil return nil
}() }()
@ -315,7 +310,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "persistent_keepalive_interval": case "persistent_keepalive_interval":
// update keep-alive interval // update persistent keepalive interval
logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer) logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer)
@ -328,7 +323,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
old := peer.persistentKeepaliveInterval old := peer.persistentKeepaliveInterval
peer.persistentKeepaliveInterval = uint16(secs) peer.persistentKeepaliveInterval = uint16(secs)
// send immediate keep-alive // send immediate keepalive if we're turning it on and before it wasn't on
if old == 0 && secs != 0 { if old == 0 && secs != 0 {
if err != nil { if err != nil {
@ -336,7 +331,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorIO} return &IPCError{Code: ipcErrorIO}
} }
if device.isUp.Get() && !dummy { if device.isUp.Get() && !dummy {
peer.SendKeepAlive() peer.SendKeepalive()
} }
} }