More consistent use of signal struct
This commit is contained in:
parent
cb09125dc4
commit
eaca1ee1f7
|
@ -37,7 +37,7 @@ type Device struct {
|
||||||
handshake chan QueueHandshakeElement
|
handshake chan QueueHandshakeElement
|
||||||
}
|
}
|
||||||
signal struct {
|
signal struct {
|
||||||
stop chan struct{}
|
stop Signal
|
||||||
}
|
}
|
||||||
underLoadUntil atomic.Value
|
underLoadUntil atomic.Value
|
||||||
ratelimiter Ratelimiter
|
ratelimiter Ratelimiter
|
||||||
|
@ -129,7 +129,6 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
|
||||||
|
|
||||||
func NewDevice(tun TUNDevice, logger *Logger) *Device {
|
func NewDevice(tun TUNDevice, logger *Logger) *Device {
|
||||||
device := new(Device)
|
device := new(Device)
|
||||||
|
|
||||||
device.mutex.Lock()
|
device.mutex.Lock()
|
||||||
defer device.mutex.Unlock()
|
defer device.mutex.Unlock()
|
||||||
|
|
||||||
|
@ -160,7 +159,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
|
||||||
|
|
||||||
// prepare signals
|
// prepare signals
|
||||||
|
|
||||||
device.signal.stop = make(chan struct{})
|
device.signal.stop = NewSignal()
|
||||||
|
|
||||||
// prepare net
|
// prepare net
|
||||||
|
|
||||||
|
@ -174,9 +173,11 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
|
||||||
go device.RoutineDecryption()
|
go device.RoutineDecryption()
|
||||||
go device.RoutineHandshake()
|
go device.RoutineHandshake()
|
||||||
}
|
}
|
||||||
|
|
||||||
go device.RoutineReadFromTUN()
|
go device.RoutineReadFromTUN()
|
||||||
go device.RoutineTUNEventReader()
|
go device.RoutineTUNEventReader()
|
||||||
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
|
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
|
||||||
|
|
||||||
return device
|
return device
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,11 +211,11 @@ func (device *Device) Close() {
|
||||||
}
|
}
|
||||||
device.log.Info.Println("Closing device")
|
device.log.Info.Println("Closing device")
|
||||||
device.RemoveAllPeers()
|
device.RemoveAllPeers()
|
||||||
close(device.signal.stop)
|
device.signal.stop.Broadcast()
|
||||||
closeBind(device)
|
|
||||||
device.tun.device.Close()
|
device.tun.device.Close()
|
||||||
|
closeBind(device)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) WaitChannel() chan struct{} {
|
func (device *Device) Wait() chan struct{} {
|
||||||
return device.signal.stop
|
return device.signal.stop.Wait()
|
||||||
}
|
}
|
||||||
|
|
11
src/main.go
11
src/main.go
|
@ -8,6 +8,10 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import _ "net/http/pprof"
|
||||||
|
import "net/http"
|
||||||
|
import "log"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ExitSetupSuccess = 0
|
ExitSetupSuccess = 0
|
||||||
ExitSetupFailed = 1
|
ExitSetupFailed = 1
|
||||||
|
@ -25,6 +29,10 @@ func printUsage() {
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||||
|
}()
|
||||||
|
|
||||||
// parse arguments
|
// parse arguments
|
||||||
|
|
||||||
var foreground bool
|
var foreground bool
|
||||||
|
@ -160,7 +168,6 @@ func main() {
|
||||||
|
|
||||||
errs := make(chan error)
|
errs := make(chan error)
|
||||||
term := make(chan os.Signal)
|
term := make(chan os.Signal)
|
||||||
wait := device.WaitChannel()
|
|
||||||
|
|
||||||
uapi, err := UAPIListen(interfaceName, fileUAPI)
|
uapi, err := UAPIListen(interfaceName, fileUAPI)
|
||||||
|
|
||||||
|
@ -183,9 +190,9 @@ func main() {
|
||||||
signal.Notify(term, os.Interrupt)
|
signal.Notify(term, os.Interrupt)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-wait:
|
|
||||||
case <-term:
|
case <-term:
|
||||||
case <-errs:
|
case <-errs:
|
||||||
|
case <-device.Wait():
|
||||||
}
|
}
|
||||||
|
|
||||||
// clean up
|
// clean up
|
||||||
|
|
37
src/misc.go
37
src/misc.go
|
@ -2,12 +2,10 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
/* We use int32 as atomic bools
|
/* Atomic Boolean */
|
||||||
* (since booleans are not natively supported by sync/atomic)
|
|
||||||
*/
|
|
||||||
const (
|
const (
|
||||||
AtomicFalse = int32(iota)
|
AtomicFalse = int32(iota)
|
||||||
AtomicTrue
|
AtomicTrue
|
||||||
|
@ -37,6 +35,8 @@ func (a *AtomicBool) Set(val bool) {
|
||||||
atomic.StoreInt32(&a.flag, flag)
|
atomic.StoreInt32(&a.flag, flag)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Integer manipulation */
|
||||||
|
|
||||||
func toInt32(n uint32) int32 {
|
func toInt32(n uint32) int32 {
|
||||||
mask := uint32(1 << 31)
|
mask := uint32(1 << 31)
|
||||||
return int32(-(n & mask) + (n & ^mask))
|
return int32(-(n & mask) + (n & ^mask))
|
||||||
|
@ -55,32 +55,3 @@ func minUint64(a uint64, b uint64) uint64 {
|
||||||
}
|
}
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
func signalSend(c chan struct{}) {
|
|
||||||
select {
|
|
||||||
case c <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func signalClear(c chan struct{}) {
|
|
||||||
select {
|
|
||||||
case <-c:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func timerStop(timer *time.Timer) {
|
|
||||||
if !timer.Stop() {
|
|
||||||
select {
|
|
||||||
case <-timer.C:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewStoppedTimer() *time.Timer {
|
|
||||||
timer := time.NewTimer(time.Hour)
|
|
||||||
timerStop(timer)
|
|
||||||
return timer
|
|
||||||
}
|
|
||||||
|
|
|
@ -66,11 +66,11 @@ func (rate *Ratelimiter) GarbageCollectEntries() {
|
||||||
rate.mutex.Unlock()
|
rate.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rate *Ratelimiter) RoutineGarbageCollector(stop chan struct{}) {
|
func (rate *Ratelimiter) RoutineGarbageCollector(stop Signal) {
|
||||||
timer := time.NewTimer(time.Second)
|
timer := time.NewTimer(time.Second)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-stop:
|
case <-stop.Wait():
|
||||||
return
|
return
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
rate.GarbageCollectEntries()
|
rate.GarbageCollectEntries()
|
||||||
|
|
|
@ -93,6 +93,11 @@ func (device *Device) addToHandshakeQueue(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Receives incoming datagrams for the device
|
||||||
|
*
|
||||||
|
* Every time the bind is updated a new routine is started for
|
||||||
|
* IPv4 and IPv6 (separately)
|
||||||
|
*/
|
||||||
func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
|
func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
|
||||||
|
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
|
@ -182,6 +187,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
|
||||||
device.addToDecryptionQueue(device.queue.decryption, elem)
|
device.addToDecryptionQueue(device.queue.decryption, elem)
|
||||||
device.addToInboundQueue(peer.queue.inbound, elem)
|
device.addToInboundQueue(peer.queue.inbound, elem)
|
||||||
buffer = device.GetMessageBuffer()
|
buffer = device.GetMessageBuffer()
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
// otherwise it is a fixed size & handshake related packet
|
// otherwise it is a fixed size & handshake related packet
|
||||||
|
@ -220,7 +226,7 @@ func (device *Device) RoutineDecryption() {
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-device.signal.stop:
|
case <-device.signal.stop.Wait():
|
||||||
logDebug.Println("Routine, decryption worker, stopped")
|
logDebug.Println("Routine, decryption worker, stopped")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -256,7 +262,7 @@ func (device *Device) RoutineDecryption() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Handles incomming packets related to handshake
|
/* Handles incoming packets related to handshake
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineHandshake() {
|
func (device *Device) RoutineHandshake() {
|
||||||
|
|
||||||
|
@ -271,7 +277,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case elem = <-device.queue.handshake:
|
case elem = <-device.queue.handshake:
|
||||||
case <-device.signal.stop:
|
case <-device.signal.stop.Wait():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -356,7 +362,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle handshake initation/response content
|
// handle handshake initiation/response content
|
||||||
|
|
||||||
switch elem.msgType {
|
switch elem.msgType {
|
||||||
case MessageInitiationType:
|
case MessageInitiationType:
|
||||||
|
@ -376,7 +382,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
peer := device.ConsumeMessageInitiation(&msg)
|
peer := device.ConsumeMessageInitiation(&msg)
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
logInfo.Println(
|
logInfo.Println(
|
||||||
"Recieved invalid initiation message from",
|
"Received invalid initiation message from",
|
||||||
elem.endpoint.DstToString(),
|
elem.endpoint.DstToString(),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
@ -449,7 +455,7 @@ func (device *Device) RoutineHandshake() {
|
||||||
peer.endpoint = elem.endpoint
|
peer.endpoint = elem.endpoint
|
||||||
peer.mutex.Unlock()
|
peer.mutex.Unlock()
|
||||||
|
|
||||||
logDebug.Println("Received handshake initation from", peer)
|
logDebug.Println("Received handshake initiation from", peer)
|
||||||
|
|
||||||
peer.TimerEphemeralKeyCreated()
|
peer.TimerEphemeralKeyCreated()
|
||||||
|
|
||||||
|
@ -556,7 +562,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||||
if device.routingTable.LookupIPv4(src) != peer {
|
if device.routingTable.LookupIPv4(src) != peer {
|
||||||
logInfo.Println(
|
logInfo.Println(
|
||||||
"IPv4 packet with unallowed source address from",
|
"IPv4 packet with disallowed source address from",
|
||||||
peer.String(),
|
peer.String(),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
@ -584,7 +590,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||||
if device.routingTable.LookupIPv6(src) != peer {
|
if device.routingTable.LookupIPv6(src) != peer {
|
||||||
logInfo.Println(
|
logInfo.Println(
|
||||||
"IPv6 packet with unallowed source address from",
|
"IPv6 packet with disallowed source address from",
|
||||||
peer.String(),
|
peer.String(),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
23
src/send.go
23
src/send.go
|
@ -11,7 +11,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Handles outbound flow
|
/* Outbound flow
|
||||||
*
|
*
|
||||||
* 1. TUN queue
|
* 1. TUN queue
|
||||||
* 2. Routing (sequential)
|
* 2. Routing (sequential)
|
||||||
|
@ -19,17 +19,22 @@ import (
|
||||||
* 4. Encryption (parallel)
|
* 4. Encryption (parallel)
|
||||||
* 5. Transmission (sequential)
|
* 5. Transmission (sequential)
|
||||||
*
|
*
|
||||||
* The order of packets (per peer) is maintained.
|
* The functions in this file occur (roughly) in the order in
|
||||||
* The functions in this file occure (roughly) in the order packets are processed.
|
* which the packets are processed.
|
||||||
*/
|
*
|
||||||
|
* Locking, Producers and Consumers
|
||||||
/* The sequential consumers will attempt to take the lock,
|
*
|
||||||
|
* The order of packets (per peer) must be maintained,
|
||||||
|
* but encryption of packets happen out-of-order:
|
||||||
|
*
|
||||||
|
* The sequential consumers will attempt to take the lock,
|
||||||
* workers release lock when they have completed work (encryption) on the packet.
|
* workers release lock when they have completed work (encryption) on the packet.
|
||||||
*
|
*
|
||||||
* If the element is inserted into the "encryption queue",
|
* If the element is inserted into the "encryption queue",
|
||||||
* the content is preceeded by enough "junk" to contain the transport header
|
* the content is preceded by enough "junk" to contain the transport header
|
||||||
* (to allow the construction of transport messages in-place)
|
* (to allow the construction of transport messages in-place)
|
||||||
*/
|
*/
|
||||||
|
|
||||||
type QueueOutboundElement struct {
|
type QueueOutboundElement struct {
|
||||||
dropped int32
|
dropped int32
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
@ -155,7 +160,7 @@ func (device *Device) RoutineReadFromTUN() {
|
||||||
peer = device.routingTable.LookupIPv6(dst)
|
peer = device.routingTable.LookupIPv6(dst)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logDebug.Println("Receieved packet with unknown IP version")
|
logDebug.Println("Received packet with unknown IP version")
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
|
@ -249,7 +254,7 @@ func (device *Device) RoutineEncryption() {
|
||||||
// fetch next element
|
// fetch next element
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-device.signal.stop:
|
case <-device.signal.stop.Wait():
|
||||||
logDebug.Println("Routine, encryption worker, stopped")
|
logDebug.Println("Routine, encryption worker, stopped")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,8 @@ func (s *Signal) Enable() {
|
||||||
s.enabled.Set(true)
|
s.enabled.Set(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Unblock exactly one listener
|
||||||
|
*/
|
||||||
func (s *Signal) Send() {
|
func (s *Signal) Send() {
|
||||||
if s.enabled.Get() {
|
if s.enabled.Get() {
|
||||||
select {
|
select {
|
||||||
|
@ -29,6 +31,8 @@ func (s *Signal) Send() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Clear the signal if already fired
|
||||||
|
*/
|
||||||
func (s Signal) Clear() {
|
func (s Signal) Clear() {
|
||||||
select {
|
select {
|
||||||
case <-s.C:
|
case <-s.C:
|
||||||
|
@ -36,10 +40,14 @@ func (s Signal) Clear() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Unblocks all listeners (forever)
|
||||||
|
*/
|
||||||
func (s Signal) Broadcast() {
|
func (s Signal) Broadcast() {
|
||||||
close(s.C) // unblocks all selectors
|
close(s.C)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Wait for the signal
|
||||||
|
*/
|
||||||
func (s Signal) Wait() chan struct{} {
|
func (s Signal) Wait() chan struct{} {
|
||||||
return s.C
|
return s.C
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@ func (peer *Peer) KeepKeyFreshSending() {
|
||||||
|
|
||||||
/* Called when a new authenticated message has been received
|
/* Called when a new authenticated message has been received
|
||||||
*
|
*
|
||||||
* NOTE: Not thread safe (called by sequential receiver)
|
* NOTE: Not thread safe, but called by sequential receiver!
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) KeepKeyFreshReceiving() {
|
func (peer *Peer) KeepKeyFreshReceiving() {
|
||||||
if peer.timer.sendLastMinuteHandshake {
|
if peer.timer.sendLastMinuteHandshake {
|
||||||
|
|
12
src/trie.go
12
src/trie.go
|
@ -11,10 +11,8 @@ import (
|
||||||
* same way as those created by the "net" functions.
|
* same way as those created by the "net" functions.
|
||||||
* Here the IPs are slices of either 4 or 16 byte (not always 16)
|
* Here the IPs are slices of either 4 or 16 byte (not always 16)
|
||||||
*
|
*
|
||||||
* Syncronization done seperatly
|
* Synchronization done separately
|
||||||
* See: routing.go
|
* See: routing.go
|
||||||
*
|
|
||||||
* TODO: Better commenting
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
type Trie struct {
|
type Trie struct {
|
||||||
|
@ -30,7 +28,11 @@ type Trie struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Finds length of matching prefix
|
/* Finds length of matching prefix
|
||||||
* TODO: Make faster
|
*
|
||||||
|
* TODO: Only use during insertion (xor + prefix mask for lookup)
|
||||||
|
* Check out
|
||||||
|
* prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits)
|
||||||
|
* https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match
|
||||||
*
|
*
|
||||||
* Assumption:
|
* Assumption:
|
||||||
* len(ip1) == len(ip2)
|
* len(ip1) == len(ip2)
|
||||||
|
@ -88,7 +90,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
// walk recursivly
|
// walk recursively
|
||||||
|
|
||||||
node.child[0] = node.child[0].RemovePeer(p)
|
node.child[0] = node.child[0].RemovePeer(p)
|
||||||
node.child[1] = node.child[1].RemovePeer(p)
|
node.child[1] = node.child[1].RemovePeer(p)
|
||||||
|
|
Loading…
Reference in a new issue