Rework of entire locking system

Locking on the Device instance is now much more fined-grained,
seperating out the fields into "resources" st. most common interactions
only require a small number.
This commit is contained in:
Mathias Hall-Andersen 2018-02-02 16:40:14 +01:00
parent 1e42b14022
commit 029410b118
10 changed files with 386 additions and 239 deletions

View file

@ -65,12 +65,12 @@ func unsafeCloseBind(device *Device) error {
} }
func (device *Device) BindUpdate() error { func (device *Device) BindUpdate() error {
device.mutex.Lock()
defer device.mutex.Unlock()
netc := &device.net device.net.mutex.Lock()
netc.mutex.Lock() defer device.net.mutex.Unlock()
defer netc.mutex.Unlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
// close existing sockets // close existing sockets
@ -85,6 +85,7 @@ func (device *Device) BindUpdate() error {
// bind to new port // bind to new port
var err error var err error
netc := &device.net
netc.bind, netc.port, err = CreateBind(netc.port) netc.bind, netc.port, err = CreateBind(netc.port)
if err != nil { if err != nil {
netc.bind = nil netc.bind = nil
@ -100,12 +101,12 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses // clear cached source addresses
for _, peer := range device.peers { for _, peer := range device.peers.keyMap {
peer.mutex.Lock() peer.mutex.Lock()
defer peer.mutex.Unlock()
if peer.endpoint != nil { if peer.endpoint != nil {
peer.endpoint.ClearSrc() peer.endpoint.ClearSrc()
} }
peer.mutex.Unlock()
} }
// start receiving routines // start receiving routines
@ -120,10 +121,8 @@ func (device *Device) BindUpdate() error {
} }
func (device *Device) BindClose() error { func (device *Device) BindClose() error {
device.mutex.Lock()
device.net.mutex.Lock() device.net.mutex.Lock()
err := unsafeCloseBind(device) err := unsafeCloseBind(device)
device.net.mutex.Unlock() device.net.mutex.Unlock()
device.mutex.Unlock()
return err return err
} }

View file

@ -11,44 +11,108 @@ import (
type Device struct { type Device struct {
isUp AtomicBool // device is (going) up isUp AtomicBool // device is (going) up
isClosed AtomicBool // device is closed? (acting as guard) isClosed AtomicBool // device is closed? (acting as guard)
log *Logger // collection of loggers for levels log *Logger
idCounter uint // for assigning debug ids to peers
fwMark uint32 // synchronized resources (locks acquired in order)
tun struct {
device TUNDevice
mtu int32
}
state struct { state struct {
mutex deadlock.Mutex mutex deadlock.Mutex
changing AtomicBool changing AtomicBool
current bool current bool
} }
pool struct {
messageBuffers sync.Pool
}
net struct { net struct {
mutex deadlock.RWMutex mutex deadlock.RWMutex
bind Bind // bind interface bind Bind // bind interface
port uint16 // listening port port uint16 // listening port
fwmark uint32 // mark value (0 = disabled) fwmark uint32 // mark value (0 = disabled)
} }
noise struct {
mutex deadlock.RWMutex mutex deadlock.RWMutex
privateKey NoisePrivateKey privateKey NoisePrivateKey
publicKey NoisePublicKey publicKey NoisePublicKey
routingTable RoutingTable }
routing struct {
mutex deadlock.RWMutex
table RoutingTable
}
peers struct {
mutex deadlock.RWMutex
keyMap map[NoisePublicKey]*Peer
}
// unprotected / "self-synchronising resources"
indices IndexTable indices IndexTable
mac CookieChecker
rate struct {
underLoadUntil atomic.Value
limiter Ratelimiter
}
pool struct {
messageBuffers sync.Pool
}
queue struct { queue struct {
encryption chan *QueueOutboundElement encryption chan *QueueOutboundElement
decryption chan *QueueInboundElement decryption chan *QueueInboundElement
handshake chan QueueHandshakeElement handshake chan QueueHandshakeElement
} }
signal struct { signal struct {
stop Signal stop Signal
} }
underLoadUntil atomic.Value
ratelimiter Ratelimiter tun struct {
peers map[NoisePublicKey]*Peer device TUNDevice
mac CookieChecker mtu int32
}
}
/* Converts the peer into a "zombie", which remains in the peer map,
* but processes no packets and does not exists in the routing table.
*
* Must hold:
* device.peers.mutex : exclusive lock
* device.routing : exclusive lock
*/
func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
// stop routing and processing of packets
device.routing.table.RemovePeer(peer)
peer.Stop()
// clean index table
kp := &peer.keyPairs
kp.mutex.Lock()
if kp.previous != nil {
device.indices.Delete(kp.previous.localIndex)
}
if kp.current != nil {
device.indices.Delete(kp.current.localIndex)
}
if kp.next != nil {
device.indices.Delete(kp.next.localIndex)
}
kp.previous = nil
kp.current = nil
kp.next = nil
kp.mutex.Unlock()
// remove from peer map
delete(device.peers.keyMap, key)
} }
func deviceUpdateState(device *Device) { func deviceUpdateState(device *Device) {
@ -59,56 +123,56 @@ func deviceUpdateState(device *Device) {
return return
} }
func() {
// compare to current state of device // compare to current state of device
device.state.mutex.Lock() device.state.mutex.Lock()
defer device.state.mutex.Unlock()
newIsUp := device.isUp.Get() newIsUp := device.isUp.Get()
if newIsUp == device.state.current { if newIsUp == device.state.current {
device.state.mutex.Unlock()
device.state.changing.Set(false) device.state.changing.Set(false)
return return
} }
device.state.mutex.Unlock()
// change state of device // change state of device
switch newIsUp { switch newIsUp {
case true: case true:
// start listener
if err := device.BindUpdate(); err != nil { if err := device.BindUpdate(); err != nil {
device.isUp.Set(false) device.isUp.Set(false)
break break
} }
// start every peer device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
for _, peer := range device.peers { for _, peer := range device.peers.keyMap {
peer.Start() peer.Start()
} }
case false: case false:
// stop listening
device.BindClose() device.BindClose()
// stop every peer device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
for _, peer := range device.peers { for _, peer := range device.peers.keyMap {
println("stopping peer")
peer.Stop() peer.Stop()
} }
} }
// update state variables // update state variables
// and check for state change in the mean time
device.state.current = newIsUp device.state.current = newIsUp
device.state.changing.Set(false) device.state.changing.Set(false)
}()
// check for state change in the mean time
deviceUpdateState(device) deviceUpdateState(device)
} }
@ -133,18 +197,6 @@ func (device *Device) Down() {
deviceUpdateState(device) deviceUpdateState(device)
} }
/* Warning:
* The caller must hold the device mutex (write lock)
*/
func removePeerUnsafe(device *Device, key NoisePublicKey) {
peer, ok := device.peers[key]
if !ok {
return
}
device.routingTable.RemovePeer(peer)
delete(device.peers, key)
}
func (device *Device) IsUnderLoad() bool { func (device *Device) IsUnderLoad() bool {
// check if currently under load // check if currently under load
@ -152,54 +204,66 @@ func (device *Device) IsUnderLoad() bool {
now := time.Now() now := time.Now()
underLoad := len(device.queue.handshake) >= UnderLoadQueueSize underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
if underLoad { if underLoad {
device.underLoadUntil.Store(now.Add(time.Second)) device.rate.underLoadUntil.Store(now.Add(time.Second))
return true return true
} }
// check if recently under load // check if recently under load
until := device.underLoadUntil.Load().(time.Time) until := device.rate.underLoadUntil.Load().(time.Time)
return until.After(now) return until.After(now)
} }
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
device.mutex.Lock()
defer device.mutex.Unlock() // lock required resources
device.noise.mutex.Lock()
defer device.noise.mutex.Unlock()
device.routing.mutex.Lock()
defer device.routing.mutex.Unlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
for _, peer := range device.peers.keyMap {
peer.handshake.mutex.RLock()
defer peer.handshake.mutex.RUnlock()
}
// remove peers with matching public keys // remove peers with matching public keys
publicKey := sk.publicKey() publicKey := sk.publicKey()
for key, peer := range device.peers { for key, peer := range device.peers.keyMap {
h := &peer.handshake if peer.handshake.remoteStatic.Equals(publicKey) {
h.mutex.RLock() unsafeRemovePeer(device, peer, key)
if h.remoteStatic.Equals(publicKey) {
removePeerUnsafe(device, key)
} }
h.mutex.RUnlock()
} }
// update key material // update key material
device.privateKey = sk device.noise.privateKey = sk
device.publicKey = publicKey device.noise.publicKey = publicKey
device.mac.Init(publicKey) device.mac.Init(publicKey)
// do DH pre-computations // do static-static DH pre-computations
rmKey := device.privateKey.IsZero() rmKey := device.noise.privateKey.IsZero()
for key, peer := range device.peers.keyMap {
hs := &peer.handshake
for key, peer := range device.peers {
h := &peer.handshake
h.mutex.Lock()
if rmKey { if rmKey {
h.precomputedStaticStatic = [NoisePublicKeySize]byte{} hs.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else { } else {
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) hs.precomputedStaticStatic = device.noise.privateKey.sharedSecret(hs.remoteStatic)
if isZero(h.precomputedStaticStatic[:]) {
removePeerUnsafe(device, key)
} }
if isZero(hs.precomputedStaticStatic[:]) {
unsafeRemovePeer(device, peer, key)
} }
h.mutex.Unlock()
} }
return nil return nil
@ -215,21 +279,23 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
func NewDevice(tun TUNDevice, logger *Logger) *Device { func NewDevice(tun TUNDevice, logger *Logger) *Device {
device := new(Device) device := new(Device)
device.mutex.Lock()
defer device.mutex.Unlock()
device.isUp.Set(false) device.isUp.Set(false)
device.isClosed.Set(false) device.isClosed.Set(false)
device.log = logger device.log = logger
device.peers = make(map[NoisePublicKey]*Peer)
device.tun.device = tun device.tun.device = tun
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
// initialize anti-DoS / anti-scanning features
device.rate.limiter.Init()
device.rate.underLoadUntil.Store(time.Time{})
// initialize noise & crypt-key routine
device.indices.Init() device.indices.Init()
device.ratelimiter.Init() device.routing.table.Reset()
device.routingTable.Reset()
device.underLoadUntil.Store(time.Time{})
// setup buffer pool // setup buffer pool
@ -264,36 +330,50 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
go device.RoutineReadFromTUN() go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader() go device.RoutineTUNEventReader()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) go device.rate.limiter.RoutineGarbageCollector(device.signal.stop)
return device return device
} }
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.mutex.RLock() device.peers.mutex.RLock()
defer device.mutex.RUnlock() defer device.peers.mutex.RUnlock()
return device.peers[pk]
return device.peers.keyMap[pk]
} }
func (device *Device) RemovePeer(key NoisePublicKey) { func (device *Device) RemovePeer(key NoisePublicKey) {
device.mutex.Lock() device.noise.mutex.Lock()
defer device.mutex.Unlock() defer device.noise.mutex.Unlock()
removePeerUnsafe(device, key)
device.routing.mutex.Lock()
defer device.routing.mutex.Unlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
// stop peer and remove from routing
peer, ok := device.peers.keyMap[key]
if ok {
unsafeRemovePeer(device, peer, key)
}
} }
func (device *Device) RemoveAllPeers() { func (device *Device) RemoveAllPeers() {
device.mutex.Lock()
defer device.mutex.Unlock()
for key, peer := range device.peers { device.routing.mutex.Lock()
peer.Stop() defer device.routing.mutex.Unlock()
peer, ok := device.peers[key]
if !ok { device.peers.mutex.Lock()
return defer device.peers.mutex.Unlock()
}
device.routingTable.RemovePeer(peer) for key, peer := range device.peers.keyMap {
delete(device.peers, key) println("rm", peer.String())
unsafeRemovePeer(device, peer, key)
} }
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
} }
func (device *Device) Close() { func (device *Device) Close() {
@ -305,7 +385,6 @@ func (device *Device) Close() {
device.tun.device.Close() device.tun.device.Close()
device.BindClose() device.BindClose()
device.isUp.Set(false) device.isUp.Set(false)
println("remove")
device.RemoveAllPeers() device.RemoveAllPeers()
device.log.Info.Println("Interface closed") device.log.Info.Println("Interface closed")
} }

View file

@ -3,6 +3,7 @@ package main
import ( import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/subtle"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
"hash" "hash"
@ -58,11 +59,11 @@ func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
} }
func isZero(val []byte) bool { func isZero(val []byte) bool {
var acc byte acc := 1
for _, b := range val { for _, b := range val {
acc |= b acc &= subtle.ConstantTimeByteEq(b, 0)
} }
return acc == 0 return acc == 1
} }
func setZero(arr []byte) { func setZero(arr []byte) {

View file

@ -137,6 +137,10 @@ func init() {
} }
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
device.noise.mutex.Lock()
defer device.noise.mutex.Unlock()
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
@ -187,7 +191,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
ss[:], ss[:],
) )
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:]) aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:])
}() }()
handshake.mixHash(msg.Static[:]) handshake.mixHash(msg.Static[:])
@ -212,16 +216,19 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
} }
func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
if msg.Type != MessageInitiationType {
return nil
}
var ( var (
hash [blake2s.Size]byte hash [blake2s.Size]byte
chainKey [blake2s.Size]byte chainKey [blake2s.Size]byte
) )
mixHash(&hash, &InitialHash, device.publicKey[:]) if msg.Type != MessageInitiationType {
return nil
}
device.noise.mutex.RLock()
defer device.noise.mutex.RUnlock()
mixHash(&hash, &InitialHash, device.noise.publicKey[:])
mixHash(&hash, &hash, msg.Ephemeral[:]) mixHash(&hash, &hash, msg.Ephemeral[:])
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
@ -231,7 +238,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
var peerPK NoisePublicKey var peerPK NoisePublicKey
func() { func() {
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
ss := device.privateKey.sharedSecret(msg.Ephemeral) ss := device.noise.privateKey.sharedSecret(msg.Ephemeral)
KDF2(&chainKey, &key, chainKey[:], ss[:]) KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
@ -407,7 +414,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
}() }()
func() { func() {
ss := device.privateKey.sharedSecret(msg.Ephemeral) ss := device.noise.privateKey.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:]) mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:]) setZero(ss[:])
}() }()

