device: use new model queues for handshakes
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
9263014ed3
commit
beb25cc4fd
|
@ -13,6 +13,7 @@ import (
|
||||||
|
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/ratelimiter"
|
"golang.zx2c4.com/wireguard/ratelimiter"
|
||||||
"golang.zx2c4.com/wireguard/rwcancel"
|
"golang.zx2c4.com/wireguard/rwcancel"
|
||||||
|
@ -77,11 +78,7 @@ type Device struct {
|
||||||
queue struct {
|
queue struct {
|
||||||
encryption *outboundQueue
|
encryption *outboundQueue
|
||||||
decryption *inboundQueue
|
decryption *inboundQueue
|
||||||
handshake chan QueueHandshakeElement
|
handshake *handshakeQueue
|
||||||
}
|
|
||||||
|
|
||||||
signals struct {
|
|
||||||
stop chan struct{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tun struct {
|
tun struct {
|
||||||
|
@ -90,6 +87,7 @@ type Device struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
ipcMutex sync.RWMutex
|
ipcMutex sync.RWMutex
|
||||||
|
closed chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
|
// An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
|
||||||
|
@ -135,6 +133,24 @@ func newInboundQueue() *inboundQueue {
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A handshakeQueue is similar to an outboundQueue; see those docs.
|
||||||
|
type handshakeQueue struct {
|
||||||
|
c chan QueueHandshakeElement
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHandshakeQueue() *handshakeQueue {
|
||||||
|
q := &handshakeQueue{
|
||||||
|
c: make(chan QueueHandshakeElement, QueueHandshakeSize),
|
||||||
|
}
|
||||||
|
q.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
q.wg.Wait()
|
||||||
|
close(q.c)
|
||||||
|
}()
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
/* Converts the peer into a "zombie", which remains in the peer map,
|
/* Converts the peer into a "zombie", which remains in the peer map,
|
||||||
* but processes no packets and does not exists in the routing table.
|
* but processes no packets and does not exists in the routing table.
|
||||||
*
|
*
|
||||||
|
@ -233,7 +249,7 @@ func (device *Device) IsUnderLoad() bool {
|
||||||
// check if currently under load
|
// check if currently under load
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
|
underLoad := len(device.queue.handshake.c) >= UnderLoadQueueSize
|
||||||
if underLoad {
|
if underLoad {
|
||||||
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime))
|
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime))
|
||||||
return true
|
return true
|
||||||
|
@ -302,6 +318,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||||
|
|
||||||
func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
||||||
device := new(Device)
|
device := new(Device)
|
||||||
|
device.closed = make(chan struct{})
|
||||||
device.log = logger
|
device.log = logger
|
||||||
device.tun.device = tunDevice
|
device.tun.device = tunDevice
|
||||||
mtu, err := device.tun.device.MTU()
|
mtu, err := device.tun.device.MTU()
|
||||||
|
@ -322,14 +339,10 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
||||||
|
|
||||||
// create queues
|
// create queues
|
||||||
|
|
||||||
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
|
device.queue.handshake = newHandshakeQueue()
|
||||||
device.queue.encryption = newOutboundQueue()
|
device.queue.encryption = newOutboundQueue()
|
||||||
device.queue.decryption = newInboundQueue()
|
device.queue.decryption = newInboundQueue()
|
||||||
|
|
||||||
// prepare signals
|
|
||||||
|
|
||||||
device.signals.stop = make(chan struct{})
|
|
||||||
|
|
||||||
// prepare net
|
// prepare net
|
||||||
|
|
||||||
device.net.port = 0
|
device.net.port = 0
|
||||||
|
@ -382,18 +395,6 @@ func (device *Device) RemoveAllPeers() {
|
||||||
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) FlushPacketQueues() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case elem := <-device.queue.handshake:
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) Close() {
|
func (device *Device) Close() {
|
||||||
if device.isClosed.Swap(true) {
|
if device.isClosed.Swap(true) {
|
||||||
return
|
return
|
||||||
|
@ -414,21 +415,20 @@ func (device *Device) Close() {
|
||||||
// No new peers are coming; we are done with these queues.
|
// No new peers are coming; we are done with these queues.
|
||||||
device.queue.encryption.wg.Done()
|
device.queue.encryption.wg.Done()
|
||||||
device.queue.decryption.wg.Done()
|
device.queue.decryption.wg.Done()
|
||||||
close(device.signals.stop)
|
device.queue.handshake.wg.Done()
|
||||||
device.state.stopping.Wait()
|
device.state.stopping.Wait()
|
||||||
|
|
||||||
device.RemoveAllPeers()
|
device.RemoveAllPeers()
|
||||||
|
|
||||||
device.FlushPacketQueues()
|
|
||||||
|
|
||||||
device.rate.limiter.Close()
|
device.rate.limiter.Close()
|
||||||
|
|
||||||
device.state.changing.Set(false)
|
device.state.changing.Set(false)
|
||||||
device.log.Verbosef("Interface closed")
|
device.log.Verbosef("Interface closed")
|
||||||
|
close(device.closed)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) Wait() chan struct{} {
|
func (device *Device) Wait() chan struct{} {
|
||||||
return device.signals.stop
|
return device.closed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
||||||
|
@ -561,6 +561,7 @@ func (device *Device) BindUpdate() error {
|
||||||
|
|
||||||
device.net.stopping.Add(2)
|
device.net.stopping.Add(2)
|
||||||
device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
||||||
|
device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
||||||
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
||||||
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
||||||
|
|
||||||
|
|
|
@ -48,15 +48,6 @@ func (elem *QueueInboundElement) clearPointers() {
|
||||||
elem.endpoint = nil
|
elem.endpoint = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, elem QueueHandshakeElement) bool {
|
|
||||||
select {
|
|
||||||
case queue <- elem:
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Called when a new authenticated message has been received
|
/* Called when a new authenticated message has been received
|
||||||
*
|
*
|
||||||
* NOTE: Not thread safe, but called by sequential receiver!
|
* NOTE: Not thread safe, but called by sequential receiver!
|
||||||
|
@ -81,6 +72,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
|
||||||
defer func() {
|
defer func() {
|
||||||
device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP)
|
device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP)
|
||||||
device.queue.decryption.wg.Done()
|
device.queue.decryption.wg.Done()
|
||||||
|
device.queue.handshake.wg.Done()
|
||||||
device.net.stopping.Done()
|
device.net.stopping.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -202,16 +194,15 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if okay {
|
if okay {
|
||||||
if (device.addToHandshakeQueue(
|
select {
|
||||||
device.queue.handshake,
|
case device.queue.handshake.c <- QueueHandshakeElement{
|
||||||
QueueHandshakeElement{
|
|
||||||
msgType: msgType,
|
msgType: msgType,
|
||||||
buffer: buffer,
|
buffer: buffer,
|
||||||
packet: packet,
|
packet: packet,
|
||||||
endpoint: endpoint,
|
endpoint: endpoint,
|
||||||
},
|
}:
|
||||||
)) {
|
|
||||||
buffer = device.GetMessageBuffer()
|
buffer = device.GetMessageBuffer()
|
||||||
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -251,34 +242,13 @@ func (device *Device) RoutineDecryption() {
|
||||||
/* Handles incoming packets related to handshake
|
/* Handles incoming packets related to handshake
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineHandshake() {
|
func (device *Device) RoutineHandshake() {
|
||||||
var elem QueueHandshakeElement
|
|
||||||
var ok bool
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
device.log.Verbosef("Routine: handshake worker - stopped")
|
device.log.Verbosef("Routine: handshake worker - stopped")
|
||||||
device.state.stopping.Done()
|
device.state.stopping.Done()
|
||||||
if elem.buffer != nil {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
device.log.Verbosef("Routine: handshake worker - started")
|
device.log.Verbosef("Routine: handshake worker - started")
|
||||||
|
|
||||||
for {
|
for elem := range device.queue.handshake.c {
|
||||||
if elem.buffer != nil {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
elem.buffer = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case elem, ok = <-device.queue.handshake:
|
|
||||||
case <-device.signals.stop:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// handle cookie fields and ratelimiting
|
// handle cookie fields and ratelimiting
|
||||||
|
|
||||||
|
@ -293,7 +263,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
err := binary.Read(reader, binary.LittleEndian, &reply)
|
err := binary.Read(reader, binary.LittleEndian, &reply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
device.log.Verbosef("Failed to decode cookie reply")
|
device.log.Verbosef("Failed to decode cookie reply")
|
||||||
return
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookup peer from index
|
// lookup peer from index
|
||||||
|
@ -301,7 +271,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
entry := device.indexTable.Lookup(reply.Receiver)
|
entry := device.indexTable.Lookup(reply.Receiver)
|
||||||
|
|
||||||
if entry.peer == nil {
|
if entry.peer == nil {
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// consume reply
|
// consume reply
|
||||||
|
@ -313,7 +283,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
goto skip
|
||||||
|
|
||||||
case MessageInitiationType, MessageResponseType:
|
case MessageInitiationType, MessageResponseType:
|
||||||
|
|
||||||
|
@ -321,7 +291,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
|
|
||||||
if !device.cookieChecker.CheckMAC1(elem.packet) {
|
if !device.cookieChecker.CheckMAC1(elem.packet) {
|
||||||
device.log.Verbosef("Received packet with invalid mac1")
|
device.log.Verbosef("Received packet with invalid mac1")
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// endpoints destination address is the source of the datagram
|
// endpoints destination address is the source of the datagram
|
||||||
|
@ -332,19 +302,19 @@ func (device *Device) RoutineHandshake() {
|
||||||
|
|
||||||
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
|
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
|
||||||
device.SendHandshakeCookie(&elem)
|
device.SendHandshakeCookie(&elem)
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// check ratelimiter
|
// check ratelimiter
|
||||||
|
|
||||||
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
|
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
device.log.Errorf("Invalid packet ended up in the handshake queue")
|
device.log.Errorf("Invalid packet ended up in the handshake queue")
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle handshake initiation/response content
|
// handle handshake initiation/response content
|
||||||
|
@ -359,7 +329,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
err := binary.Read(reader, binary.LittleEndian, &msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
device.log.Errorf("Failed to decode initiation message")
|
device.log.Errorf("Failed to decode initiation message")
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// consume initiation
|
// consume initiation
|
||||||
|
@ -367,7 +337,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
peer := device.ConsumeMessageInitiation(&msg)
|
peer := device.ConsumeMessageInitiation(&msg)
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
|
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// update timers
|
// update timers
|
||||||
|
@ -392,7 +362,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
err := binary.Read(reader, binary.LittleEndian, &msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
device.log.Errorf("Failed to decode response message")
|
device.log.Errorf("Failed to decode response message")
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// consume response
|
// consume response
|
||||||
|
@ -400,7 +370,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
peer := device.ConsumeMessageResponse(&msg)
|
peer := device.ConsumeMessageResponse(&msg)
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
|
device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// update endpoint
|
// update endpoint
|
||||||
|
@ -420,13 +390,15 @@ func (device *Device) RoutineHandshake() {
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
|
device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.timersSessionDerived()
|
peer.timersSessionDerived()
|
||||||
peer.timersHandshakeComplete()
|
peer.timersHandshakeComplete()
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
}
|
}
|
||||||
|
skip:
|
||||||
|
device.PutMessageBuffer(elem.buffer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue