Added source verification

This commit is contained in:
Mathias Hall-Andersen 2017-07-08 09:23:10 +02:00
parent ed31e75739
commit 5c1ccbddf0
5 changed files with 115 additions and 44 deletions

View file

@ -61,8 +61,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
if peer.endpoint != nil { if peer.endpoint != nil {
send("endpoint=" + peer.endpoint.String()) send("endpoint=" + peer.endpoint.String())
} }
send(fmt.Sprintf("tx_bytes=%d", peer.tx_bytes)) send(fmt.Sprintf("tx_bytes=%d", peer.txBytes))
send(fmt.Sprintf("rx_bytes=%d", peer.rx_bytes)) send(fmt.Sprintf("rx_bytes=%d", peer.rxBytes))
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
for _, ip := range device.routingTable.AllowedIPs(peer) { for _, ip := range device.routingTable.AllowedIPs(peer) {
send("allowed_ip=" + ip.String()) send("allowed_ip=" + ip.String())
@ -73,7 +73,6 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
// send lines // send lines
for _, line := range lines { for _, line := range lines {
device.log.Debug.Println("Response:", line)
_, err := socket.WriteString(line + "\n") _, err := socket.WriteString(line + "\n")
if err != nil { if err != nil {
return err return err

View file

@ -31,10 +31,16 @@ type Device struct {
signal struct { signal struct {
stop chan struct{} stop chan struct{}
} }
congestionState int32 // used as an atomic bool
peers map[NoisePublicKey]*Peer peers map[NoisePublicKey]*Peer
mac MACStateDevice mac MACStateDevice
} }
const (
CongestionStateUnderLoad = iota
CongestionStateOkay
)
func (device *Device) SetPrivateKey(sk NoisePrivateKey) { func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
device.mutex.Lock() device.mutex.Lock()
defer device.mutex.Unlock() defer device.mutex.Unlock()
@ -93,6 +99,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
go device.RoutineDecryption() go device.RoutineDecryption()
go device.RoutineHandshake() go device.RoutineHandshake()
} }
go device.RoutineBusyMonitor()
go device.RoutineReadFromTUN(tun) go device.RoutineReadFromTUN(tun)
go device.RoutineReceiveIncomming() go device.RoutineReceiveIncomming()
go device.RoutineWriteToTUN(tun) go device.RoutineWriteToTUN(tun)

View file

@ -17,8 +17,8 @@ type Peer struct {
keyPairs KeyPairs keyPairs KeyPairs
handshake Handshake handshake Handshake
device *Device device *Device
tx_bytes uint64 txBytes uint64
rx_bytes uint64 rxBytes uint64
time struct { time struct {
lastSend time.Time // last send message lastSend time.Time // last send message
lastHandshake time.Time // last completed handshake lastHandshake time.Time // last completed handshake

View file

@ -72,12 +72,48 @@ func addToHandshakeQueue(
} }
} }
/* Routine determining the busy state of the interface
*
* TODO: prehaps nicer to do this in response to events
* TODO: more well reasoned definition of "busy"
*/
func (device *Device) RoutineBusyMonitor() {
samples := 0
interval := time.Second
for timer := time.NewTimer(interval); ; {
select {
case <-device.signal.stop:
return
case <-timer.C:
}
// compute busy heuristic
if len(device.queue.handshake) > QueueHandshakeBusySize {
samples += 1
} else if samples > 0 {
samples -= 1
}
samples %= 30
busy := samples > 5
// update busy state
if busy {
atomic.StoreInt32(&device.congestionState, CongestionStateUnderLoad)
} else {
atomic.StoreInt32(&device.congestionState, CongestionStateOkay)
}
timer.Reset(interval)
}
}
func (device *Device) RoutineReceiveIncomming() { func (device *Device) RoutineReceiveIncomming() {
debugLog := device.log.Debug logDebug := device.log.Debug
debugLog.Println("Routine, receive incomming, started") logDebug.Println("Routine, receive incomming, started")
errorLog := device.log.Error
var buffer []byte var buffer []byte
@ -122,33 +158,6 @@ func (device *Device) RoutineReceiveIncomming() {
case MessageInitiationType, MessageResponseType: case MessageInitiationType, MessageResponseType:
// verify mac1
if !device.mac.CheckMAC1(packet) {
debugLog.Println("Received packet with invalid mac1")
return
}
// check if busy, TODO: refine definition of "busy"
busy := len(device.queue.handshake) > QueueHandshakeBusySize
if busy && !device.mac.CheckMAC2(packet, raddr) {
sender := binary.LittleEndian.Uint32(packet[4:8]) // "sender" always follows "type"
reply, err := device.CreateMessageCookieReply(packet, sender, raddr)
if err != nil {
errorLog.Println("Failed to create cookie reply:", err)
return
}
writer := bytes.NewBuffer(packet[:0])
binary.Write(writer, binary.LittleEndian, reply)
packet = writer.Bytes()
_, err = device.net.conn.WriteToUDP(packet, raddr)
if err != nil {
debugLog.Println("Failed to send cookie reply:", err)
}
return
}
// add to handshake queue // add to handshake queue
addToHandshakeQueue( addToHandshakeQueue(
@ -173,7 +182,7 @@ func (device *Device) RoutineReceiveIncomming() {
reader := bytes.NewReader(packet) reader := bytes.NewReader(packet)
err := binary.Read(reader, binary.LittleEndian, &reply) err := binary.Read(reader, binary.LittleEndian, &reply)
if err != nil { if err != nil {
debugLog.Println("Failed to decode cookie reply") logDebug.Println("Failed to decode cookie reply")
return return
} }
device.ConsumeMessageCookieReply(&reply) device.ConsumeMessageCookieReply(&reply)
@ -218,7 +227,7 @@ func (device *Device) RoutineReceiveIncomming() {
default: default:
// unknown message type // unknown message type
debugLog.Println("Got unknown message from:", raddr) logDebug.Println("Got unknown message from:", raddr)
} }
}() }()
} }
@ -285,6 +294,38 @@ func (device *Device) RoutineHandshake() {
func() { func() {
// verify mac1
if !device.mac.CheckMAC1(elem.packet) {
logDebug.Println("Received packet with invalid mac1")
return
}
// verify mac2
busy := atomic.LoadInt32(&device.congestionState) == CongestionStateUnderLoad
if busy && !device.mac.CheckMAC2(elem.packet, elem.source) {
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source)
if err != nil {
logError.Println("Failed to create cookie reply:", err)
return
}
writer := bytes.NewBuffer(elem.packet[:0])
binary.Write(writer, binary.LittleEndian, reply)
elem.packet = writer.Bytes()
_, err = device.net.conn.WriteToUDP(elem.packet, elem.source)
if err != nil {
logDebug.Println("Failed to send cookie reply:", err)
}
return
}
// ratelimit
// handle messages
switch elem.msgType { switch elem.msgType {
case MessageInitiationType: case MessageInitiationType:
@ -321,12 +362,12 @@ func (device *Device) RoutineHandshake() {
logError.Println("Failed to create response message:", err) logError.Println("Failed to create response message:", err)
return return
} }
outElem := device.NewOutboundElement() outElem := device.NewOutboundElement()
writer := bytes.NewBuffer(outElem.data[:0]) writer := bytes.NewBuffer(outElem.data[:0])
binary.Write(writer, binary.LittleEndian, response) binary.Write(writer, binary.LittleEndian, response)
elem.packet = writer.Bytes() elem.packet = writer.Bytes()
peer.mac.AddMacs(elem.packet) peer.mac.AddMacs(elem.packet)
device.log.Debug.Println(elem.packet)
addToOutboundQueue(peer.queue.outbound, outElem) addToOutboundQueue(peer.queue.outbound, outElem)
case MessageResponseType: case MessageResponseType:
@ -388,7 +429,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
} }
elem.mutex.Lock() elem.mutex.Lock()
// process IP packet // process packet
func() { func() {
if elem.IsDropped() { if elem.IsDropped() {
@ -407,30 +448,54 @@ func (peer *Peer) RoutineSequentialReceiver() {
return return
} }
// strip padding // verify source and strip padding
switch elem.packet[0] >> 4 { switch elem.packet[0] >> 4 {
case IPv4version: case IPv4version:
// strip padding
if len(elem.packet) < IPv4headerSize { if len(elem.packet) < IPv4headerSize {
return return
} }
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field) length := binary.BigEndian.Uint16(field)
elem.packet = elem.packet[:length] elem.packet = elem.packet[:length]
// verify IPv4 source
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
if device.routingTable.LookupIPv4(dst) != peer {
return
}
case IPv6version: case IPv6version:
// strip padding
if len(elem.packet) < IPv6headerSize { if len(elem.packet) < IPv6headerSize {
return return
} }
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field) length := binary.BigEndian.Uint16(field)
length += IPv6headerSize length += IPv6headerSize
elem.packet = elem.packet[:length] elem.packet = elem.packet[:length]
// verify IPv6 source
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
if device.routingTable.LookupIPv6(dst) != peer {
return
}
default: default:
device.log.Debug.Println("Receieved packet with unknown IP version") device.log.Debug.Println("Receieved packet with unknown IP version")
return return
} }
atomic.AddUint64(&peer.rxBytes, uint64(len(elem.packet)))
addToInboundQueue(device.queue.inbound, elem) addToInboundQueue(device.queue.inbound, elem)
}() }()
} }

View file

@ -329,7 +329,7 @@ func (peer *Peer) RoutineSequentialSender() {
if err != nil { if err != nil {
return return
} }
atomic.AddUint64(&peer.tx_bytes, uint64(len(work.packet))) atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
// shift keep-alive timer // shift keep-alive timer