Initial implementation of source caching

Yet untested.
This commit is contained in:
Mathias Hall-Andersen 2017-10-16 21:33:47 +02:00
parent a72b0f7ae5
commit e86d03dca2
10 changed files with 84 additions and 83 deletions

View file

@ -34,15 +34,20 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
return addr, err return addr, err
} }
func ListeningUpdate(device *Device) error { func UpdateUDPListener(device *Device) error {
device.mutex.Lock()
defer device.mutex.Unlock()
netc := &device.net netc := &device.net
netc.mutex.Lock() netc.mutex.Lock()
defer netc.mutex.Unlock() defer netc.mutex.Unlock()
// close existing sockets // close existing sockets
if err := device.net.bind.Close(); err != nil { if netc.bind != nil {
return err if err := netc.bind.Close(); err != nil {
return err
}
} }
// open new sockets // open new sockets
@ -64,13 +69,19 @@ func ListeningUpdate(device *Device) error {
return err return err
} }
// TODO: clear endpoint (src) caches // clear cached source addresses
for _, peer := range device.peers {
peer.mutex.Lock()
peer.endpoint.value.ClearSrc()
peer.mutex.Unlock()
}
} }
return nil return nil
} }
func ListeningClose(device *Device) error { func CloseUDPListener(device *Device) error {
netc := &device.net netc := &device.net
netc.mutex.Lock() netc.mutex.Lock()
defer netc.mutex.Unlock() defer netc.mutex.Unlock()

View file

@ -133,7 +133,7 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string {
} }
} }
func (end *Endpoint) DestinationIP() net.IP { func (end *Endpoint) DstIP() net.IP {
switch end.dst.Family { switch end.dst.Family {
case unix.AF_INET6: case unix.AF_INET6:
return end.dst.Addr[:] return end.dst.Addr[:]
@ -150,20 +150,24 @@ func (end *Endpoint) DestinationIP() net.IP {
} }
} }
func (end *Endpoint) SourceToBytes() []byte { func (end *Endpoint) SrcToBytes() []byte {
ptr := unsafe.Pointer(&end.src) ptr := unsafe.Pointer(&end.src)
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
return arr[:] return arr[:]
} }
func (end *Endpoint) SourceToString() string { func (end *Endpoint) SrcToString() string {
return sockaddrToString(end.src) return sockaddrToString(end.src)
} }
func (end *Endpoint) DestinationToString() string { func (end *Endpoint) DstToString() string {
return sockaddrToString(end.dst) return sockaddrToString(end.dst)
} }
func (end *Endpoint) ClearDst() {
end.dst = unix.RawSockaddrInet6{}
}
func (end *Endpoint) ClearSrc() { func (end *Endpoint) ClearSrc() {
end.src = unix.RawSockaddrInet6{} end.src = unix.RawSockaddrInet6{}
} }

View file

@ -205,7 +205,7 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) Close() { func (device *Device) Close() {
device.RemoveAllPeers() device.RemoveAllPeers()
close(device.signal.stop) close(device.signal.stop)
ListeningClose(device) CloseUDPListener(device)
} }
func (device *Device) WaitChannel() chan struct{} { func (device *Device) WaitChannel() chan struct{} {

View file

@ -14,8 +14,6 @@ func printUsage() {
} }
func main() { func main() {
test()
// parse arguments // parse arguments
var foreground bool var foreground bool

View file

@ -14,9 +14,12 @@ type Peer struct {
persistentKeepaliveInterval uint64 persistentKeepaliveInterval uint64
keyPairs KeyPairs keyPairs KeyPairs
handshake Handshake handshake Handshake
endpoint Endpoint
device *Device device *Device
stats struct { endpoint struct {
set bool // has a known endpoint been discovered
value Endpoint // source / destination cache
}
stats struct {
txBytes uint64 // bytes send to peer (endpoint) txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer rxBytes uint64 // bytes received from peer
lastHandshakeNano int64 // nano seconds since epoch lastHandshakeNano int64 // nano seconds since epoch
@ -105,6 +108,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic) handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
handshake.mutex.Unlock() handshake.mutex.Unlock()
// reset endpoint
peer.endpoint.set = false
peer.endpoint.value.ClearDst()
peer.endpoint.value.ClearSrc()
// prepare queuing // prepare queuing
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
@ -129,11 +138,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil return peer, nil
} }
/* Returns a short string identification for logging
*/
func (peer *Peer) String() string { func (peer *Peer) String() string {
if !peer.endpoint.set {
return fmt.Sprintf(
"peer(%d unknown %s)",
peer.id,
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
)
}
return fmt.Sprintf( return fmt.Sprintf(
"peer(%d %s %s)", "peer(%d %s %s)",
peer.id, peer.id,
peer.endpoint.DestinationToString(), peer.endpoint.value.DstToString(),
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
) )
} }

View file

@ -331,7 +331,7 @@ func (device *Device) RoutineHandshake() {
return return
} }
srcBytes := elem.endpoint.SourceToBytes() srcBytes := elem.endpoint.SrcToBytes()
if device.IsUnderLoad() { if device.IsUnderLoad() {
// verify MAC2 field // verify MAC2 field
@ -340,8 +340,7 @@ func (device *Device) RoutineHandshake() {
// construct cookie reply // construct cookie reply
logDebug.Println("Sending cookie reply to:", elem.endpoint.SourceToString()) logDebug.Println("Sending cookie reply to:", elem.endpoint.SrcToString())
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes) reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
if err != nil { if err != nil {
@ -365,9 +364,7 @@ func (device *Device) RoutineHandshake() {
// check ratelimiter // check ratelimiter
if !device.ratelimiter.Allow( if !device.ratelimiter.Allow(elem.endpoint.DstIP()) {
elem.endpoint.DestinationIP(),
) {
continue continue
} }
} }
@ -398,7 +395,7 @@ func (device *Device) RoutineHandshake() {
if peer == nil { if peer == nil {
logInfo.Println( logInfo.Println(
"Recieved invalid initiation message from", "Recieved invalid initiation message from",
elem.endpoint.DestinationToString(), elem.endpoint.DstToString(),
) )
continue continue
} }
@ -412,7 +409,8 @@ func (device *Device) RoutineHandshake() {
// TODO: Discover destination address also, only update on change // TODO: Discover destination address also, only update on change
peer.mutex.Lock() peer.mutex.Lock()
peer.endpoint = elem.endpoint peer.endpoint.set = true
peer.endpoint.value = elem.endpoint
peer.mutex.Unlock() peer.mutex.Unlock()
// create response // create response
@ -435,7 +433,7 @@ func (device *Device) RoutineHandshake() {
// send response // send response
_, err = peer.SendBuffer(packet) err = peer.SendBuffer(packet)
if err == nil { if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal() peer.TimerAnyAuthenticatedPacketTraversal()
} }
@ -458,7 +456,7 @@ func (device *Device) RoutineHandshake() {
if peer == nil { if peer == nil {
logInfo.Println( logInfo.Println(
"Recieved invalid response message from", "Recieved invalid response message from",
elem.endpoint.DestinationToString(), elem.endpoint.DstToString(),
) )
continue continue
} }

View file

@ -105,24 +105,15 @@ func addToEncryptionQueue(
} }
} }
func (peer *Peer) SendBuffer(buffer []byte) (int, error) { func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.mutex.RLock() peer.device.net.mutex.RLock()
defer peer.device.net.mutex.RUnlock() defer peer.device.net.mutex.RUnlock()
peer.mutex.RLock() peer.mutex.RLock()
defer peer.mutex.RUnlock() defer peer.mutex.RUnlock()
if !peer.endpoint.set {
endpoint := peer.endpoint return errors.New("No known endpoint for peer")
if endpoint == nil {
return 0, errors.New("No known endpoint for peer")
} }
return peer.device.net.bind.Send(buffer, &peer.endpoint.value)
conn := peer.device.net.conn
if conn == nil {
return 0, errors.New("No UDP socket for device")
}
return conn.WriteToUDP(buffer, endpoint)
} }
/* Reads packets from the TUN and inserts /* Reads packets from the TUN and inserts
@ -343,7 +334,7 @@ func (peer *Peer) RoutineSequentialSender() {
// send message and return buffer to pool // send message and return buffer to pool
length := uint64(len(elem.packet)) length := uint64(len(elem.packet))
_, err := peer.SendBuffer(elem.packet) err := peer.SendBuffer(elem.packet)
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
if err != nil { if err != nil {
logDebug.Println("Failed to send authenticated packet to peer", peer.String()) logDebug.Println("Failed to send authenticated packet to peer", peer.String())

View file

@ -288,7 +288,7 @@ func (peer *Peer) RoutineHandshakeInitiator() {
packet := writer.Bytes() packet := writer.Bytes()
peer.mac.AddMacs(packet) peer.mac.AddMacs(packet)
_, err = peer.SendBuffer(packet) err = peer.SendBuffer(packet)
if err != nil { if err != nil {
logError.Println( logError.Println(
"Failed to send handshake initiation message to", "Failed to send handshake initiation message to",

View file

@ -47,7 +47,7 @@ func (device *Device) RoutineTUNEventReader() {
if !device.tun.isUp.Get() { if !device.tun.isUp.Get() {
logInfo.Println("Interface set up") logInfo.Println("Interface set up")
device.tun.isUp.Set(true) device.tun.isUp.Set(true)
updateUDPConn(device) UpdateUDPListener(device)
} }
} }
@ -55,7 +55,7 @@ func (device *Device) RoutineTUNEventReader() {
if device.tun.isUp.Get() { if device.tun.isUp.Get() {
logInfo.Println("Interface set down") logInfo.Println("Interface set down")
device.tun.isUp.Set(false) device.tun.isUp.Set(false)
closeUDPConn(device) CloseUDPListener(device)
} }
} }
} }

View file

@ -39,9 +39,10 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send("private_key=" + device.privateKey.ToHex()) send("private_key=" + device.privateKey.ToHex())
} }
if device.net.addr != nil { if device.net.port != 0 {
send(fmt.Sprintf("listen_port=%d", device.net.addr.Port)) send(fmt.Sprintf("listen_port=%d", device.net.port))
} }
if device.net.fwmark != 0 { if device.net.fwmark != 0 {
send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
} }
@ -52,8 +53,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
defer peer.mutex.RUnlock() defer peer.mutex.RUnlock()
send("public_key=" + peer.handshake.remoteStatic.ToHex()) send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex()) send("preshared_key=" + peer.handshake.presharedKey.ToHex())
if peer.endpoint != nil { if peer.endpoint.set {
send("endpoint=" + peer.endpoint.String()) send("endpoint=" + peer.endpoint.value.DstToString())
} }
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
@ -137,53 +138,24 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set listen_port:", err) logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
device.net.port = uint16(port)
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) if err := UpdateUDPListener(device); err != nil {
if err != nil {
logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorInvalid}
}
device.net.mutex.Lock()
device.net.addr = addr
device.net.mutex.Unlock()
err = updateUDPConn(device)
if err != nil {
logError.Println("Failed to set listen_port:", err) logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorPortInUse} return &IPCError{Code: ipcErrorPortInUse}
} }
// TODO: Clear source address of all peers
case "fwmark": case "fwmark":
fwmark, err := strconv.ParseUint(value, 10, 32) fwmark, err := strconv.ParseUint(value, 10, 32)
if err != nil { if err != nil {
logError.Println("Invalid fwmark", err) logError.Println("Invalid fwmark", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
device.net.mutex.Lock() device.net.mutex.Lock()
if fwmark > 0 || device.net.fwmark > 0 { device.net.fwmark = uint32(fwmark)
device.net.fwmark = uint32(fwmark)
err := SetMark(
device.net.conn,
device.net.fwmark,
)
if err != nil {
logError.Println("Failed to set fwmark:", err)
device.net.mutex.Unlock()
return &IPCError{Code: ipcErrorIO}
}
// TODO: Clear source address of all peers
}
device.net.mutex.Unlock() device.net.mutex.Unlock()
case "public_key": case "public_key":
// switch to peer configuration // switch to peer configuration
deviceConfig = false deviceConfig = false
case "replace_peers": case "replace_peers":
@ -218,7 +190,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.mutex.RLock() device.mutex.RLock()
if device.publicKey.Equals(pubKey) { if device.publicKey.Equals(pubKey) {
// create dummy instance // create dummy instance (not added to device)
peer = &Peer{} peer = &Peer{}
dummy = true dummy = true
@ -244,6 +216,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
} }
case "remove": case "remove":
// remove currently selected peer from device
if value != "true" { if value != "true" {
logError.Println("Failed to set remove, invalid value:", value) logError.Println("Failed to set remove, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
@ -256,6 +231,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
dummy = true dummy = true
case "preshared_key": case "preshared_key":
// update PSK
peer.mutex.Lock() peer.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value) err := peer.handshake.presharedKey.FromHex(value)
peer.mutex.Unlock() peer.mutex.Unlock()
@ -265,14 +243,17 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
} }
case "endpoint": case "endpoint":
addr, err := parseEndpoint(value)
// set endpoint destination and reset handshake timer
peer.mutex.Lock()
err := peer.endpoint.value.Set(value)
peer.endpoint.set = (err == nil)
peer.mutex.Unlock()
if err != nil { if err != nil {
logError.Println("Failed to set endpoint:", value) logError.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
peer.mutex.Lock()
peer.endpoint = addr
peer.mutex.Unlock()
signalSend(peer.signal.handshakeReset) signalSend(peer.signal.handshakeReset)
case "persistent_keepalive_interval": case "persistent_keepalive_interval":