c9e4a859ae
In each case, the starting waitgroup did nothing but ensure that the goroutine has launched. Nothing downstream depends on the order in which goroutines launch, and if the Go runtime scheduler is so broken that goroutines don't get launched reasonably promptly, we have much deeper problems. Given all that, simplify the code. Passed a race-enabled stress test 25,000 times without failure. Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
537 lines
10 KiB
Go
537 lines
10 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package device
|
|
|
|
import (
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"golang.org/x/net/ipv4"
|
|
"golang.org/x/net/ipv6"
|
|
"golang.zx2c4.com/wireguard/conn"
|
|
"golang.zx2c4.com/wireguard/ratelimiter"
|
|
"golang.zx2c4.com/wireguard/rwcancel"
|
|
"golang.zx2c4.com/wireguard/tun"
|
|
)
|
|
|
|
type Device struct {
|
|
isUp AtomicBool // device is (going) up
|
|
isClosed AtomicBool // device is closed? (acting as guard)
|
|
log *Logger
|
|
|
|
// synchronized resources (locks acquired in order)
|
|
|
|
state struct {
|
|
stopping sync.WaitGroup
|
|
sync.Mutex
|
|
changing AtomicBool
|
|
current bool
|
|
}
|
|
|
|
net struct {
|
|
stopping sync.WaitGroup
|
|
sync.RWMutex
|
|
bind conn.Bind // bind interface
|
|
netlinkCancel *rwcancel.RWCancel
|
|
port uint16 // listening port
|
|
fwmark uint32 // mark value (0 = disabled)
|
|
}
|
|
|
|
staticIdentity struct {
|
|
sync.RWMutex
|
|
privateKey NoisePrivateKey
|
|
publicKey NoisePublicKey
|
|
}
|
|
|
|
peers struct {
|
|
sync.RWMutex
|
|
keyMap map[NoisePublicKey]*Peer
|
|
}
|
|
|
|
// unprotected / "self-synchronising resources"
|
|
|
|
allowedips AllowedIPs
|
|
indexTable IndexTable
|
|
cookieChecker CookieChecker
|
|
|
|
rate struct {
|
|
underLoadUntil atomic.Value
|
|
limiter ratelimiter.Ratelimiter
|
|
}
|
|
|
|
pool struct {
|
|
messageBufferPool *sync.Pool
|
|
messageBufferReuseChan chan *[MaxMessageSize]byte
|
|
inboundElementPool *sync.Pool
|
|
inboundElementReuseChan chan *QueueInboundElement
|
|
outboundElementPool *sync.Pool
|
|
outboundElementReuseChan chan *QueueOutboundElement
|
|
}
|
|
|
|
queue struct {
|
|
encryption chan *QueueOutboundElement
|
|
decryption chan *QueueInboundElement
|
|
handshake chan QueueHandshakeElement
|
|
}
|
|
|
|
signals struct {
|
|
stop chan struct{}
|
|
}
|
|
|
|
tun struct {
|
|
device tun.Device
|
|
mtu int32
|
|
}
|
|
}
|
|
|
|
/* Converts the peer into a "zombie", which remains in the peer map,
|
|
* but processes no packets and does not exists in the routing table.
|
|
*
|
|
* Must hold device.peers.Mutex
|
|
*/
|
|
func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
|
|
|
|
// stop routing and processing of packets
|
|
|
|
device.allowedips.RemoveByPeer(peer)
|
|
peer.Stop()
|
|
|
|
// remove from peer map
|
|
|
|
delete(device.peers.keyMap, key)
|
|
}
|
|
|
|
func deviceUpdateState(device *Device) {
|
|
|
|
// check if state already being updated (guard)
|
|
|
|
if device.state.changing.Swap(true) {
|
|
return
|
|
}
|
|
|
|
// compare to current state of device
|
|
|
|
device.state.Lock()
|
|
|
|
newIsUp := device.isUp.Get()
|
|
|
|
if newIsUp == device.state.current {
|
|
device.state.changing.Set(false)
|
|
device.state.Unlock()
|
|
return
|
|
}
|
|
|
|
// change state of device
|
|
|
|
switch newIsUp {
|
|
case true:
|
|
if err := device.BindUpdate(); err != nil {
|
|
device.log.Error.Printf("Unable to update bind: %v\n", err)
|
|
device.isUp.Set(false)
|
|
break
|
|
}
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.Start()
|
|
if peer.persistentKeepaliveInterval > 0 {
|
|
peer.SendKeepalive()
|
|
}
|
|
}
|
|
device.peers.RUnlock()
|
|
|
|
case false:
|
|
device.BindClose()
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.Stop()
|
|
}
|
|
device.peers.RUnlock()
|
|
}
|
|
|
|
// update state variables
|
|
|
|
device.state.current = newIsUp
|
|
device.state.changing.Set(false)
|
|
device.state.Unlock()
|
|
|
|
// check for state change in the mean time
|
|
|
|
deviceUpdateState(device)
|
|
}
|
|
|
|
func (device *Device) Up() {
|
|
|
|
// closed device cannot be brought up
|
|
|
|
if device.isClosed.Get() {
|
|
return
|
|
}
|
|
|
|
device.isUp.Set(true)
|
|
deviceUpdateState(device)
|
|
}
|
|
|
|
func (device *Device) Down() {
|
|
device.isUp.Set(false)
|
|
deviceUpdateState(device)
|
|
}
|
|
|
|
func (device *Device) IsUnderLoad() bool {
|
|
|
|
// check if currently under load
|
|
|
|
now := time.Now()
|
|
underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
|
|
if underLoad {
|
|
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime))
|
|
return true
|
|
}
|
|
|
|
// check if recently under load
|
|
|
|
until := device.rate.underLoadUntil.Load().(time.Time)
|
|
return until.After(now)
|
|
}
|
|
|
|
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|
// lock required resources
|
|
|
|
device.staticIdentity.Lock()
|
|
defer device.staticIdentity.Unlock()
|
|
|
|
if sk.Equals(device.staticIdentity.privateKey) {
|
|
return nil
|
|
}
|
|
|
|
device.peers.Lock()
|
|
defer device.peers.Unlock()
|
|
|
|
lockedPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.handshake.mutex.RLock()
|
|
lockedPeers = append(lockedPeers, peer)
|
|
}
|
|
|
|
// remove peers with matching public keys
|
|
|
|
publicKey := sk.publicKey()
|
|
for key, peer := range device.peers.keyMap {
|
|
if peer.handshake.remoteStatic.Equals(publicKey) {
|
|
unsafeRemovePeer(device, peer, key)
|
|
}
|
|
}
|
|
|
|
// update key material
|
|
|
|
device.staticIdentity.privateKey = sk
|
|
device.staticIdentity.publicKey = publicKey
|
|
device.cookieChecker.Init(publicKey)
|
|
|
|
// do static-static DH pre-computations
|
|
|
|
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
|
for _, peer := range device.peers.keyMap {
|
|
handshake := &peer.handshake
|
|
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
|
expiredPeers = append(expiredPeers, peer)
|
|
}
|
|
|
|
for _, peer := range lockedPeers {
|
|
peer.handshake.mutex.RUnlock()
|
|
}
|
|
for _, peer := range expiredPeers {
|
|
peer.ExpireCurrentKeypairs()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
|
device := new(Device)
|
|
|
|
device.isUp.Set(false)
|
|
device.isClosed.Set(false)
|
|
|
|
device.log = logger
|
|
|
|
device.tun.device = tunDevice
|
|
mtu, err := device.tun.device.MTU()
|
|
if err != nil {
|
|
logger.Error.Println("Trouble determining MTU, assuming default:", err)
|
|
mtu = DefaultMTU
|
|
}
|
|
device.tun.mtu = int32(mtu)
|
|
|
|
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
|
|
|
device.rate.limiter.Init()
|
|
device.rate.underLoadUntil.Store(time.Time{})
|
|
|
|
device.indexTable.Init()
|
|
device.allowedips.Reset()
|
|
|
|
device.PopulatePools()
|
|
|
|
// create queues
|
|
|
|
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
|
|
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
|
|
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
|
|
|
|
// prepare signals
|
|
|
|
device.signals.stop = make(chan struct{})
|
|
|
|
// prepare net
|
|
|
|
device.net.port = 0
|
|
device.net.bind = nil
|
|
|
|
// start workers
|
|
|
|
cpus := runtime.NumCPU()
|
|
device.state.stopping.Wait()
|
|
for i := 0; i < cpus; i += 1 {
|
|
device.state.stopping.Add(3)
|
|
go device.RoutineEncryption()
|
|
go device.RoutineDecryption()
|
|
go device.RoutineHandshake()
|
|
}
|
|
|
|
device.state.stopping.Add(2)
|
|
go device.RoutineReadFromTUN()
|
|
go device.RoutineTUNEventReader()
|
|
|
|
return device
|
|
}
|
|
|
|
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
|
device.peers.RLock()
|
|
defer device.peers.RUnlock()
|
|
|
|
return device.peers.keyMap[pk]
|
|
}
|
|
|
|
func (device *Device) RemovePeer(key NoisePublicKey) {
|
|
device.peers.Lock()
|
|
defer device.peers.Unlock()
|
|
// stop peer and remove from routing
|
|
|
|
peer, ok := device.peers.keyMap[key]
|
|
if ok {
|
|
unsafeRemovePeer(device, peer, key)
|
|
}
|
|
}
|
|
|
|
func (device *Device) RemoveAllPeers() {
|
|
device.peers.Lock()
|
|
defer device.peers.Unlock()
|
|
|
|
for key, peer := range device.peers.keyMap {
|
|
unsafeRemovePeer(device, peer, key)
|
|
}
|
|
|
|
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
|
}
|
|
|
|
func (device *Device) FlushPacketQueues() {
|
|
for {
|
|
select {
|
|
case elem, ok := <-device.queue.decryption:
|
|
if ok {
|
|
elem.Drop()
|
|
}
|
|
case elem, ok := <-device.queue.encryption:
|
|
if ok {
|
|
elem.Drop()
|
|
}
|
|
case <-device.queue.handshake:
|
|
default:
|
|
return
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
func (device *Device) Close() {
|
|
if device.isClosed.Swap(true) {
|
|
return
|
|
}
|
|
|
|
device.log.Info.Println("Device closing")
|
|
device.state.changing.Set(true)
|
|
device.state.Lock()
|
|
defer device.state.Unlock()
|
|
|
|
device.tun.device.Close()
|
|
device.BindClose()
|
|
|
|
device.isUp.Set(false)
|
|
|
|
close(device.signals.stop)
|
|
device.state.stopping.Wait()
|
|
|
|
device.RemoveAllPeers()
|
|
|
|
device.FlushPacketQueues()
|
|
|
|
device.rate.limiter.Close()
|
|
|
|
device.state.changing.Set(false)
|
|
device.log.Info.Println("Interface closed")
|
|
}
|
|
|
|
func (device *Device) Wait() chan struct{} {
|
|
return device.signals.stop
|
|
}
|
|
|
|
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
|
if device.isClosed.Get() {
|
|
return
|
|
}
|
|
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.keypairs.RLock()
|
|
sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
|
|
peer.keypairs.RUnlock()
|
|
if sendKeepalive {
|
|
peer.SendKeepalive()
|
|
}
|
|
}
|
|
device.peers.RUnlock()
|
|
}
|
|
|
|
func unsafeCloseBind(device *Device) error {
|
|
var err error
|
|
netc := &device.net
|
|
if netc.netlinkCancel != nil {
|
|
netc.netlinkCancel.Cancel()
|
|
}
|
|
if netc.bind != nil {
|
|
err = netc.bind.Close()
|
|
netc.bind = nil
|
|
}
|
|
netc.stopping.Wait()
|
|
return err
|
|
}
|
|
|
|
func (device *Device) Bind() conn.Bind {
|
|
device.net.Lock()
|
|
defer device.net.Unlock()
|
|
return device.net.bind
|
|
}
|
|
|
|
func (device *Device) BindSetMark(mark uint32) error {
|
|
|
|
device.net.Lock()
|
|
defer device.net.Unlock()
|
|
|
|
// check if modified
|
|
|
|
if device.net.fwmark == mark {
|
|
return nil
|
|
}
|
|
|
|
// update fwmark on existing bind
|
|
|
|
device.net.fwmark = mark
|
|
if device.isUp.Get() && device.net.bind != nil {
|
|
if err := device.net.bind.SetMark(mark); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// clear cached source addresses
|
|
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.Lock()
|
|
defer peer.Unlock()
|
|
if peer.endpoint != nil {
|
|
peer.endpoint.ClearSrc()
|
|
}
|
|
}
|
|
device.peers.RUnlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (device *Device) BindUpdate() error {
|
|
|
|
device.net.Lock()
|
|
defer device.net.Unlock()
|
|
|
|
// close existing sockets
|
|
|
|
if err := unsafeCloseBind(device); err != nil {
|
|
return err
|
|
}
|
|
|
|
// open new sockets
|
|
|
|
if device.isUp.Get() {
|
|
|
|
// bind to new port
|
|
|
|
var err error
|
|
netc := &device.net
|
|
netc.bind, netc.port, err = conn.CreateBind(netc.port)
|
|
if err != nil {
|
|
netc.bind = nil
|
|
netc.port = 0
|
|
return err
|
|
}
|
|
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
|
if err != nil {
|
|
netc.bind.Close()
|
|
netc.bind = nil
|
|
netc.port = 0
|
|
return err
|
|
}
|
|
|
|
// set fwmark
|
|
|
|
if netc.fwmark != 0 {
|
|
err = netc.bind.SetMark(netc.fwmark)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// clear cached source addresses
|
|
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.Lock()
|
|
defer peer.Unlock()
|
|
if peer.endpoint != nil {
|
|
peer.endpoint.ClearSrc()
|
|
}
|
|
}
|
|
device.peers.RUnlock()
|
|
|
|
// start receiving routines
|
|
|
|
device.net.stopping.Add(2)
|
|
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
|
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
|
|
|
device.log.Debug.Println("UDP bind has been updated")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (device *Device) BindClose() error {
|
|
device.net.Lock()
|
|
err := unsafeCloseBind(device)
|
|
device.net.Unlock()
|
|
return err
|
|
}
|