device: use a waiting sync.Pool instead of a channel

Channels are FIFO which means we have guaranteed cache misses.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-02-02 18:37:49 +01:00
parent a9f80d8c58
commit 4846070322
4 changed files with 117 additions and 68 deletions

View file

@ -42,7 +42,6 @@ func TestPeerAlignment(t *testing.T) {
checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning)) checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
} }
// TestDeviceAlignment checks that atomically-accessed fields are // TestDeviceAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package. // aligned to 64-bit boundaries, as required by the atomic package.
// //

View file

@ -67,12 +67,9 @@ type Device struct {
} }
pool struct { pool struct {
messageBufferPool *sync.Pool messageBuffers *WaitPool
messageBufferReuseChan chan *[MaxMessageSize]byte inboundElements *WaitPool
inboundElementPool *sync.Pool outboundElements *WaitPool
inboundElementReuseChan chan *QueueInboundElement
outboundElementPool *sync.Pool
outboundElementReuseChan chan *QueueOutboundElement
} }
queue struct { queue struct {

View file

@ -5,87 +5,80 @@
package device package device
import "sync" import (
"sync"
"sync/atomic"
)
type WaitPool struct {
pool sync.Pool
cond sync.Cond
lock sync.Mutex
count uint32
max uint32
}
func NewWaitPool(max uint32, new func() interface{}) *WaitPool {
p := &WaitPool{pool: sync.Pool{New: new}, max: max}
p.cond = sync.Cond{L: &p.lock}
return p
}
func (p *WaitPool) Get() interface{} {
if p.max != 0 {
p.lock.Lock()
for atomic.LoadUint32(&p.count) >= p.max {
p.cond.Wait()
}
atomic.AddUint32(&p.count, 1)
p.lock.Unlock()
}
return p.pool.Get()
}
func (p *WaitPool) Put(x interface{}) {
p.pool.Put(x)
if p.max == 0 {
return
}
atomic.AddUint32(&p.count, ^uint32(0))
p.cond.Signal()
}
func (device *Device) PopulatePools() { func (device *Device) PopulatePools() {
if PreallocatedBuffersPerPool == 0 { device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
device.pool.messageBufferPool = &sync.Pool{
New: func() interface{} {
return new([MaxMessageSize]byte) return new([MaxMessageSize]byte)
}, })
} device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
device.pool.inboundElementPool = &sync.Pool{
New: func() interface{} {
return new(QueueInboundElement) return new(QueueInboundElement)
}, })
} device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
device.pool.outboundElementPool = &sync.Pool{
New: func() interface{} {
return new(QueueOutboundElement) return new(QueueOutboundElement)
}, })
}
} else {
device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool)
for i := 0; i < PreallocatedBuffersPerPool; i++ {
device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte)
}
device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool)
for i := 0; i < PreallocatedBuffersPerPool; i++ {
device.pool.inboundElementReuseChan <- new(QueueInboundElement)
}
device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool)
for i := 0; i < PreallocatedBuffersPerPool; i++ {
device.pool.outboundElementReuseChan <- new(QueueOutboundElement)
}
}
} }
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
if PreallocatedBuffersPerPool == 0 { return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte)
} else {
return <-device.pool.messageBufferReuseChan
}
} }
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
if PreallocatedBuffersPerPool == 0 { device.pool.messageBuffers.Put(msg)
device.pool.messageBufferPool.Put(msg)
} else {
device.pool.messageBufferReuseChan <- msg
}
} }
func (device *Device) GetInboundElement() *QueueInboundElement { func (device *Device) GetInboundElement() *QueueInboundElement {
if PreallocatedBuffersPerPool == 0 { return device.pool.inboundElements.Get().(*QueueInboundElement)
return device.pool.inboundElementPool.Get().(*QueueInboundElement)
} else {
return <-device.pool.inboundElementReuseChan
}
} }
func (device *Device) PutInboundElement(elem *QueueInboundElement) { func (device *Device) PutInboundElement(elem *QueueInboundElement) {
elem.clearPointers() elem.clearPointers()
if PreallocatedBuffersPerPool == 0 { device.pool.inboundElements.Put(elem)
device.pool.inboundElementPool.Put(elem)
} else {
device.pool.inboundElementReuseChan <- elem
}
} }
func (device *Device) GetOutboundElement() *QueueOutboundElement { func (device *Device) GetOutboundElement() *QueueOutboundElement {
if PreallocatedBuffersPerPool == 0 { return device.pool.outboundElements.Get().(*QueueOutboundElement)
return device.pool.outboundElementPool.Get().(*QueueOutboundElement)
} else {
return <-device.pool.outboundElementReuseChan
}
} }
func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
elem.clearPointers() elem.clearPointers()
if PreallocatedBuffersPerPool == 0 { device.pool.outboundElements.Put(elem)
device.pool.outboundElementPool.Put(elem)
} else {
device.pool.outboundElementReuseChan <- elem
}
} }

60
device/pools_test.go Normal file
View file

@ -0,0 +1,60 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestWaitPool(t *testing.T) {
var wg sync.WaitGroup
trials := int32(100000)
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
t.Skip("Not enough cores")
}
p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) })
wg.Add(workers)
max := uint32(0)
updateMax := func() {
count := atomic.LoadUint32(&p.count)
if count > p.max {
t.Errorf("count (%d) > max (%d)", count, p.max)
}
for {
old := atomic.LoadUint32(&max)
if count <= old {
break
}
if atomic.CompareAndSwapUint32(&max, old, count) {
break
}
}
}
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for atomic.AddInt32(&trials, -1) > 0 {
updateMax()
x := p.Get()
updateMax()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
updateMax()
p.Put(x)
updateMax()
}
}()
}
wg.Wait()
if max != p.max {
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
}
}