Moved endpoint into interface and simplified peer

This commit is contained in:
Mathias Hall-Andersen 2017-11-18 23:34:02 +01:00
parent fa399a91d5
commit d10126f883
6 changed files with 101 additions and 80 deletions

View file

@ -7,26 +7,28 @@ import (
"net" "net"
) )
type UDPBind interface { /* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
*/
type Bind interface {
SetMark(value uint32) error SetMark(value uint32) error
ReceiveIPv6(buff []byte, end *Endpoint) (int, error) ReceiveIPv6(buff []byte) (int, Endpoint, error)
ReceiveIPv4(buff []byte, end *Endpoint) (int, error) ReceiveIPv4(buff []byte) (int, Endpoint, error)
Send(buff []byte, end *Endpoint) error Send(buff []byte, end Endpoint) error
Close() error Close() error
} }
/* An Endpoint maintains the source/destination caching for a peer /* An Endpoint maintains the source/destination caching for a peer
* *
* dst : the remote address of a peer * dst : the remote address of a peer ("endpoint" in uapi terminology)
* src : the local address from which datagrams originate going to the peer * src : the local address from which datagrams originate going to the peer
*
*/ */
type UDPEndpoint interface { type Endpoint interface {
ClearSrc() // clears the source address ClearSrc() // clears the source address
ClearDst() // clears the destination address ClearDst() // clears the destination address
SrcToString() string // returns the local source address (ip:port) SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port) DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations DstToBytes() []byte // used for mac2 cookie calculations
SetDst(string) error // used for manually setting the endpoint (uapi)
DstIP() net.IP DstIP() net.IP
SrcIP() net.IP SrcIP() net.IP
} }
@ -107,7 +109,9 @@ func UpdateUDPListener(device *Device) error {
for _, peer := range device.peers { for _, peer := range device.peers {
peer.mutex.Lock() peer.mutex.Lock()
peer.endpoint.value.ClearSrc() if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.mutex.Unlock() peer.mutex.Unlock()
} }

View file

@ -21,22 +21,24 @@ import (
* See e.g. https://github.com/golang/go/issues/17930 * See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent. * So this code is remains platform dependent.
*/ */
type NativeEndpoint struct {
type Endpoint struct {
src unix.RawSockaddrInet6 src unix.RawSockaddrInet6
dst unix.RawSockaddrInet6 dst unix.RawSockaddrInet6
} }
type IPv4Source struct {
src unix.RawSockaddrInet4
Ifindex int32
}
type NativeBind struct { type NativeBind struct {
sock4 int sock4 int
sock6 int sock6 int
} }
var _ Endpoint = (*NativeEndpoint)(nil)
var _ Bind = NativeBind{}
type IPv4Source struct {
src unix.RawSockaddrInet4
Ifindex int32
}
func htons(val uint16) uint16 { func htons(val uint16) uint16 {
var out [unsafe.Sizeof(val)]byte var out [unsafe.Sizeof(val)]byte
binary.BigEndian.PutUint16(out[:], val) binary.BigEndian.PutUint16(out[:], val)
@ -48,7 +50,11 @@ func ntohs(val uint16) uint16 {
return binary.BigEndian.Uint16((*tmp)[:]) return binary.BigEndian.Uint16((*tmp)[:])
} }
func CreateUDPBind(port uint16) (UDPBind, uint16, error) { func NewEndpoint() Endpoint {
return &NativeEndpoint{}
}
func CreateUDPBind(port uint16) (Bind, uint16, error) {
var err error var err error
var bind NativeBind var bind NativeBind
@ -99,28 +105,33 @@ func (bind NativeBind) Close() error {
return err2 return err2
} }
func (bind NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
return receive6( var end NativeEndpoint
n, err := receive6(
bind.sock6, bind.sock6,
buff, buff,
end, &end,
) )
return n, &end, err
} }
func (bind NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
return receive4( var end NativeEndpoint
n, err := receive4(
bind.sock4, bind.sock4,
buff, buff,
end, &end,
) )
return n, &end, err
} }
func (bind NativeBind) Send(buff []byte, end *Endpoint) error { func (bind NativeBind) Send(buff []byte, end Endpoint) error {
switch end.dst.Family { nend := end.(*NativeEndpoint)
switch nend.dst.Family {
case unix.AF_INET6: case unix.AF_INET6:
return send6(bind.sock6, end, buff) return send6(bind.sock6, nend, buff)
case unix.AF_INET: case unix.AF_INET:
return send4(bind.sock4, end, buff) return send4(bind.sock4, nend, buff)
default: default:
return errors.New("Unknown address family of destination") return errors.New("Unknown address family of destination")
} }
@ -151,12 +162,12 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string {
} }
} }
func (end *Endpoint) DstIP() net.IP { func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP {
switch end.dst.Family { switch addr.Family {
case unix.AF_INET6: case unix.AF_INET6:
return end.dst.Addr[:] return addr.Addr[:]
case unix.AF_INET: case unix.AF_INET:
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
return net.IPv4( return net.IPv4(
ptr.Addr[0], ptr.Addr[0],
ptr.Addr[1], ptr.Addr[1],
@ -168,25 +179,33 @@ func (end *Endpoint) DstIP() net.IP {
} }
} }
func (end *Endpoint) DstToBytes() []byte { func (end *NativeEndpoint) SrcIP() net.IP {
return rawAddrToIP(end.src)
}
func (end *NativeEndpoint) DstIP() net.IP {
return rawAddrToIP(end.dst)
}
func (end *NativeEndpoint) DstToBytes() []byte {
ptr := unsafe.Pointer(&end.src) ptr := unsafe.Pointer(&end.src)
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
return arr[:] return arr[:]
} }
func (end *Endpoint) SrcToString() string { func (end *NativeEndpoint) SrcToString() string {
return sockaddrToString(end.src) return sockaddrToString(end.src)
} }
func (end *Endpoint) DstToString() string { func (end *NativeEndpoint) DstToString() string {
return sockaddrToString(end.dst) return sockaddrToString(end.dst)
} }
func (end *Endpoint) ClearDst() { func (end *NativeEndpoint) ClearDst() {
end.dst = unix.RawSockaddrInet6{} end.dst = unix.RawSockaddrInet6{}
} }
func (end *Endpoint) ClearSrc() { func (end *NativeEndpoint) ClearSrc() {
end.src = unix.RawSockaddrInet6{} end.src = unix.RawSockaddrInet6{}
} }
@ -306,7 +325,7 @@ func create6(port uint16) (int, uint16, error) {
return fd, uint16(addr.Port), err return fd, uint16(addr.Port), err
} }
func (end *Endpoint) SetDst(s string) error { func (end *NativeEndpoint) SetDst(s string) error {
addr, err := parseEndpoint(s) addr, err := parseEndpoint(s)
if err != nil { if err != nil {
return err return err
@ -342,7 +361,7 @@ func (end *Endpoint) SetDst(s string) error {
return errors.New("Failed to recognize IP address format") return errors.New("Failed to recognize IP address format")
} }
func send6(sock int, end *Endpoint, buff []byte) error { func send6(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header // construct message header
@ -404,7 +423,7 @@ func send6(sock int, end *Endpoint, buff []byte) error {
return errno return errno
} }
func send4(sock int, end *Endpoint, buff []byte) error { func send4(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header // construct message header
@ -470,7 +489,7 @@ func send4(sock int, end *Endpoint, buff []byte) error {
return errno return errno
} }
func receive4(sock int, buff []byte, end *Endpoint) (int, error) { func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header // contruct message header
@ -518,7 +537,7 @@ func receive4(sock int, buff []byte, end *Endpoint) (int, error) {
return int(size), nil return int(size), nil
} }
func receive6(sock int, buff []byte, end *Endpoint) (int, error) { func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header // contruct message header

View file

@ -22,7 +22,7 @@ type Device struct {
} }
net struct { net struct {
mutex sync.RWMutex mutex sync.RWMutex
bind UDPBind // bind interface bind Bind // bind interface
port uint16 // listening port port uint16 // listening port
fwmark uint32 // mark value (0 = disabled) fwmark uint32 // mark value (0 = disabled)
} }

View file

@ -15,10 +15,7 @@ type Peer struct {
keyPairs KeyPairs keyPairs KeyPairs
handshake Handshake handshake Handshake
device *Device device *Device
endpoint struct { endpoint Endpoint
set bool // has a known endpoint been discovered
value Endpoint // source / destination cache
}
stats struct { stats struct {
txBytes uint64 // bytes send to peer (endpoint) txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer rxBytes uint64 // bytes received from peer
@ -110,9 +107,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// reset endpoint // reset endpoint
peer.endpoint.set = false peer.endpoint = nil
peer.endpoint.value.ClearDst()
peer.endpoint.value.ClearSrc()
// prepare queuing // prepare queuing
@ -143,16 +138,16 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
defer peer.device.net.mutex.RUnlock() defer peer.device.net.mutex.RUnlock()
peer.mutex.RLock() peer.mutex.RLock()
defer peer.mutex.RUnlock() defer peer.mutex.RUnlock()
if !peer.endpoint.set { if peer.endpoint == nil {
return errors.New("No known endpoint for peer") return errors.New("No known endpoint for peer")
} }
return peer.device.net.bind.Send(buffer, &peer.endpoint.value) return peer.device.net.bind.Send(buffer, peer.endpoint)
} }
/* Returns a short string identification for logging /* Returns a short string identification for logging
*/ */
func (peer *Peer) String() string { func (peer *Peer) String() string {
if !peer.endpoint.set { if peer.endpoint == nil {
return fmt.Sprintf( return fmt.Sprintf(
"peer(%d unknown %s)", "peer(%d unknown %s)",
peer.id, peer.id,
@ -162,7 +157,7 @@ func (peer *Peer) String() string {
return fmt.Sprintf( return fmt.Sprintf(
"peer(%d %s %s)", "peer(%d %s %s)",
peer.id, peer.id,
peer.endpoint.value.DstToString(), peer.endpoint.DstToString(),
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
) )
} }

View file

@ -93,7 +93,7 @@ func (device *Device) addToHandshakeQueue(
} }
} }
func (device *Device) RoutineReceiveIncomming(IP int, bind UDPBind) { func (device *Device) RoutineReceiveIncomming(IP int, bind Bind) {
logDebug := device.log.Debug logDebug := device.log.Debug
logDebug.Println("Routine, receive incomming, IP version:", IP) logDebug.Println("Routine, receive incomming, IP version:", IP)
@ -104,20 +104,21 @@ func (device *Device) RoutineReceiveIncomming(IP int, bind UDPBind) {
buffer := device.GetMessageBuffer() buffer := device.GetMessageBuffer()
var size int var (
var err error err error
size int
endpoint Endpoint
)
for { for {
// read next datagram // read next datagram
var endpoint Endpoint
switch IP { switch IP {
case ipv4.Version: case ipv4.Version:
size, err = bind.ReceiveIPv4(buffer[:], &endpoint) size, endpoint, err = bind.ReceiveIPv4(buffer[:])
case ipv6.Version: case ipv6.Version:
size, err = bind.ReceiveIPv6(buffer[:], &endpoint) size, endpoint, err = bind.ReceiveIPv6(buffer[:])
default: default:
return return
} }
@ -339,10 +340,7 @@ func (device *Device) RoutineHandshake() {
writer := bytes.NewBuffer(temp[:0]) writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, reply) binary.Write(writer, binary.LittleEndian, reply)
device.net.bind.Send( device.net.bind.Send(writer.Bytes(), elem.endpoint)
writer.Bytes(),
&elem.endpoint,
)
if err != nil { if err != nil {
logDebug.Println("Failed to send cookie reply:", err) logDebug.Println("Failed to send cookie reply:", err)
} }
@ -395,8 +393,7 @@ func (device *Device) RoutineHandshake() {
// update endpoint // update endpoint
peer.mutex.Lock() peer.mutex.Lock()
peer.endpoint.set = true peer.endpoint = elem.endpoint
peer.endpoint.value = elem.endpoint
peer.mutex.Unlock() peer.mutex.Unlock()
// create response // create response
@ -452,8 +449,7 @@ func (device *Device) RoutineHandshake() {
// update endpoint // update endpoint
peer.mutex.Lock() peer.mutex.Lock()
peer.endpoint.set = true peer.endpoint = elem.endpoint
peer.endpoint.value = elem.endpoint
peer.mutex.Unlock() peer.mutex.Unlock()
logDebug.Println("Received handshake initation from", peer) logDebug.Println("Received handshake initation from", peer)
@ -527,8 +523,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// update endpoint // update endpoint
peer.mutex.Lock() peer.mutex.Lock()
peer.endpoint.set = true peer.endpoint = elem.endpoint
peer.endpoint.value = elem.endpoint
peer.mutex.Unlock() peer.mutex.Unlock()
// check for keep-alive // check for keep-alive

View file

@ -53,8 +53,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
defer peer.mutex.RUnlock() defer peer.mutex.RUnlock()
send("public_key=" + peer.handshake.remoteStatic.ToHex()) send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex()) send("preshared_key=" + peer.handshake.presharedKey.ToHex())
if peer.endpoint.set { if peer.endpoint != nil {
send("endpoint=" + peer.endpoint.value.DstToString()) send("endpoint=" + peer.endpoint.DstToString())
} }
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
@ -255,17 +255,25 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "endpoint": case "endpoint":
// set endpoint destination and reset handshake timer // set endpoint destination
err := func() error {
peer.mutex.Lock() peer.mutex.Lock()
err := peer.endpoint.value.SetDst(value) defer peer.mutex.Unlock()
peer.endpoint.set = (err == nil)
peer.mutex.Unlock() endpoint := NewEndpoint()
if err := endpoint.SetDst(value); err != nil {
return err
}
peer.endpoint = endpoint
signalSend(peer.signal.handshakeReset)
return nil
}()
if err != nil { if err != nil {
logError.Println("Failed to set endpoint:", value) logError.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
signalSend(peer.signal.handshakeReset)
case "persistent_keepalive_interval": case "persistent_keepalive_interval":