More refactoring

This commit is contained in:
Jason A. Donenfeld 2018-05-13 23:14:43 +02:00
parent 729773fdf3
commit b56af1829d
9 changed files with 219 additions and 251 deletions

20
conn.go
View file

@ -74,9 +74,6 @@ func (device *Device) BindSetMark(mark uint32) error {
device.net.mutex.Lock() device.net.mutex.Lock()
defer device.net.mutex.Unlock() defer device.net.mutex.Unlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
// check if modified // check if modified
if device.net.fwmark == mark { if device.net.fwmark == mark {
@ -92,6 +89,18 @@ func (device *Device) BindSetMark(mark uint32) error {
} }
} }
// clear cached source addresses
device.peers.mutex.RLock()
for _, peer := range device.peers.keyMap {
peer.mutex.Lock()
defer peer.mutex.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.mutex.RUnlock()
return nil return nil
} }
@ -100,9 +109,6 @@ func (device *Device) BindUpdate() error {
device.net.mutex.Lock() device.net.mutex.Lock()
defer device.net.mutex.Unlock() defer device.net.mutex.Unlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
// close existing sockets // close existing sockets
if err := unsafeCloseBind(device); err != nil { if err := unsafeCloseBind(device); err != nil {
@ -135,6 +141,7 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses // clear cached source addresses
device.peers.mutex.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.mutex.Lock() peer.mutex.Lock()
defer peer.mutex.Unlock() defer peer.mutex.Unlock()
@ -142,6 +149,7 @@ func (device *Device) BindUpdate() error {
peer.endpoint.ClearSrc() peer.endpoint.ClearSrc()
} }
} }
device.peers.mutex.RUnlock()
// start receiving routines // start receiving routines

View file

