wireguard-go/device/device.go
Josh Bleecher Snyder 4eab21a7b7 device: make RoutineReadFromTUN keep encryption queue alive
RoutineReadFromTUN can trigger a call to SendStagedPackets.
SendStagedPackets attempts to protect against sending
on the encryption queue by checking peer.isRunning and device.isClosed.
However, those are subject to TOCTOU bugs.

If that happens, we get this:

goroutine 1254 [running]:
golang.zx2c4.com/wireguard/device.(*Peer).SendStagedPackets(0xc000798300)
        .../wireguard-go/device/send.go:321 +0x125
golang.zx2c4.com/wireguard/device.(*Device).RoutineReadFromTUN(0xc000014780)
        .../wireguard-go/device/send.go:271 +0x21c
created by golang.zx2c4.com/wireguard/device.NewDevice
        .../wireguard-go/device/device.go:315 +0x298

Fix this with a simple, big hammer: Keep the encryption queue
alive as long as it might be written to.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-09 09:53:00 -08:00

520 lines
13 KiB
Go

/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"runtime"
"sync"
"sync/atomic"
"time"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel"
"golang.zx2c4.com/wireguard/tun"
)
type Device struct {
state struct {
// state holds the device's state. It is accessed atomically.
// Use the device.deviceState method to read it.
// device.deviceState does not acquire the mutex, so it captures only a snapshot.
// During state transitions, the state variable is updated before the device itself.
// The state is thus either the current state of the device or
// the intended future state of the device.
// For example, while executing a call to Up, state will be deviceStateUp.
// There is no guarantee that that intended future state of the device
// will become the actual state; Up can fail.
// The device can also change state multiple times between time of check and time of use.
// Unsynchronized uses of state must therefore be advisory/best-effort only.
state uint32 // actually a deviceState, but typed uint32 for convenience
// stopping blocks until all inputs to Device have been closed.
stopping sync.WaitGroup
// mu protects state changes.
sync.Mutex
}
net struct {
stopping sync.WaitGroup
sync.RWMutex
bind conn.Bind // bind interface
netlinkCancel *rwcancel.RWCancel
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
}
staticIdentity struct {
sync.RWMutex
privateKey NoisePrivateKey
publicKey NoisePublicKey
}
peers struct {
empty AtomicBool // empty reports whether len(keyMap) == 0
sync.RWMutex // protects keyMap
keyMap map[NoisePublicKey]*Peer
}
allowedips AllowedIPs
indexTable IndexTable
cookieChecker CookieChecker
rate struct {
underLoadUntil int64
limiter ratelimiter.Ratelimiter
}
pool struct {
messageBuffers *WaitPool
inboundElements *WaitPool
outboundElements *WaitPool
}
queue struct {
encryption *outboundQueue
decryption *inboundQueue
handshake *handshakeQueue
}
tun struct {
device tun.Device
mtu int32
}
ipcMutex sync.RWMutex
closed chan struct{}
log *Logger
}
// deviceState represents the state of a Device.
// There are three states: down, up, closed.
// Transitions:
//
// down -----+
// ↑↓ ↓
// up -> closed
//
type deviceState uint32
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
const (
deviceStateDown deviceState = iota
deviceStateUp
deviceStateClosed
)
// deviceState returns device.state.state as a deviceState
// See those docs for how to interpret this value.
func (device *Device) deviceState() deviceState {
return deviceState(atomic.LoadUint32(&device.state.state))
}
// isClosed reports whether the device is closed (or is closing).
// See device.state.state comments for how to interpret this value.
func (device *Device) isClosed() bool {
return device.deviceState() == deviceStateClosed
}
// isUp reports whether the device is up (or is attempting to come up).
// See device.state.state comments for how to interpret this value.
func (device *Device) isUp() bool {
return device.deviceState() == deviceStateUp
}
// Must hold device.peers.Lock()
func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
// stop routing and processing of packets
device.allowedips.RemoveByPeer(peer)
peer.Stop()
// remove from peer map
delete(device.peers.keyMap, key)
device.peers.empty.Set(len(device.peers.keyMap) == 0)
}
// changeState attempts to change the device state to match want.
func (device *Device) changeState(want deviceState) {
device.state.Lock()
defer device.state.Unlock()
old := device.deviceState()
if old == deviceStateClosed {
// once closed, always closed
device.log.Verbosef("Interface closed, ignored requested state %s", want)
return
}
switch want {
case old:
return
case deviceStateUp:
atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
if ok := device.upLocked(); ok {
break
}
fallthrough // up failed; bring the device all the way back down
case deviceStateDown:
atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
device.downLocked()
}
device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
}
// upLocked attempts to bring the device up and reports whether it succeeded.
// The caller must hold device.state.mu and is responsible for updating device.state.state.
func (device *Device) upLocked() bool {
if err := device.BindUpdate(); err != nil {
device.log.Errorf("Unable to update bind: %v", err)
return false
}
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Start()
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
peer.SendKeepalive()
}
}
device.peers.RUnlock()
return true
}
// downLocked attempts to bring the device down.
// The caller must hold device.state.mu and is responsible for updating device.state.state.
func (device *Device) downLocked() {
err := device.BindClose()
if err != nil {
device.log.Errorf("Bind close failed: %v", err)
}
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Stop()
}
device.peers.RUnlock()
}
func (device *Device) Up() {
device.changeState(deviceStateUp)
}
func (device *Device) Down() {
device.changeState(deviceStateDown)
}
func (device *Device) IsUnderLoad() bool {
// check if currently under load
now := time.Now()
underLoad := len(device.queue.handshake.c) >= UnderLoadQueueSize
if underLoad {
atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
return true
}
// check if recently under load
return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// lock required resources
device.staticIdentity.Lock()
defer device.staticIdentity.Unlock()
if sk.Equals(device.staticIdentity.privateKey) {
return nil
}
device.peers.Lock()
defer device.peers.Unlock()
lockedPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
peer.handshake.mutex.RLock()
lockedPeers = append(lockedPeers, peer)
}
// remove peers with matching public keys
publicKey := sk.publicKey()
for key, peer := range device.peers.keyMap {
if peer.handshake.remoteStatic.Equals(publicKey) {
peer.handshake.mutex.RUnlock()
removePeerLocked(device, peer, key)
peer.handshake.mutex.RLock()
}
}
// update key material
device.staticIdentity.privateKey = sk
device.staticIdentity.publicKey = publicKey
device.cookieChecker.Init(publicKey)
// do static-static DH pre-computations
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer)
}
for _, peer := range lockedPeers {
peer.handshake.mutex.RUnlock()
}
for _, peer := range expiredPeers {
peer.ExpireCurrentKeypairs()
}
return nil
}
func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
device := new(Device)
device.state.state = uint32(deviceStateDown)
device.closed = make(chan struct{})
device.log = logger
device.tun.device = tunDevice
mtu, err := device.tun.device.MTU()
if err != nil {
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
mtu = DefaultMTU
}
device.tun.mtu = int32(mtu)
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init()
device.indexTable.Init()
device.PopulatePools()
// create queues
device.queue.handshake = newHandshakeQueue()
device.queue.encryption = newOutboundQueue()
device.queue.decryption = newInboundQueue()
// prepare net
device.net.port = 0
device.net.bind = nil
// start workers
cpus := runtime.NumCPU()
device.state.stopping.Wait()
for i := 0; i < cpus; i++ {
go device.RoutineEncryption()
go device.RoutineDecryption()
go device.RoutineHandshake()
}
device.state.stopping.Add(1) // RoutineReadFromTUN
device.queue.encryption.wg.Add(1) // RoutineReadFromTUN
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
return device
}
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
return device.peers.keyMap[pk]
}
func (device *Device) RemovePeer(key NoisePublicKey) {
device.peers.Lock()
defer device.peers.Unlock()
// stop peer and remove from routing
peer, ok := device.peers.keyMap[key]
if ok {
removePeerLocked(device, peer, key)
}
}
func (device *Device) RemoveAllPeers() {
device.peers.Lock()
defer device.peers.Unlock()
for key, peer := range device.peers.keyMap {
removePeerLocked(device, peer, key)
}
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
}
func (device *Device) Close() {
device.state.Lock()
defer device.state.Unlock()
if device.isClosed() {
return
}
atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
device.log.Verbosef("Device closing")
device.tun.device.Close()
device.downLocked()
// Remove peers before closing queues,
// because peers assume that queues are active.
device.RemoveAllPeers()
// We kept a reference to the encryption and decryption queues,
// in case we started any new peers that might write to them.
// No new peers are coming; we are done with these queues.
device.queue.encryption.wg.Done()
device.queue.decryption.wg.Done()
device.queue.handshake.wg.Done()
device.state.stopping.Wait()
device.rate.limiter.Close()
device.log.Verbosef("Device closed")
close(device.closed)
}
func (device *Device) Wait() chan struct{} {
return device.closed
}
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
if !device.isUp() {
return
}
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.keypairs.RLock()
sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
peer.keypairs.RUnlock()
if sendKeepalive {
peer.SendKeepalive()
}
}
device.peers.RUnlock()
}
func unsafeCloseBind(device *Device) error {
var err error
netc := &device.net
if netc.netlinkCancel != nil {
netc.netlinkCancel.Cancel()
}
if netc.bind != nil {
err = netc.bind.Close()
netc.bind = nil
}
netc.stopping.Wait()
return err
}
func (device *Device) Bind() conn.Bind {
device.net.Lock()
defer device.net.Unlock()
return device.net.bind
}
func (device *Device) BindSetMark(mark uint32) error {
device.net.Lock()
defer device.net.Unlock()
// check if modified
if device.net.fwmark == mark {
return nil
}
// update fwmark on existing bind
device.net.fwmark = mark
if device.isUp() && device.net.bind != nil {
if err := device.net.bind.SetMark(mark); err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
return nil
}
func (device *Device) BindUpdate() error {
device.net.Lock()
defer device.net.Unlock()
// close existing sockets
if err := unsafeCloseBind(device); err != nil {
return err
}
// open new sockets
if !device.isUp() {
return nil
}
// bind to new port
var err error
netc := &device.net
netc.bind, netc.port, err = conn.CreateBind(netc.port)
if err != nil {
netc.bind = nil
netc.port = 0
return err
}
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil {
netc.bind.Close()
netc.bind = nil
netc.port = 0
return err
}
// set fwmark
if netc.fwmark != 0 {
err = netc.bind.SetMark(netc.fwmark)
if err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
// start receiving routines
device.net.stopping.Add(2)
device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
device.log.Verbosef("UDP bind has been updated")
return nil
}
func (device *Device) BindClose() error {
device.net.Lock()
err := unsafeCloseBind(device)
device.net.Unlock()
return err
}