More pooling

This commit is contained in:
Jason A. Donenfeld 2018-09-22 06:29:02 +02:00
parent cf81a28dd3
commit 833597b585
4 changed files with 148 additions and 56 deletions

View file

@ -19,8 +19,6 @@ const (
DeviceRoutineNumberAdditional = 2 DeviceRoutineNumberAdditional = 2
) )
var preallocatedBuffers = 0
type Device struct { type Device struct {
isUp AtomicBool // device is (going) up isUp AtomicBool // device is (going) up
isClosed AtomicBool // device is closed? (acting as guard) isClosed AtomicBool // device is closed? (acting as guard)
@ -68,8 +66,12 @@ type Device struct {
} }
pool struct { pool struct {
messageBuffers *sync.Pool messageBufferPool *sync.Pool
reuseChan chan interface{} messageBufferReuseChan chan *[MaxMessageSize]byte
inboundElementPool *sync.Pool
inboundElementReuseChan chan *QueueInboundElement
outboundElementPool *sync.Pool
outboundElementReuseChan chan *QueueOutboundElement
} }
queue struct { queue struct {
@ -245,22 +247,6 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
return nil return nil
} }
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
if preallocatedBuffers == 0 {
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
} else {
return (<-device.pool.reuseChan).(*[MaxMessageSize]byte)
}
}
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
if preallocatedBuffers == 0 {
device.pool.messageBuffers.Put(msg)
} else {
device.pool.reuseChan <- msg
}
}
func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device { func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device {
device := new(Device) device := new(Device)
@ -285,18 +271,7 @@ func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device {
device.indexTable.Init() device.indexTable.Init()
device.allowedips.Reset() device.allowedips.Reset()
if preallocatedBuffers == 0 { device.PopulatePools()
device.pool.messageBuffers = &sync.Pool{
New: func() interface{} {
return new([MaxMessageSize]byte)
},
}
} else {
device.pool.reuseChan = make(chan interface{}, preallocatedBuffers)
for i := 0; i < preallocatedBuffers; i += 1 {
device.pool.reuseChan <- new([MaxMessageSize]byte)
}
}
// create queues // create queues

91
pools.go Normal file
View file

@ -0,0 +1,91 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
import "sync"
var preallocatedBuffers = 0
func (device *Device) PopulatePools() {
if preallocatedBuffers == 0 {
device.pool.messageBufferPool = &sync.Pool{
New: func() interface{} {
return new([MaxMessageSize]byte)
},
}
device.pool.inboundElementPool = &sync.Pool{
New: func() interface{} {
return new(QueueInboundElement)
},
}
device.pool.outboundElementPool = &sync.Pool{
New: func() interface{} {
return new(QueueOutboundElement)
},
}
} else {
device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, preallocatedBuffers)
for i := 0; i < preallocatedBuffers; i += 1 {
device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte)
}
device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, preallocatedBuffers)
for i := 0; i < preallocatedBuffers; i += 1 {
device.pool.inboundElementReuseChan <- new(QueueInboundElement)
}
device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, preallocatedBuffers)
for i := 0; i < preallocatedBuffers; i += 1 {
device.pool.outboundElementReuseChan <- new(QueueOutboundElement)
}
}
}
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
if preallocatedBuffers == 0 {
return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte)
} else {
return <-device.pool.messageBufferReuseChan
}
}
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
if preallocatedBuffers == 0 {
device.pool.messageBufferPool.Put(msg)
} else {
device.pool.messageBufferReuseChan <- msg
}
}
func (device *Device) GetInboundElement() *QueueInboundElement {
if preallocatedBuffers == 0 {
return device.pool.inboundElementPool.Get().(*QueueInboundElement)
} else {
return <-device.pool.inboundElementReuseChan
}
}
func (device *Device) PutInboundElement(msg *QueueInboundElement) {
if preallocatedBuffers == 0 {
device.pool.inboundElementPool.Put(msg)
} else {
device.pool.inboundElementReuseChan <- msg
}
}
func (device *Device) GetOutboundElement() *QueueOutboundElement {
if preallocatedBuffers == 0 {
return device.pool.outboundElementPool.Get().(*QueueOutboundElement)
} else {
return <-device.pool.outboundElementReuseChan
}
}
func (device *Device) PutOutboundElement(msg *QueueOutboundElement) {
if preallocatedBuffers == 0 {
device.pool.outboundElementPool.Put(msg)
} else {
device.pool.outboundElementReuseChan <- msg
}
}