@ -38,17 +38,12 @@ type Device struct {
fwmark uint32 // mark value (0 = disabled) fwmark uint32 // mark value (0 = disabled)
} }
noise struct { staticIdentity struct {
mutex sync.RWMutex mutex sync.RWMutex
privateKey NoisePrivateKey privateKey NoisePrivateKey
publicKey NoisePublicKey publicKey NoisePublicKey
} }
routing struct {
mutex sync.RWMutex
table AllowedIPs
}
peers struct { peers struct {
mutex sync.RWMutex mutex sync.RWMutex
keyMap map[NoisePublicKey]*Peer keyMap map[NoisePublicKey]*Peer
@ -56,8 +51,9 @@ type Device struct {
// unprotected / "self-synchronising resources" // unprotected / "self-synchronising resources"
indexTable IndexTable allowedips AllowedIPs
mac CookieChecker indexTable IndexTable
cookieChecker CookieChecker
rate struct { rate struct {
underLoadUntil atomic.Value underLoadUntil atomic.Value
@ -87,15 +83,13 @@ type Device struct {
/* Converts the peer into a "zombie", which remains in the peer map, /* Converts the peer into a "zombie", which remains in the peer map,
* but processes no packets and does not exists in the routing table. * but processes no packets and does not exists in the routing table.
* *
* Must hold: * Must hold device.peers.mutex.
* device.peers.mutex : exclusive lock
* device.routing : exclusive lock
*/ */
func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
// stop routing and processing of packets // stop routing and processing of packets
device.routing.table.RemoveByPeer(peer) device.allowedips.RemoveByPeer(peer)
peer.Stop() peer.Stop()
// remove from peer map // remove from peer map
@ -131,19 +125,19 @@ func deviceUpdateState(device *Device) {
device.isUp.Set(false) device.isUp.Set(false)
break break
} }
device.peers.mutex.Lock() device.peers.mutex.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Start() peer.Start()
} }
device.peers.mutex.Unlock() device.peers.mutex.RUnlock()
case false: case false:
device.BindClose() device.BindClose()
device.peers.mutex.Lock() device.peers.mutex.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Stop() peer.Stop()
} }
device.peers.mutex.Unlock() device.peers.mutex.RUnlock()
} }
// update state variables // update state variables
@ -199,11 +193,8 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// lock required resources // lock required resources
device.noise.mutex.Lock() device.staticIdentity.mutex.Lock()
defer device.noise.mutex.Unlock() defer device.staticIdentity.mutex.Unlock()
device.routing.mutex.Lock()
defer device.routing.mutex.Unlock()
device.peers.mutex.Lock() device.peers.mutex.Lock()
defer device.peers.mutex.Unlock() defer device.peers.mutex.Unlock()
@ -224,13 +215,13 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// update key material // update key material
device.noise.privateKey = sk device.staticIdentity.privateKey = sk
device.noise.publicKey = publicKey device.staticIdentity.publicKey = publicKey
device.mac.Init(publicKey) device.cookieChecker.Init(publicKey)
// do static-static DH pre-computations // do static-static DH pre-computations
rmKey := device.noise.privateKey.IsZero() rmKey := device.staticIdentity.privateKey.IsZero()
for key, peer := range device.peers.keyMap { for key, peer := range device.peers.keyMap {
@ -239,7 +230,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
if rmKey { if rmKey {
hs.precomputedStaticStatic = [NoisePublicKeySize]byte{} hs.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else { } else {
hs.precomputedStaticStatic = device.noise.privateKey.sharedSecret(hs.remoteStatic) hs.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(hs.remoteStatic)
} }
if isZero(hs.precomputedStaticStatic[:]) { if isZero(hs.precomputedStaticStatic[:]) {
@ -281,10 +272,10 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
device.rate.limiter.Init() device.rate.limiter.Init()
device.rate.underLoadUntil.Store(time.Time{}) device.rate.underLoadUntil.Store(time.Time{})
// initialize noise & crypt-key routine // initialize staticIdentity & crypt-key routine
device.indexTable.Init() device.indexTable.Init()
device.routing.table.Reset() device.allowedips.Reset()
// setup buffer pool // setup buffer pool
@ -333,12 +324,6 @@ func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
} }
func (device *Device) RemovePeer(key NoisePublicKey) { func (device *Device) RemovePeer(key NoisePublicKey) {
device.noise.mutex.Lock()
defer device.noise.mutex.Unlock()
device.routing.mutex.Lock()
defer device.routing.mutex.Unlock()
device.peers.mutex.Lock() device.peers.mutex.Lock()
defer device.peers.mutex.Unlock() defer device.peers.mutex.Unlock()
@ -351,12 +336,6 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
} }
func (device *Device) RemoveAllPeers() { func (device *Device) RemoveAllPeers() {
device.noise.mutex.Lock()
defer device.noise.mutex.Unlock()
device.routing.mutex.Lock()
defer device.routing.mutex.Unlock()
device.peers.mutex.Lock() device.peers.mutex.Lock()
defer device.peers.mutex.Unlock() defer device.peers.mutex.Unlock()

View file

@ -107,6 +107,7 @@ type Handshake struct {
precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
lastTimestamp tai64n.Timestamp lastTimestamp tai64n.Timestamp
lastInitiationConsumption time.Time lastInitiationConsumption time.Time
lastSentHandshake time.Time
} }
var ( var (
@ -153,8 +154,8 @@ func init() {
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
device.noise.mutex.RLock() device.staticIdentity.mutex.RLock()
defer device.noise.mutex.RUnlock() defer device.staticIdentity.mutex.RUnlock()
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
@ -206,7 +207,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
ss[:], ss[:],
) )
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:]) aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
}() }()
handshake.mixHash(msg.Static[:]) handshake.mixHash(msg.Static[:])
@ -240,10 +241,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
return nil return nil
} }
device.noise.mutex.RLock() device.staticIdentity.mutex.RLock()
defer device.noise.mutex.RUnlock() defer device.staticIdentity.mutex.RUnlock()
mixHash(&hash, &InitialHash, device.noise.publicKey[:]) mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
mixHash(&hash, &hash, msg.Ephemeral[:]) mixHash(&hash, &hash, msg.Ephemeral[:])
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
@ -253,7 +254,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
var peerPK NoisePublicKey var peerPK NoisePublicKey
func() { func() {
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
KDF2(&chainKey, &key, chainKey[:], ss[:]) KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
@ -422,8 +423,8 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// lock private key for reading // lock private key for reading
device.noise.mutex.RLock() device.staticIdentity.mutex.RLock()
defer device.noise.mutex.RUnlock() defer device.staticIdentity.mutex.RUnlock()
// finish 3-way DH // finish 3-way DH
@ -437,7 +438,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
}() }()
func() { func() {
ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:]) mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:]) setZero(ss[:])
}() }()
@ -490,7 +491,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
/* Derives a new keypair from the current handshake state /* Derives a new keypair from the current handshake state
* *
*/ */
func (peer *Peer) DeriveNewKeypair() error { func (peer *Peer) BeginSymmetricSession() error {
device := peer.device device := peer.device
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
@ -552,50 +553,48 @@ func (peer *Peer) DeriveNewKeypair() error {
// rotate key pairs // rotate key pairs
kp := &peer.keypairs keypairs := &peer.keypairs
kp.mutex.Lock() keypairs.mutex.Lock()
defer keypairs.mutex.Unlock()
peer.timersSessionDerived() previous := keypairs.previous
next := keypairs.next
previous := kp.previous current := keypairs.current
next := kp.next
current := kp.current
if isInitiator { if isInitiator {
if next != nil { if next != nil {
kp.next = nil keypairs.next = nil
kp.previous = next keypairs.previous = next
device.DeleteKeypair(current) device.DeleteKeypair(current)
} else { } else {
kp.previous = current keypairs.previous = current
} }
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
kp.current = keypair keypairs.current = keypair
} else { } else {
kp.next = keypair keypairs.next = keypair
device.DeleteKeypair(next) device.DeleteKeypair(next)
kp.previous = nil keypairs.previous = nil
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
} }
kp.mutex.Unlock()
return nil return nil
} }
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
kp := &peer.keypairs keypairs := &peer.keypairs
if kp.next != receivedKeypair { if keypairs.next != receivedKeypair {
return false return false
} }
kp.mutex.Lock() keypairs.mutex.Lock()
defer kp.mutex.Unlock() defer keypairs.mutex.Unlock()
if kp.next != receivedKeypair { if keypairs.next != receivedKeypair {
return false return false
} }
old := kp.previous old := keypairs.previous
kp.previous = kp.current keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old) peer.device.DeleteKeypair(old)
kp.current = kp.next keypairs.current = keypairs.next
kp.next = nil keypairs.next = nil
return true return true
} }

