device: use new model queues for handshakes

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-01-29 18:24:45 +01:00
parent 9263014ed3
commit beb25cc4fd
2 changed files with 52 additions and 79 deletions

View file

@ -13,6 +13,7 @@ import (
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/rwcancel"
@ -77,11 +78,7 @@ type Device struct {
queue struct { queue struct {
encryption *outboundQueue encryption *outboundQueue
decryption *inboundQueue decryption *inboundQueue
handshake chan QueueHandshakeElement handshake *handshakeQueue
}
signals struct {
stop chan struct{}
} }
tun struct { tun struct {
@ -90,6 +87,7 @@ type Device struct {
} }
ipcMutex sync.RWMutex ipcMutex sync.RWMutex
closed chan struct{}
} }
// An outboundQueue is a channel of QueueOutboundElements awaiting encryption. // An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
@ -135,6 +133,24 @@ func newInboundQueue() *inboundQueue {
return q return q
} }
// A handshakeQueue is similar to an outboundQueue; see those docs.
type handshakeQueue struct {
c chan QueueHandshakeElement
wg sync.WaitGroup
}
func newHandshakeQueue() *handshakeQueue {
q := &handshakeQueue{
c: make(chan QueueHandshakeElement, QueueHandshakeSize),
}
q.wg.Add(1)
go func() {
q.wg.Wait()
close(q.c)
}()
return q
}
/* Converts the peer into a "zombie", which remains in the peer map, /* Converts the peer into a "zombie", which remains in the peer map,
* but processes no packets and does not exists in the routing table. * but processes no packets and does not exists in the routing table.
* *
@ -233,7 +249,7 @@ func (device *Device) IsUnderLoad() bool {
// check if currently under load // check if currently under load
now := time.Now() now := time.Now()
underLoad := len(device.queue.handshake) >= UnderLoadQueueSize underLoad := len(device.queue.handshake.c) >= UnderLoadQueueSize
if underLoad { if underLoad {
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime)) device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime))
return true return true
@ -302,6 +318,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
func NewDevice(tunDevice tun.Device, logger *Logger) *Device { func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
device := new(Device) device := new(Device)
device.closed = make(chan struct{})
device.log = logger device.log = logger
device.tun.device = tunDevice device.tun.device = tunDevice
mtu, err := device.tun.device.MTU() mtu, err := device.tun.device.MTU()
@ -322,14 +339,10 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
// create queues // create queues
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) device.queue.handshake = newHandshakeQueue()
device.queue.encryption = newOutboundQueue() device.queue.encryption = newOutboundQueue()
device.queue.decryption = newInboundQueue() device.queue.decryption = newInboundQueue()
// prepare signals
device.signals.stop = make(chan struct{})
// prepare net // prepare net
device.net.port = 0 device.net.port = 0
@ -382,18 +395,6 @@ func (device *Device) RemoveAllPeers() {
device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.peers.keyMap = make(map[NoisePublicKey]*Peer)
} }
func (device *Device) FlushPacketQueues() {
for {
select {
case elem := <-device.queue.handshake:
device.PutMessageBuffer(elem.buffer)
default:
return
}
}
}
func (device *Device) Close() { func (device *Device) Close() {
if device.isClosed.Swap(true) { if device.isClosed.Swap(true) {
return return
@ -414,21 +415,20 @@ func (device *Device) Close() {
// No new peers are coming; we are done with these queues. // No new peers are coming; we are done with these queues.
device.queue.encryption.wg.Done() device.queue.encryption.wg.Done()
device.queue.decryption.wg.Done() device.queue.decryption.wg.Done()
close(device.signals.stop) device.queue.handshake.wg.Done()
device.state.stopping.Wait() device.state.stopping.Wait()
device.RemoveAllPeers() device.RemoveAllPeers()
device.FlushPacketQueues()
device.rate.limiter.Close() device.rate.limiter.Close()
device.state.changing.Set(false) device.state.changing.Set(false)
device.log.Verbosef("Interface closed") device.log.Verbosef("Interface closed")
close(device.closed)
} }
func (device *Device) Wait() chan struct{} { func (device *Device) Wait() chan struct{} {
return device.signals.stop return device.closed
} }
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
@ -561,6 +561,7 @@ func (device *Device) BindUpdate() error {
device.net.stopping.Add(2) device.net.stopping.Add(2)
device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)

View file

@ -48,15 +48,6 @@ func (elem *QueueInboundElement) clearPointers() {
elem.endpoint = nil elem.endpoint = nil
} }
func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, elem QueueHandshakeElement) bool {
select {
case queue <- elem:
return true
default:
return false
}
}
/* Called when a new authenticated message has been received /* Called when a new authenticated message has been received
* *
* NOTE: Not thread safe, but called by sequential receiver! * NOTE: Not thread safe, but called by sequential receiver!
@ -81,6 +72,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
defer func() { defer func() {
device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP) device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP)
device.queue.decryption.wg.Done() device.queue.decryption.wg.Done()
device.queue.handshake.wg.Done()
device.net.stopping.Done() device.net.stopping.Done()
}() }()
@ -202,16 +194,15 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
} }
if okay { if okay {
if (device.addToHandshakeQueue( select {
device.queue.handshake, case device.queue.handshake.c <- QueueHandshakeElement{
QueueHandshakeElement{
msgType: msgType, msgType: msgType,
buffer: buffer, buffer: buffer,
packet: packet, packet: packet,
endpoint: endpoint, endpoint: endpoint,
}, }:
)) {
buffer = device.GetMessageBuffer() buffer = device.GetMessageBuffer()
default:
} }
} }
} }
@ -251,34 +242,13 @@ func (device *Device) RoutineDecryption() {
/* Handles incoming packets related to handshake /* Handles incoming packets related to handshake
*/ */
func (device *Device) RoutineHandshake() { func (device *Device) RoutineHandshake() {
var elem QueueHandshakeElement
var ok bool
defer func() { defer func() {
device.log.Verbosef("Routine: handshake worker - stopped") device.log.Verbosef("Routine: handshake worker - stopped")
device.state.stopping.Done() device.state.stopping.Done()
if elem.buffer != nil {
device.PutMessageBuffer(elem.buffer)
}
}() }()
device.log.Verbosef("Routine: handshake worker - started") device.log.Verbosef("Routine: handshake worker - started")
for { for elem := range device.queue.handshake.c {
if elem.buffer != nil {
device.PutMessageBuffer(elem.buffer)
elem.buffer = nil
}
select {
case elem, ok = <-device.queue.handshake:
case <-device.signals.stop:
return
}
if !ok {
return
}
// handle cookie fields and ratelimiting // handle cookie fields and ratelimiting
@ -293,7 +263,7 @@ func (device *Device) RoutineHandshake() {
err := binary.Read(reader, binary.LittleEndian, &reply) err := binary.Read(reader, binary.LittleEndian, &reply)
if err != nil { if err != nil {
device.log.Verbosef("Failed to decode cookie reply") device.log.Verbosef("Failed to decode cookie reply")
return goto skip
} }
// lookup peer from index // lookup peer from index
@ -301,7 +271,7 @@ func (device *Device) RoutineHandshake() {
entry := device.indexTable.Lookup(reply.Receiver) entry := device.indexTable.Lookup(reply.Receiver)
if entry.peer == nil { if entry.peer == nil {
continue goto skip
} }
// consume reply // consume reply
@ -313,7 +283,7 @@ func (device *Device) RoutineHandshake() {
} }
} }
continue goto skip
case MessageInitiationType, MessageResponseType: case MessageInitiationType, MessageResponseType:
@ -321,7 +291,7 @@ func (device *Device) RoutineHandshake() {
if !device.cookieChecker.CheckMAC1(elem.packet) { if !device.cookieChecker.CheckMAC1(elem.packet) {
device.log.Verbosef("Received packet with invalid mac1") device.log.Verbosef("Received packet with invalid mac1")
continue goto skip
} }
// endpoints destination address is the source of the datagram // endpoints destination address is the source of the datagram
@ -332,19 +302,19 @@ func (device *Device) RoutineHandshake() {
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
device.SendHandshakeCookie(&elem) device.SendHandshakeCookie(&elem)
continue goto skip
} }
// check ratelimiter // check ratelimiter
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
continue goto skip
} }
} }
default: default:
device.log.Errorf("Invalid packet ended up in the handshake queue") device.log.Errorf("Invalid packet ended up in the handshake queue")
continue goto skip
} }
// handle handshake initiation/response content // handle handshake initiation/response content
@ -359,7 +329,7 @@ func (device *Device) RoutineHandshake() {
err := binary.Read(reader, binary.LittleEndian, &msg) err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil { if err != nil {
device.log.Errorf("Failed to decode initiation message") device.log.Errorf("Failed to decode initiation message")
continue goto skip
} }
// consume initiation // consume initiation
@ -367,7 +337,7 @@ func (device *Device) RoutineHandshake() {
peer := device.ConsumeMessageInitiation(&msg) peer := device.ConsumeMessageInitiation(&msg)
if peer == nil { if peer == nil {
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
continue goto skip
} }
// update timers // update timers
@ -392,7 +362,7 @@ func (device *Device) RoutineHandshake() {
err := binary.Read(reader, binary.LittleEndian, &msg) err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil { if err != nil {
device.log.Errorf("Failed to decode response message") device.log.Errorf("Failed to decode response message")
continue goto skip
} }
// consume response // consume response
@ -400,7 +370,7 @@ func (device *Device) RoutineHandshake() {
peer := device.ConsumeMessageResponse(&msg) peer := device.ConsumeMessageResponse(&msg)
if peer == nil { if peer == nil {
device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
continue goto skip
} }
// update endpoint // update endpoint
@ -420,13 +390,15 @@ func (device *Device) RoutineHandshake() {
if err != nil { if err != nil {
device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
continue goto skip
} }
peer.timersSessionDerived() peer.timersSessionDerived()
peer.timersHandshakeComplete() peer.timersHandshakeComplete()
peer.SendKeepalive() peer.SendKeepalive()
} }
skip:
device.PutMessageBuffer(elem.buffer)
} }
} }