More consistent use of signal struct

This commit is contained in:
Mathias Hall-Andersen 2017-12-01 23:37:26 +01:00
parent cb09125dc4
commit eaca1ee1f7
9 changed files with 68 additions and 68 deletions

View file

@ -37,7 +37,7 @@ type Device struct {
handshake chan QueueHandshakeElement handshake chan QueueHandshakeElement
} }
signal struct { signal struct {
stop chan struct{} stop Signal
} }
underLoadUntil atomic.Value underLoadUntil atomic.Value
ratelimiter Ratelimiter ratelimiter Ratelimiter
@ -129,7 +129,6 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
func NewDevice(tun TUNDevice, logger *Logger) *Device { func NewDevice(tun TUNDevice, logger *Logger) *Device {
device := new(Device) device := new(Device)
device.mutex.Lock() device.mutex.Lock()
defer device.mutex.Unlock() defer device.mutex.Unlock()
@ -160,7 +159,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
// prepare signals // prepare signals
device.signal.stop = make(chan struct{}) device.signal.stop = NewSignal()
// prepare net // prepare net
@ -174,9 +173,11 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
go device.RoutineDecryption() go device.RoutineDecryption()
go device.RoutineHandshake() go device.RoutineHandshake()
} }
go device.RoutineReadFromTUN() go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader() go device.RoutineTUNEventReader()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
return device return device
} }
@ -210,11 +211,11 @@ func (device *Device) Close() {
} }
device.log.Info.Println("Closing device") device.log.Info.Println("Closing device")
device.RemoveAllPeers() device.RemoveAllPeers()
close(device.signal.stop) device.signal.stop.Broadcast()
closeBind(device)
device.tun.device.Close() device.tun.device.Close()
closeBind(device)
} }
func (device *Device) WaitChannel() chan struct{} { func (device *Device) Wait() chan struct{} {
return device.signal.stop return device.signal.stop.Wait()
} }

View file