View file

@ -55,6 +55,7 @@ func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueIn
return false return false
} }
default: default:
device.PutInboundElement(element)
return false return false
} }
} }
@ -168,15 +169,15 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
} }
// create work element // create work element
peer := value.peer peer := value.peer
elem := &QueueInboundElement{ elem := device.GetInboundElement()
packet: packet, elem.packet = packet
buffer: buffer, elem.buffer = buffer
keypair: keypair, elem.keypair = keypair
dropped: AtomicFalse, elem.dropped = AtomicFalse
endpoint: endpoint, elem.endpoint = endpoint
} elem.counter = 0
elem.mutex = sync.Mutex{}
elem.mutex.Lock() elem.mutex.Lock()
// add to decryption queues // add to decryption queues
@ -246,6 +247,7 @@ func (device *Device) RoutineDecryption() {
// check if dropped // check if dropped
if elem.IsDropped() { if elem.IsDropped() {
device.PutInboundElement(elem)
continue continue
} }
@ -280,7 +282,6 @@ func (device *Device) RoutineDecryption() {
elem.Drop() elem.Drop()
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
elem.buffer = nil elem.buffer = nil
elem.mutex.Unlock()
} }
elem.mutex.Unlock() elem.mutex.Unlock()
} }
@ -487,12 +488,16 @@ func (peer *Peer) RoutineSequentialReceiver() {
logDebug := device.log.Debug logDebug := device.log.Debug
var elem *QueueInboundElement var elem *QueueInboundElement
var ok bool
defer func() { defer func() {
logDebug.Println(peer, "- Routine: sequential receiver - stopped") logDebug.Println(peer, "- Routine: sequential receiver - stopped")
peer.routines.stopping.Done() peer.routines.stopping.Done()
if elem != nil && elem.buffer != nil { if elem != nil {
device.PutMessageBuffer(elem.buffer) if elem.buffer != nil {
device.PutMessageBuffer(elem.buffer)
}
device.PutInboundElement(elem)
} }
}() }()
@ -501,8 +506,11 @@ func (peer *Peer) RoutineSequentialReceiver() {
peer.routines.starting.Done() peer.routines.starting.Done()
for { for {
if elem != nil && elem.buffer != nil { if elem != nil {
device.PutMessageBuffer(elem.buffer) if elem.buffer != nil {
device.PutMessageBuffer(elem.buffer)
}
device.PutInboundElement(elem)
} }
select { select {
@ -510,7 +518,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
case <-peer.routines.stop: case <-peer.routines.stop:
return return
case elem, ok := <-peer.queue.inbound: case elem, ok = <-peer.queue.inbound:
if !ok { if !ok {
return return
@ -621,9 +629,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
offset := MessageTransportOffsetContent offset := MessageTransportOffsetContent
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
_, err := device.tun.device.Write( _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
elem.buffer[:offset+len(elem.packet)],
offset)
if err != nil { if err != nil {
logError.Println("Failed to write packet to TUN device:", err) logError.Println("Failed to write packet to TUN device:", err)
} }

34
send.go
View file

@ -52,10 +52,14 @@ type QueueOutboundElement struct {
} }
func (device *Device) NewOutboundElement() *QueueOutboundElement { func (device *Device) NewOutboundElement() *QueueOutboundElement {
return &QueueOutboundElement{ elem := device.GetOutboundElement()
dropped: AtomicFalse, elem.dropped = AtomicFalse
buffer: device.GetMessageBuffer(), elem.buffer = device.GetMessageBuffer()
} elem.mutex = sync.Mutex{}
elem.nonce = 0
elem.keypair = nil
elem.peer = nil
return elem
} }
func (elem *QueueOutboundElement) Drop() { func (elem *QueueOutboundElement) Drop() {
@ -75,6 +79,7 @@ func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundEle
select { select {
case old := <-queue: case old := <-queue:
device.PutMessageBuffer(old.buffer) device.PutMessageBuffer(old.buffer)
device.PutOutboundElement(old)
default: default:
} }
} }
@ -94,6 +99,7 @@ func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement,
} }
default: default:
element.peer.device.PutMessageBuffer(element.buffer) element.peer.device.PutMessageBuffer(element.buffer)
element.peer.device.PutOutboundElement(element)
} }
} }
@ -111,6 +117,7 @@ func (peer *Peer) SendKeepalive() bool {
return true return true
default: default:
peer.device.PutMessageBuffer(elem.buffer) peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
return false return false
} }
} }
@ -236,8 +243,6 @@ func (peer *Peer) keepKeyFreshSending() {
*/ */
func (device *Device) RoutineReadFromTUN() { func (device *Device) RoutineReadFromTUN() {
elem := device.NewOutboundElement()
logDebug := device.log.Debug logDebug := device.log.Debug
logError := device.log.Error logError := device.log.Error
@ -249,7 +254,14 @@ func (device *Device) RoutineReadFromTUN() {
logDebug.Println("Routine: TUN reader - started") logDebug.Println("Routine: TUN reader - started")
device.state.starting.Done() device.state.starting.Done()
var elem *QueueOutboundElement
for { for {
if elem != nil {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
elem = device.NewOutboundElement()
// read packet // read packet
@ -262,6 +274,7 @@ func (device *Device) RoutineReadFromTUN() {
device.Close() device.Close()
} }
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
return return
} }
@ -304,7 +317,7 @@ func (device *Device) RoutineReadFromTUN() {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
addToNonceQueue(peer.queue.nonce, elem, device) addToNonceQueue(peer.queue.nonce, elem, device)
elem = device.NewOutboundElement() elem = nil
} }
} }
} }
@ -339,6 +352,7 @@ func (peer *Peer) RoutineNonce() {
select { select {
case elem := <-peer.queue.nonce: case elem := <-peer.queue.nonce:
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
default: default:
return return
} }
@ -399,11 +413,13 @@ func (peer *Peer) RoutineNonce() {
case <-peer.signals.flushNonceQueue: case <-peer.signals.flushNonceQueue:
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
flush() flush()
goto NextPacket goto NextPacket
case <-peer.routines.stop: case <-peer.routines.stop:
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
return return
} }
} }
@ -419,6 +435,7 @@ func (peer *Peer) RoutineNonce() {
if elem.nonce >= RejectAfterMessages { if elem.nonce >= RejectAfterMessages {
atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
goto NextPacket goto NextPacket
} }
@ -468,6 +485,7 @@ func (device *Device) RoutineEncryption() {
// check if dropped // check if dropped
if elem.IsDropped() { if elem.IsDropped() {
device.PutOutboundElement(elem)
continue continue
} }
@ -544,6 +562,7 @@ func (peer *Peer) RoutineSequentialSender() {
elem.mutex.Lock() elem.mutex.Lock()
if elem.IsDropped() { if elem.IsDropped() {
device.PutOutboundElement(elem)
continue continue
} }
@ -555,6 +574,7 @@ func (peer *Peer) RoutineSequentialSender() {
length := uint64(len(elem.packet)) length := uint64(len(elem.packet))
err := peer.SendBuffer(elem.packet) err := peer.SendBuffer(elem.packet)
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
if err != nil { if err != nil {
logError.Println(peer, "- Failed to send data packet", err) logError.Println(peer, "- Failed to send data packet", err)
continue continue