View file

@ -36,8 +36,8 @@ func TestNoiseHandshake(t *testing.T) {
defer dev1.Close() defer dev1.Close()
defer dev2.Close() defer dev2.Close()
peer1, _ := dev2.NewPeer(dev1.noise.privateKey.publicKey()) peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
peer2, _ := dev1.NewPeer(dev2.noise.privateKey.publicKey()) peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
assertEqual( assertEqual(
t, t,
@ -102,8 +102,8 @@ func TestNoiseHandshake(t *testing.T) {
t.Log("deriving keys") t.Log("deriving keys")
key1 := peer1.DeriveNewKeypair() key1 := peer1.BeginSymmetricSession()
key2 := peer2.DeriveNewKeypair() key2 := peer2.BeginSymmetricSession()
if key1 == nil { if key1 == nil {
t.Fatal("failed to dervice keypair for peer 1") t.Fatal("failed to dervice keypair for peer 1")

70
peer.go
View file

@ -19,7 +19,7 @@ const (
type Peer struct { type Peer struct {
isRunning AtomicBool isRunning AtomicBool
mutex sync.RWMutex mutex sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
keypairs Keypairs keypairs Keypairs
handshake Handshake handshake Handshake
device *Device device *Device
@ -42,7 +42,6 @@ type Peer struct {
handshakeAttempts uint handshakeAttempts uint
needAnotherKeepalive bool needAnotherKeepalive bool
sentLastMinuteHandshake bool sentLastMinuteHandshake bool
lastSentHandshake time.Time
} }
signals struct { signals struct {
@ -64,7 +63,7 @@ type Peer struct {
stop chan struct{} // size 0, stop all go routines in peer stop chan struct{} // size 0, stop all go routines in peer
} }
mac CookieGenerator cookieGenerator CookieGenerator
} }
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
@ -75,11 +74,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// lock resources // lock resources
device.state.mutex.Lock() device.staticIdentity.mutex.RLock()
defer device.state.mutex.Unlock() defer device.staticIdentity.mutex.RUnlock()
device.noise.mutex.RLock()
defer device.noise.mutex.RUnlock()
device.peers.mutex.Lock() device.peers.mutex.Lock()
defer device.peers.mutex.Unlock() defer device.peers.mutex.Unlock()
@ -96,7 +92,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.mutex.Lock() peer.mutex.Lock()
defer peer.mutex.Unlock() defer peer.mutex.Unlock()
peer.mac.Init(pk) peer.cookieGenerator.Init(pk)
peer.device = device peer.device = device
peer.isRunning.Set(false) peer.isRunning.Set(false)
@ -113,7 +109,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
handshake.remoteStatic = pk handshake.remoteStatic = pk
handshake.precomputedStaticStatic = device.noise.privateKey.sharedSecret(pk) handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.mutex.Unlock() handshake.mutex.Unlock()
// reset endpoint // reset endpoint
@ -191,6 +187,7 @@ func (peer *Peer) Start() {
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
peer.timersInit() peer.timersInit()
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
peer.signals.newKeypairArrived = make(chan struct{}, 1) peer.signals.newKeypairArrived = make(chan struct{}, 1)
peer.signals.flushNonceQueue = make(chan struct{}, 1) peer.signals.flushNonceQueue = make(chan struct{}, 1)
@ -204,6 +201,32 @@ func (peer *Peer) Start() {
peer.isRunning.Set(true) peer.isRunning.Set(true)
} }
func (peer *Peer) ZeroAndFlushAll() {
device := peer.device
// clear key pairs
keypairs := &peer.keypairs
keypairs.mutex.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.next)
keypairs.previous = nil
keypairs.current = nil
keypairs.next = nil
keypairs.mutex.Unlock()
// clear handshake state
handshake := &peer.handshake
handshake.mutex.Lock()
device.indexTable.Delete(handshake.localIndex)
handshake.Clear()
handshake.mutex.Unlock()
peer.FlushNonceQueue()
}
func (peer *Peer) Stop() { func (peer *Peer) Stop() {
// prevent simultaneous start/stop operations // prevent simultaneous start/stop operations
@ -215,8 +238,7 @@ func (peer *Peer) Stop() {
return return
} }
device := peer.device peer.device.log.Debug.Println(peer, ": Stopping...")
device.log.Debug.Println(peer, ": Stopping...")
peer.timersStop() peer.timersStop()
@ -232,27 +254,5 @@ func (peer *Peer) Stop() {
close(peer.queue.outbound) close(peer.queue.outbound)
close(peer.queue.inbound) close(peer.queue.inbound)
// clear key pairs peer.ZeroAndFlushAll()
kp := &peer.keypairs
kp.mutex.Lock()
device.DeleteKeypair(kp.previous)
device.DeleteKeypair(kp.current)
device.DeleteKeypair(kp.next)
kp.previous = nil
kp.current = nil
kp.next = nil
kp.mutex.Unlock()
// clear handshake state
hs := &peer.handshake
hs.mutex.Lock()
device.indexTable.Delete(hs.localIndex)
hs.Clear()
hs.mutex.Unlock()
peer.FlushNonceQueue()
} }

View file

@ -107,8 +107,8 @@ func (peer *Peer) keepKeyFreshReceiving() {
if peer.timers.sentLastMinuteHandshake { if peer.timers.sentLastMinuteHandshake {
return return
} }
kp := peer.keypairs.Current() keypair := peer.keypairs.Current()
if kp != nil && kp.isInitiator && time.Now().Sub(kp.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { if keypair != nil && keypair.isInitiator && time.Now().Sub(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake = true peer.timers.sentLastMinuteHandshake = true
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
@ -325,7 +325,6 @@ func (device *Device) RoutineHandshake() {
logDebug.Println("Routine: handshake worker - started") logDebug.Println("Routine: handshake worker - started")
var temp [MessageHandshakeSize]byte
var elem QueueHandshakeElement var elem QueueHandshakeElement
var ok bool var ok bool
@ -367,52 +366,28 @@ func (device *Device) RoutineHandshake() {
// consume reply // consume reply
if peer := entry.peer; peer.isRunning.Get() { if peer := entry.peer; peer.isRunning.Get() {
peer.mac.ConsumeReply(&reply) peer.cookieGenerator.ConsumeReply(&reply)
} }
continue continue
case MessageInitiationType, MessageResponseType: case MessageInitiationType, MessageResponseType:
// check mac fields and ratelimit // check mac fields and maybe ratelimit
if !device.mac.CheckMAC1(elem.packet) { if !device.cookieChecker.CheckMAC1(elem.packet) {
logDebug.Println("Received packet with invalid mac1") logDebug.Println("Received packet with invalid mac1")
continue continue
} }
// endpoints destination address is the source of the datagram // endpoints destination address is the source of the datagram
srcBytes := elem.endpoint.DstToBytes()
if device.IsUnderLoad() { if device.IsUnderLoad() {
// verify MAC2 field // verify MAC2 field
if !device.mac.CheckMAC2(elem.packet, srcBytes) { if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
device.SendHandshakeCookie(&elem)
// construct cookie reply
logDebug.Println(
"Sending cookie reply to:",
elem.endpoint.DstToString(),
)
sender := binary.LittleEndian.Uint32(elem.packet[4:8])
reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
if err != nil {
logError.Println("Failed to create cookie reply:", err)
continue
}
// marshal and send reply
writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, reply)
device.net.bind.Send(writer.Bytes(), elem.endpoint)
if err != nil {
logDebug.Println("Failed to send cookie reply:", err)
}
continue continue
} }
@ -467,34 +442,7 @@ func (device *Device) RoutineHandshake() {
logDebug.Println(peer, ": Received handshake initiation") logDebug.Println(peer, ": Received handshake initiation")
// create response peer.SendHandshakeResponse()
response, err := device.CreateMessageResponse(peer)
if err != nil {
logError.Println("Failed to create response message:", err)
continue
}
if peer.DeriveNewKeypair() != nil {
continue
}
logDebug.Println(peer, ": Sending handshake response")
writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.mac.AddMacs(packet)
// send response
peer.timers.lastSentHandshake = time.Now()
err = peer.SendBuffer(packet)
if err == nil {
peer.timersAnyAuthenticatedPacketTraversal()
} else {
logError.Println(peer, ": Failed to send handshake response", err)
}
case MessageResponseType: case MessageResponseType:
@ -534,10 +482,14 @@ func (device *Device) RoutineHandshake() {
// derive keypair // derive keypair
if peer.DeriveNewKeypair() != nil { err = peer.BeginSymmetricSession()
if err != nil {
logError.Println(peer, ": Failed to derive keypair:", err)
continue continue
} }
peer.timersSessionDerived()
peer.timersHandshakeComplete() peer.timersHandshakeComplete()
peer.SendKeepalive() peer.SendKeepalive()
select { select {
@ -640,7 +592,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// verify IPv4 source // verify IPv4 source
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.routing.table.LookupIPv4(src) != peer { if device.allowedips.LookupIPv4(src) != peer {
logInfo.Println( logInfo.Println(
"IPv4 packet with disallowed source address from", "IPv4 packet with disallowed source address from",
peer, peer,
@ -668,7 +620,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// verify IPv6 source // verify IPv6 source
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routing.table.LookupIPv6(src) != peer { if device.allowedips.LookupIPv6(src) != peer {
logInfo.Println( logInfo.Println(
peer, peer,
"sent packet with disallowed IPv6 source", "sent packet with disallowed IPv6 source",

108
send.go
View file

@ -121,52 +121,114 @@ func (peer *Peer) SendKeepalive() bool {
} }
} }
/* Sends a new handshake initiation message to the peer (endpoint)
*/
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry { if !isRetry {
peer.timers.handshakeAttempts = 0 peer.timers.handshakeAttempts = 0
} }
if time.Now().Sub(peer.timers.lastSentHandshake) < RekeyTimeout { peer.handshake.mutex.RLock()
if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout {
peer.handshake.mutex.RUnlock()
return nil return nil
} }
peer.timers.lastSentHandshake = time.Now() //TODO: locking for this variable? peer.handshake.mutex.RUnlock()
// create initiation message peer.handshake.mutex.Lock()
if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout {
msg, err := peer.device.CreateMessageInitiation(peer) peer.handshake.mutex.Unlock()
if err != nil { return nil
return err
} }
peer.handshake.lastSentHandshake = time.Now()
peer.handshake.mutex.Unlock()
peer.device.log.Debug.Println(peer, ": Sending handshake initiation") peer.device.log.Debug.Println(peer, ": Sending handshake initiation")
// marshal handshake message msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
peer.device.log.Error.Println(peer, ": Failed to create initiation message:", err)
return err
}
var buff [MessageInitiationSize]byte var buff [MessageInitiationSize]byte
writer := bytes.NewBuffer(buff[:0]) writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, msg) binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes() packet := writer.Bytes()
peer.mac.AddMacs(packet) peer.cookieGenerator.AddMacs(packet)
// send to endpoint
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
err = peer.SendBuffer(packet)
if err != nil {
peer.device.log.Error.Println(peer, ": Failed to send handshake initiation", err)
}
peer.timersHandshakeInitiated() peer.timersHandshakeInitiated()
return peer.SendBuffer(packet)
return err
}
func (peer *Peer) SendHandshakeResponse() error {
peer.handshake.mutex.Lock()
peer.handshake.lastSentHandshake = time.Now()
peer.handshake.mutex.Unlock()
peer.device.log.Debug.Println(peer, ": Sending handshake response")
response, err := peer.device.CreateMessageResponse(peer)
if err != nil {
peer.device.log.Error.Println(peer, ": Failed to create response message:", err)
return err
}
var buff [MessageResponseSize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
err = peer.BeginSymmetricSession()
if err != nil {
peer.device.log.Error.Println(peer, ": Failed to derive keypair:", err)
return err
}
peer.timersSessionDerived()
peer.timersAnyAuthenticatedPacketTraversal()
err = peer.SendBuffer(packet)
if err != nil {
peer.device.log.Error.Println(peer, ": Failed to send handshake response", err)
}
return err
}
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
device.log.Debug.Println("Sending cookie reply to:", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
if err != nil {
device.log.Error.Println("Failed to create cookie reply:", err)
return err
}
var buff [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, reply)
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
if err != nil {
device.log.Error.Println("Failed to send cookie reply:", err)
}
return err
} }
/* Called when a new authenticated message has been send
*
*/
func (peer *Peer) keepKeyFreshSending() { func (peer *Peer) keepKeyFreshSending() {
kp := peer.keypairs.Current() keypair := peer.keypairs.Current()
if kp == nil { if keypair == nil {
return return
} }
nonce := atomic.LoadUint64(&kp.sendNonce) nonce := atomic.LoadUint64(&keypair.sendNonce)
if nonce > RekeyAfterMessages || (kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime) { if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Now().Sub(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
} }
@ -217,14 +279,14 @@ func (device *Device) RoutineReadFromTUN() {
continue continue
} }
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.routing.table.LookupIPv4(dst) peer = device.allowedips.LookupIPv4(dst)
case ipv6.Version: case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen { if len(elem.packet) < ipv6.HeaderLen {
continue continue
} }
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.routing.table.LookupIPv6(dst) peer = device.allowedips.LookupIPv6(dst)
default: default:
logDebug.Println("Received packet with unknown IP version") logDebug.Println("Received packet with unknown IP version")

View file

@ -104,30 +104,7 @@ func expiredNewHandshake(peer *Peer) {
func expiredZeroKeyMaterial(peer *Peer) { func expiredZeroKeyMaterial(peer *Peer) {
peer.device.log.Debug.Printf(":%s Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds())) peer.device.log.Debug.Printf(":%s Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds()))
peer.ZeroAndFlushAll()
hs := &peer.handshake
hs.mutex.Lock()
kp := &peer.keypairs
kp.mutex.Lock()
if kp.previous != nil {
peer.device.DeleteKeypair(kp.previous)
kp.previous = nil
}
if kp.current != nil {
peer.device.DeleteKeypair(kp.current)
kp.current = nil
}
if kp.next != nil {
peer.device.DeleteKeypair(kp.next)
kp.next = nil
}
kp.mutex.Unlock()
peer.device.indexTable.Delete(hs.localIndex)
hs.Clear()
hs.mutex.Unlock()
} }
func expiredPersistentKeepalive(peer *Peer) { func expiredPersistentKeepalive(peer *Peer) {
@ -209,7 +186,6 @@ func (peer *Peer) timersInit() {
peer.timers.handshakeAttempts = 0 peer.timers.handshakeAttempts = 0
peer.timers.sentLastMinuteHandshake = false peer.timers.sentLastMinuteHandshake = false
peer.timers.needAnotherKeepalive = false peer.timers.needAnotherKeepalive = false
peer.timers.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
} }
func (peer *Peer) timersStop() { func (peer *Peer) timersStop() {

34
uapi.go
View file

@ -46,19 +46,16 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.net.mutex.RLock() device.net.mutex.RLock()
defer device.net.mutex.RUnlock() defer device.net.mutex.RUnlock()
device.noise.mutex.RLock() device.staticIdentity.mutex.RLock()
defer device.noise.mutex.RUnlock() defer device.staticIdentity.mutex.RUnlock()
device.routing.mutex.RLock() device.peers.mutex.RLock()
defer device.routing.mutex.RUnlock() defer device.peers.mutex.RUnlock()
device.peers.mutex.Lock()
defer device.peers.mutex.Unlock()
// serialize device related values // serialize device related values
if !device.noise.privateKey.IsZero() { if !device.staticIdentity.privateKey.IsZero() {
send("private_key=" + device.noise.privateKey.ToHex()) send("private_key=" + device.staticIdentity.privateKey.ToHex())
} }
if device.net.port != 0 { if device.net.port != 0 {
@ -91,7 +88,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send(fmt.Sprintf("rx_bytes=%d", peer.stats.rxBytes)) send(fmt.Sprintf("rx_bytes=%d", peer.stats.rxBytes))
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
for _, ip := range device.routing.table.EntriesForPeer(peer) { for _, ip := range device.allowedips.EntriesForPeer(peer) {
send("allowed_ip=" + ip.String()) send("allowed_ip=" + ip.String())
} }
@ -234,13 +231,12 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// ignore peer with public key of device // ignore peer with public key of device
device.noise.mutex.RLock() device.staticIdentity.mutex.RLock()
equals := device.noise.publicKey.Equals(publicKey) dummy = device.staticIdentity.publicKey.Equals(publicKey)
device.noise.mutex.RUnlock() device.staticIdentity.mutex.RUnlock()
if equals { if dummy {
peer = &Peer{} peer = &Peer{}
dummy = true
} }
// find peer referenced // find peer referenced
@ -348,9 +344,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
continue continue
} }
device.routing.mutex.Lock() device.allowedips.RemoveByPeer(peer)
device.routing.table.RemoveByPeer(peer)
device.routing.mutex.Unlock()
case "allowed_ip": case "allowed_ip":
@ -367,9 +361,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
} }
ones, _ := network.Mask.Size() ones, _ := network.Mask.Size()
device.routing.mutex.Lock() device.allowedips.Insert(network.IP, uint(ones), peer)
device.routing.table.Insert(network.IP, uint(ones), peer)
device.routing.mutex.Unlock()
default: default:
logError.Println("Invalid UAPI key (peer configuration):", key) logError.Println("Invalid UAPI key (peer configuration):", key)