@ -8,6 +8,10 @@ import (
"strconv" "strconv"
) )
import _ "net/http/pprof"
import "net/http"
import "log"
const ( const (
ExitSetupSuccess = 0 ExitSetupSuccess = 0
ExitSetupFailed = 1 ExitSetupFailed = 1
@ -25,6 +29,10 @@ func printUsage() {
func main() { func main() {
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
// parse arguments // parse arguments
var foreground bool var foreground bool
@ -160,7 +168,6 @@ func main() {
errs := make(chan error) errs := make(chan error)
term := make(chan os.Signal) term := make(chan os.Signal)
wait := device.WaitChannel()
uapi, err := UAPIListen(interfaceName, fileUAPI) uapi, err := UAPIListen(interfaceName, fileUAPI)
@ -183,9 +190,9 @@ func main() {
signal.Notify(term, os.Interrupt) signal.Notify(term, os.Interrupt)
select { select {
case <-wait:
case <-term: case <-term:
case <-errs: case <-errs:
case <-device.Wait():
} }
// clean up // clean up

View file

@ -2,12 +2,10 @@ package main
import ( import (
"sync/atomic" "sync/atomic"
"time"
) )
/* We use int32 as atomic bools /* Atomic Boolean */
* (since booleans are not natively supported by sync/atomic)
*/
const ( const (
AtomicFalse = int32(iota) AtomicFalse = int32(iota)
AtomicTrue AtomicTrue
@ -37,6 +35,8 @@ func (a *AtomicBool) Set(val bool) {
atomic.StoreInt32(&a.flag, flag) atomic.StoreInt32(&a.flag, flag)
} }
/* Integer manipulation */
func toInt32(n uint32) int32 { func toInt32(n uint32) int32 {
mask := uint32(1 << 31) mask := uint32(1 << 31)
return int32(-(n & mask) + (n & ^mask)) return int32(-(n & mask) + (n & ^mask))
@ -55,32 +55,3 @@ func minUint64(a uint64, b uint64) uint64 {
} }
return a return a
} }
func signalSend(c chan struct{}) {
select {
case c <- struct{}{}:
default:
}
}
func signalClear(c chan struct{}) {
select {
case <-c:
default:
}
}
func timerStop(timer *time.Timer) {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
}
func NewStoppedTimer() *time.Timer {
timer := time.NewTimer(time.Hour)
timerStop(timer)
return timer
}

View file

@ -66,11 +66,11 @@ func (rate *Ratelimiter) GarbageCollectEntries() {
rate.mutex.Unlock() rate.mutex.Unlock()
} }
func (rate *Ratelimiter) RoutineGarbageCollector(stop chan struct{}) { func (rate *Ratelimiter) RoutineGarbageCollector(stop Signal) {
timer := time.NewTimer(time.Second) timer := time.NewTimer(time.Second)
for { for {
select { select {
case <-stop: case <-stop.Wait():
return return
case <-timer.C: case <-timer.C:
rate.GarbageCollectEntries() rate.GarbageCollectEntries()

View file

@ -93,6 +93,11 @@ func (device *Device) addToHandshakeQueue(
} }
} }
/* Receives incoming datagrams for the device
*
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
logDebug := device.log.Debug logDebug := device.log.Debug
@ -182,6 +187,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
device.addToDecryptionQueue(device.queue.decryption, elem) device.addToDecryptionQueue(device.queue.decryption, elem)
device.addToInboundQueue(peer.queue.inbound, elem) device.addToInboundQueue(peer.queue.inbound, elem)
buffer = device.GetMessageBuffer() buffer = device.GetMessageBuffer()
continue continue
// otherwise it is a fixed size & handshake related packet // otherwise it is a fixed size & handshake related packet
@ -220,7 +226,7 @@ func (device *Device) RoutineDecryption() {
for { for {
select { select {
case <-device.signal.stop: case <-device.signal.stop.Wait():
logDebug.Println("Routine, decryption worker, stopped") logDebug.Println("Routine, decryption worker, stopped")
return return
@ -256,7 +262,7 @@ func (device *Device) RoutineDecryption() {
} }
} }
/* Handles incomming packets related to handshake /* Handles incoming packets related to handshake
*/ */
func (device *Device) RoutineHandshake() { func (device *Device) RoutineHandshake() {
@ -271,7 +277,7 @@ func (device *Device) RoutineHandshake() {
for { for {
select { select {
case elem = <-device.queue.handshake: case elem = <-device.queue.handshake:
case <-device.signal.stop: case <-device.signal.stop.Wait():
return return
} }
@ -356,7 +362,7 @@ func (device *Device) RoutineHandshake() {
continue continue
} }
// handle handshake initation/response content // handle handshake initiation/response content
switch elem.msgType { switch elem.msgType {
case MessageInitiationType: case MessageInitiationType:
@ -376,7 +382,7 @@ func (device *Device) RoutineHandshake() {
peer := device.ConsumeMessageInitiation(&msg) peer := device.ConsumeMessageInitiation(&msg)
if peer == nil { if peer == nil {
logInfo.Println( logInfo.Println(
"Recieved invalid initiation message from", "Received invalid initiation message from",
elem.endpoint.DstToString(), elem.endpoint.DstToString(),
) )
continue continue
@ -449,7 +455,7 @@ func (device *Device) RoutineHandshake() {
peer.endpoint = elem.endpoint peer.endpoint = elem.endpoint
peer.mutex.Unlock() peer.mutex.Unlock()
logDebug.Println("Received handshake initation from", peer) logDebug.Println("Received handshake initiation from", peer)
peer.TimerEphemeralKeyCreated() peer.TimerEphemeralKeyCreated()
@ -556,7 +562,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.routingTable.LookupIPv4(src) != peer { if device.routingTable.LookupIPv4(src) != peer {
logInfo.Println( logInfo.Println(
"IPv4 packet with unallowed source address from", "IPv4 packet with disallowed source address from",
peer.String(), peer.String(),
) )
continue continue
@ -584,7 +590,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routingTable.LookupIPv6(src) != peer { if device.routingTable.LookupIPv6(src) != peer {
logInfo.Println( logInfo.Println(
"IPv6 packet with unallowed source address from", "IPv6 packet with disallowed source address from",
peer.String(), peer.String(),
) )
continue continue

View file

@ -11,7 +11,7 @@ import (
"time" "time"
) )
/* Handles outbound flow /* Outbound flow
* *
* 1. TUN queue * 1. TUN queue
* 2. Routing (sequential) * 2. Routing (sequential)
@ -19,17 +19,22 @@ import (
* 4. Encryption (parallel) * 4. Encryption (parallel)
* 5. Transmission (sequential) * 5. Transmission (sequential)
* *
* The order of packets (per peer) is maintained. * The functions in this file occur (roughly) in the order in
* The functions in this file occure (roughly) in the order packets are processed. * which the packets are processed.
*/ *
* Locking, Producers and Consumers
/* The sequential consumers will attempt to take the lock, *
* The order of packets (per peer) must be maintained,
* but encryption of packets happen out-of-order:
*
* The sequential consumers will attempt to take the lock,
* workers release lock when they have completed work (encryption) on the packet. * workers release lock when they have completed work (encryption) on the packet.
* *
* If the element is inserted into the "encryption queue", * If the element is inserted into the "encryption queue",
* the content is preceeded by enough "junk" to contain the transport header * the content is preceded by enough "junk" to contain the transport header
* (to allow the construction of transport messages in-place) * (to allow the construction of transport messages in-place)
*/ */
type QueueOutboundElement struct { type QueueOutboundElement struct {
dropped int32 dropped int32
mutex sync.Mutex mutex sync.Mutex
@ -155,7 +160,7 @@ func (device *Device) RoutineReadFromTUN() {
peer = device.routingTable.LookupIPv6(dst) peer = device.routingTable.LookupIPv6(dst)
default: default:
logDebug.Println("Receieved packet with unknown IP version") logDebug.Println("Received packet with unknown IP version")
} }
if peer == nil { if peer == nil {
@ -249,7 +254,7 @@ func (device *Device) RoutineEncryption() {
// fetch next element // fetch next element
select { select {
case <-device.signal.stop: case <-device.signal.stop.Wait():
logDebug.Println("Routine, encryption worker, stopped") logDebug.Println("Routine, encryption worker, stopped")
return return

View file

@ -20,6 +20,8 @@ func (s *Signal) Enable() {
s.enabled.Set(true) s.enabled.Set(true)
} }
/* Unblock exactly one listener
*/
func (s *Signal) Send() { func (s *Signal) Send() {
if s.enabled.Get() { if s.enabled.Get() {
select { select {
@ -29,6 +31,8 @@ func (s *Signal) Send() {
} }
} }
/* Clear the signal if already fired
*/
func (s Signal) Clear() { func (s Signal) Clear() {
select { select {
case <-s.C: case <-s.C:
@ -36,10 +40,14 @@ func (s Signal) Clear() {
} }
} }
/* Unblocks all listeners (forever)
*/
func (s Signal) Broadcast() { func (s Signal) Broadcast() {
close(s.C) // unblocks all selectors close(s.C)
} }
/* Wait for the signal
*/
func (s Signal) Wait() chan struct{} { func (s Signal) Wait() chan struct{} {
return s.C return s.C
} }

View file

@ -27,7 +27,7 @@ func (peer *Peer) KeepKeyFreshSending() {
/* Called when a new authenticated message has been received /* Called when a new authenticated message has been received
* *
* NOTE: Not thread safe (called by sequential receiver) * NOTE: Not thread safe, but called by sequential receiver!
*/ */
func (peer *Peer) KeepKeyFreshReceiving() { func (peer *Peer) KeepKeyFreshReceiving() {
if peer.timer.sendLastMinuteHandshake { if peer.timer.sendLastMinuteHandshake {

View file

@ -11,10 +11,8 @@ import (
* same way as those created by the "net" functions. * same way as those created by the "net" functions.
* Here the IPs are slices of either 4 or 16 byte (not always 16) * Here the IPs are slices of either 4 or 16 byte (not always 16)
* *
* Syncronization done seperatly * Synchronization done separately
* See: routing.go * See: routing.go
*
* TODO: Better commenting
*/ */
type Trie struct { type Trie struct {
@ -30,7 +28,11 @@ type Trie struct {
} }
/* Finds length of matching prefix /* Finds length of matching prefix
* TODO: Make faster *
* TODO: Only use during insertion (xor + prefix mask for lookup)
* Check out
* prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits)
* https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match
* *
* Assumption: * Assumption:
* len(ip1) == len(ip2) * len(ip1) == len(ip2)
@ -88,7 +90,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
return node return node
} }
// walk recursivly // walk recursively
node.child[0] = node.child[0].RemovePeer(p) node.child[0] = node.child[0].RemovePeer(p)
node.child[1] = node.child[1].RemovePeer(p) node.child[1] = node.child[1].RemovePeer(p)