View file

@ -14,7 +14,6 @@ const (
) )
type Peer struct { type Peer struct {
id uint
isRunning AtomicBool isRunning AtomicBool
mutex deadlock.RWMutex mutex deadlock.RWMutex
persistentKeepaliveInterval uint64 persistentKeepaliveInterval uint64
@ -22,17 +21,20 @@ type Peer struct {
handshake Handshake handshake Handshake
device *Device device *Device
endpoint Endpoint endpoint Endpoint
stats struct { stats struct {
txBytes uint64 // bytes send to peer (endpoint) txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer rxBytes uint64 // bytes received from peer
lastHandshakeNano int64 // nano seconds since epoch lastHandshakeNano int64 // nano seconds since epoch
} }
time struct { time struct {
mutex deadlock.RWMutex mutex deadlock.RWMutex
lastSend time.Time // last send message lastSend time.Time // last send message
lastHandshake time.Time // last completed handshake lastHandshake time.Time // last completed handshake
nextKeepalive time.Time nextKeepalive time.Time
} }
signal struct { signal struct {
newKeyPair Signal // size 1, new key pair was generated newKeyPair Signal // size 1, new key pair was generated
handshakeCompleted Signal // size 1, handshake completed handshakeCompleted Signal // size 1, handshake completed
@ -41,7 +43,9 @@ type Peer struct {
messageSend Signal // size 1, message was send to peer messageSend Signal // size 1, message was send to peer
messageReceived Signal // size 1, authenticated message recv messageReceived Signal // size 1, authenticated message recv
} }
timer struct { timer struct {
// state related to WireGuard timers // state related to WireGuard timers
keepalivePersistent Timer // set for persistent keepalives keepalivePersistent Timer // set for persistent keepalives
@ -54,17 +58,20 @@ type Peer struct {
sendLastMinuteHandshake bool sendLastMinuteHandshake bool
needAnotherKeepalive bool needAnotherKeepalive bool
} }
queue struct { queue struct {
nonce chan *QueueOutboundElement // nonce / pre-handshake queue nonce chan *QueueOutboundElement // nonce / pre-handshake queue
outbound chan *QueueOutboundElement // sequential ordering of work outbound chan *QueueOutboundElement // sequential ordering of work
inbound chan *QueueInboundElement // sequential ordering of work inbound chan *QueueInboundElement // sequential ordering of work
} }
routines struct { routines struct {
mutex deadlock.Mutex // held when stopping / starting routines mutex deadlock.Mutex // held when stopping / starting routines
starting sync.WaitGroup // routines pending start starting sync.WaitGroup // routines pending start
stopping sync.WaitGroup // routines pending stop stopping sync.WaitGroup // routines pending stop
stop Signal // size 0, stop all goroutines in peer stop Signal // size 0, stop all goroutines in peer
} }
mac CookieGenerator mac CookieGenerator
} }
@ -74,8 +81,22 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return nil, errors.New("Device closed") return nil, errors.New("Device closed")
} }
device.mutex.Lock() // lock resources
defer device.mutex.Unlock()
device.state.mutex.Lock()
defer device.state.mutex.Unlock()
device.noise.mutex.RLock()
defer device.noise.mutex.RUnlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
// check if over limit
if len(device.peers.keyMap) >= MaxPeers {
return nil, errors.New("Too many peers")
}
// create peer // create peer
@ -94,32 +115,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.timer.handshakeDeadline = NewTimer() peer.timer.handshakeDeadline = NewTimer()
peer.timer.handshakeTimeout = NewTimer() peer.timer.handshakeTimeout = NewTimer()
// assign id for debugging
peer.id = device.idCounter
device.idCounter += 1
// check if over limit
if len(device.peers) >= MaxPeers {
return nil, errors.New("Too many peers")
}
// map public key // map public key
_, ok := device.peers[pk] _, ok := device.peers.keyMap[pk]
if ok { if ok {
return nil, errors.New("Adding existing peer") return nil, errors.New("Adding existing peer")
} }
device.peers[pk] = peer device.peers.keyMap[pk] = peer
// precompute DH // precompute DH
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
handshake.remoteStatic = pk handshake.remoteStatic = pk
handshake.precomputedStaticStatic = handshake.precomputedStaticStatic = device.noise.privateKey.sharedSecret(pk)
device.privateKey.sharedSecret(handshake.remoteStatic)
handshake.mutex.Unlock() handshake.mutex.Unlock()
// reset endpoint // reset endpoint
@ -134,11 +143,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// start peer // start peer
peer.device.state.mutex.Lock()
if peer.device.isUp.Get() { if peer.device.isUp.Get() {
peer.Start() peer.Start()
} }
peer.device.state.mutex.Unlock()
return peer, nil return peer, nil
} }
@ -166,14 +173,12 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
func (peer *Peer) String() string { func (peer *Peer) String() string {
if peer.endpoint == nil { if peer.endpoint == nil {
return fmt.Sprintf( return fmt.Sprintf(
"peer(%d unknown %s)", "peer(unknown %s)",
peer.id,
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
) )
} }
return fmt.Sprintf( return fmt.Sprintf(
"peer(%d %s %s)", "peer(%s %s)",
peer.id,
peer.endpoint.DstToString(), peer.endpoint.DstToString(),
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
) )
@ -181,8 +186,12 @@ func (peer *Peer) String() string {
func (peer *Peer) Start() { func (peer *Peer) Start() {
if peer.device.isClosed.Get() {
return
}
peer.routines.mutex.Lock() peer.routines.mutex.Lock()
defer peer.routines.mutex.Lock() defer peer.routines.mutex.Unlock()
peer.device.log.Debug.Println("Starting:", peer.String()) peer.device.log.Debug.Println("Starting:", peer.String())
@ -222,7 +231,7 @@ func (peer *Peer) Start() {
func (peer *Peer) Stop() { func (peer *Peer) Stop() {
peer.routines.mutex.Lock() peer.routines.mutex.Lock()
defer peer.routines.mutex.Lock() defer peer.routines.mutex.Unlock()
peer.device.log.Debug.Println("Stopping:", peer.String()) peer.device.log.Debug.Println("Stopping:", peer.String())

View file

@ -372,7 +372,7 @@ func (device *Device) RoutineHandshake() {
// check ratelimiter // check ratelimiter
if !device.ratelimiter.Allow(elem.endpoint.DstIP()) { if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
continue continue
} }
} }
@ -495,19 +495,23 @@ func (device *Device) RoutineHandshake() {
func (peer *Peer) RoutineSequentialReceiver() { func (peer *Peer) RoutineSequentialReceiver() {
defer peer.routines.stopping.Done()
device := peer.device device := peer.device
logInfo := device.log.Info logInfo := device.log.Info
logError := device.log.Error logError := device.log.Error
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, sequential receiver, started for peer", peer.id) logDebug.Println("Routine, sequential receiver, started for peer", peer.String())
peer.routines.starting.Done()
for { for {
select { select {
case <-peer.routines.stop.Wait(): case <-peer.routines.stop.Wait():
logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id) logDebug.Println("Routine, sequential receiver, stopped for peer", peer.String())
return return
case elem := <-peer.queue.inbound: case elem := <-peer.queue.inbound:
@ -581,7 +585,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// verify IPv4 source // verify IPv4 source
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.routingTable.LookupIPv4(src) != peer { if device.routing.table.LookupIPv4(src) != peer {
logInfo.Println( logInfo.Println(
"IPv4 packet with disallowed source address from", "IPv4 packet with disallowed source address from",
peer.String(), peer.String(),
@ -609,7 +613,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// verify IPv6 source // verify IPv6 source
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routingTable.LookupIPv6(src) != peer { if device.routing.table.LookupIPv6(src) != peer {
logInfo.Println( logInfo.Println(
"IPv6 packet with disallowed source address from", "IPv6 packet with disallowed source address from",
peer.String(), peer.String(),

View file

@ -151,14 +151,14 @@ func (device *Device) RoutineReadFromTUN() {
continue continue
} }
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routingTable.LookupIPv4(dst) peer = device.routing.table.LookupIPv4(dst)
case ipv6.Version: case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen { if len(elem.packet) < ipv6.HeaderLen {
continue continue
} }
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routingTable.LookupIPv6(dst) peer = device.routing.table.LookupIPv6(dst)
default: default:
logDebug.Println("Received packet with unknown IP version") logDebug.Println("Received packet with unknown IP version")
@ -187,10 +187,14 @@ func (device *Device) RoutineReadFromTUN() {
func (peer *Peer) RoutineNonce() { func (peer *Peer) RoutineNonce() {
var keyPair *KeyPair var keyPair *KeyPair
defer peer.routines.stopping.Done()
device := peer.device device := peer.device
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, nonce worker, started for peer", peer.String()) logDebug.Println("Routine, nonce worker, started for peer", peer.String())
peer.routines.starting.Done()
for { for {
NextPacket: NextPacket:
select { select {

View file

@ -303,7 +303,7 @@ func (peer *Peer) RoutineTimerHandler() {
err := peer.sendNewHandshake() err := peer.sendNewHandshake()
if err != nil { if err != nil {
logInfo.Println( logInfo.Println(
"Failed to send handshake to peer:", peer.String()) "Failed to send handshake to peer:", peer.String(), "(", err, ")")
} }
case <-peer.timer.handshakeDeadline.Wait(): case <-peer.timer.handshakeDeadline.Wait():
@ -326,7 +326,7 @@ func (peer *Peer) RoutineTimerHandler() {
err := peer.sendNewHandshake() err := peer.sendNewHandshake()
if err != nil { if err != nil {
logInfo.Println( logInfo.Println(
"Failed to send handshake to peer:", peer.String()) "Failed to send handshake to peer:", peer.String(), "(", err, ")")
} }
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)

View file

@ -313,7 +313,7 @@ func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) {
} }
go device.RoutineNetlinkListener() go device.RoutineNetlinkListener()
go device.RoutineHackListener() // cross namespace // go device.RoutineHackListener() // cross namespace
// set default MTU // set default MTU
@ -369,7 +369,7 @@ func CreateTUN(name string) (TUNDevice, error) {
} }
go device.RoutineNetlinkListener() go device.RoutineNetlinkListener()
go device.RoutineHackListener() // cross namespace // go device.RoutineHackListener() // cross namespace
// set default MTU // set default MTU

View file

@ -25,18 +25,35 @@ func (s *IPCError) ErrorCode() int64 {
func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// create lines device.log.Debug.Println("UAPI: Processing get operation")
device.mutex.RLock() // create lines
device.net.mutex.RLock()
lines := make([]string, 0, 100) lines := make([]string, 0, 100)
send := func(line string) { send := func(line string) {
lines = append(lines, line) lines = append(lines, line)
} }
if !device.privateKey.IsZero() { func() {
send("private_key=" + device.privateKey.ToHex())
// lock required resources
device.net.mutex.RLock()
defer device.net.mutex.RUnlock()
device.noise.mutex.RLock()
defer device.noise.mutex.RUnlock()
device.routing.mutex.RLock()
defer device.routing.mutex.RUnlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
// serialize device related values
if !device.noise.privateKey.IsZero() {
send("private_key=" + device.noise.privateKey.ToHex())
} }
if device.net.port != 0 { if device.net.port != 0 {
@ -47,10 +64,12 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
} }
for _, peer := range device.peers { // serialize each peer state
func() {
for _, peer := range device.peers.keyMap {
peer.mutex.RLock() peer.mutex.RLock()
defer peer.mutex.RUnlock() defer peer.mutex.RUnlock()
send("public_key=" + peer.handshake.remoteStatic.ToHex()) send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex()) send("preshared_key=" + peer.handshake.presharedKey.ToHex())
if peer.endpoint != nil { if peer.endpoint != nil {
@ -69,16 +88,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
atomic.LoadUint64(&peer.persistentKeepaliveInterval), atomic.LoadUint64(&peer.persistentKeepaliveInterval),
)) ))
for _, ip := range device.routingTable.AllowedIPs(peer) { for _, ip := range device.routing.table.AllowedIPs(peer) {
send("allowed_ip=" + ip.String()) send("allowed_ip=" + ip.String())
} }
}()
} }
}()
device.net.mutex.RUnlock() // send lines (does not require resource locks)
device.mutex.RUnlock()
// send lines
for _, line := range lines { for _, line := range lines {
_, err := socket.WriteString(line + "\n") _, err := socket.WriteString(line + "\n")
@ -94,7 +111,6 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
scanner := bufio.NewScanner(socket) scanner := bufio.NewScanner(socket)
logInfo := device.log.Info
logError := device.log.Error logError := device.log.Error
logDebug := device.log.Debug logDebug := device.log.Debug
@ -130,6 +146,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set private_key:", err) logError.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
logDebug.Println("UAPI: Updating device private key")
device.SetPrivateKey(sk) device.SetPrivateKey(sk)
case "listen_port": case "listen_port":
@ -144,6 +161,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update port and rebind // update port and rebind
logDebug.Println("UAPI: Updating listen port")
device.net.mutex.Lock() device.net.mutex.Lock()
device.net.port = uint16(port) device.net.port = uint16(port)
device.net.mutex.Unlock() device.net.mutex.Unlock()
@ -170,6 +189,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
logDebug.Println("UAPI: Updating fwmark")
device.net.mutex.Lock() device.net.mutex.Lock()
device.net.fwmark = uint32(fwmark) device.net.fwmark = uint32(fwmark)
device.net.mutex.Unlock() device.net.mutex.Unlock()
@ -181,6 +202,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "public_key": case "public_key":
// switch to peer configuration // switch to peer configuration
logDebug.Println("UAPI: Transition to peer configuration")
deviceConfig = false deviceConfig = false
case "replace_peers": case "replace_peers":
@ -188,6 +210,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set replace_peers, invalid value:", value) logError.Println("Failed to set replace_peers, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
logDebug.Println("UAPI: Removing all peers")
device.RemoveAllPeers() device.RemoveAllPeers()
default: default:
@ -203,42 +226,40 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
switch key { switch key {
case "public_key": case "public_key":
var pubKey NoisePublicKey var publicKey NoisePublicKey
err := pubKey.FromHex(value) err := publicKey.FromHex(value)
if err != nil { if err != nil {
logError.Println("Failed to get peer by public_key:", err) logError.Println("Failed to get peer by public_key:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
// check if public key of peer equal to device // ignore peer with public key of device
device.mutex.RLock() device.noise.mutex.RLock()
if device.publicKey.Equals(pubKey) { equals := device.noise.publicKey.Equals(publicKey)
device.noise.mutex.RUnlock()
// create dummy instance (not added to device)
if equals {
peer = &Peer{} peer = &Peer{}
dummy = true dummy = true
device.mutex.RUnlock() }
logInfo.Println("Ignoring peer with public key of device")
} else {
// find peer referenced // find peer referenced
peer, _ = device.peers[pubKey] peer = device.LookupPeer(publicKey)
device.mutex.RUnlock()
if peer == nil { if peer == nil {
peer, err = device.NewPeer(pubKey) peer, err = device.NewPeer(publicKey)
if err != nil { if err != nil {
logError.Println("Failed to create new peer:", err) logError.Println("Failed to create new peer:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
logDebug.Println("UAPI: Created new peer:", peer.String())
} }
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
dummy = false
} peer.mutex.Lock()
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
peer.mutex.Unlock()
case "remove": case "remove":
@ -249,7 +270,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
if !dummy { if !dummy {
logDebug.Println("Removing", peer.String()) logDebug.Println("UAPI: Removing peer:", peer.String())
device.RemovePeer(peer.handshake.remoteStatic) device.RemovePeer(peer.handshake.remoteStatic)
} }
peer = &Peer{} peer = &Peer{}
@ -259,9 +280,12 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update PSK // update PSK
peer.mutex.Lock() logDebug.Println("UAPI: Updating pre-shared key for peer:", peer.String())
peer.handshake.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value) err := peer.handshake.presharedKey.FromHex(value)
peer.mutex.Unlock() peer.handshake.mutex.Unlock()
if err != nil { if err != nil {
logError.Println("Failed to set preshared_key:", err) logError.Println("Failed to set preshared_key:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
@ -271,6 +295,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// set endpoint destination // set endpoint destination
logDebug.Println("UAPI: Updating endpoint for peer:", peer.String())
err := func() error { err := func() error {
peer.mutex.Lock() peer.mutex.Lock()
defer peer.mutex.Unlock() defer peer.mutex.Unlock()
@ -292,6 +318,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update keep-alive interval // update keep-alive interval
logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer.String())
secs, err := strconv.ParseUint(value, 10, 16) secs, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {
logError.Println("Failed to set persistent_keepalive_interval:", err) logError.Println("Failed to set persistent_keepalive_interval:", err)
@ -316,25 +344,41 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
} }
case "replace_allowed_ips": case "replace_allowed_ips":
logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer.String())
if value != "true" { if value != "true" {
logError.Println("Failed to set replace_allowed_ips, invalid value:", value) logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
if !dummy {
device.routingTable.RemovePeer(peer) if dummy {
continue
} }
device.routing.mutex.Lock()
device.routing.table.RemovePeer(peer)
device.routing.mutex.Unlock()
case "allowed_ip": case "allowed_ip":
logDebug.Println("UAPI: Adding allowed_ip to peer:", peer.String())
_, network, err := net.ParseCIDR(value) _, network, err := net.ParseCIDR(value)
if err != nil { if err != nil {
logError.Println("Failed to set allowed_ip:", err) logError.Println("Failed to set allowed_ip:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
ones, _ := network.Mask.Size()
if !dummy { if dummy {
device.routingTable.Insert(network.IP, uint(ones), peer) continue
} }
ones, _ := network.Mask.Size()
device.routing.mutex.Lock()
device.routing.table.Insert(network.IP, uint(ones), peer)
device.routing.mutex.Unlock()
default: default:
logError.Println("Invalid UAPI key (peer configuration):", key) logError.Println("Invalid UAPI key (peer configuration):", key)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}