wg: store tail pointer to make coalescing peers fast

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2017-10-10 17:17:43 +02:00
parent e13b1e719b
commit 4e0e99c74d
2 changed files with 70 additions and 74 deletions

View file

@ -46,7 +46,7 @@ struct wgpeer {
uint64_t rx_bytes, tx_bytes; uint64_t rx_bytes, tx_bytes;
uint16_t persistent_keepalive_interval; uint16_t persistent_keepalive_interval;
struct wgallowedip *first_allowedip; struct wgallowedip *first_allowedip, *last_allowedip;
struct wgpeer *next_peer; struct wgpeer *next_peer;
}; };
@ -73,7 +73,7 @@ struct wgdevice {
uint32_t fwmark; uint32_t fwmark;
uint16_t listen_port; uint16_t listen_port;
struct wgpeer *first_peer; struct wgpeer *first_peer, *last_peer;
}; };
#define for_each_wgpeer(__dev, __peer) for ((__peer) = (__dev)->first_peer; (__peer); (__peer) = (__peer)->next_peer) #define for_each_wgpeer(__dev, __peer) for ((__peer) = (__dev)->first_peer; (__peer); (__peer) = (__peer)->next_peer)

140
src/ipc.c
View file

@ -678,30 +678,24 @@ out:
return ret; return ret;
} }
struct get_device_ctx {
struct wgdevice *device;
struct wgpeer *peer;
struct wgallowedip *allowedip;
};
static int parse_allowedip(const struct nlattr *attr, void *data) static int parse_allowedip(const struct nlattr *attr, void *data)
{ {
struct get_device_ctx *ctx = data; struct wgallowedip *allowedip = data;
switch (mnl_attr_get_type(attr)) { switch (mnl_attr_get_type(attr)) {
case WGALLOWEDIP_A_FAMILY: case WGALLOWEDIP_A_FAMILY:
if (!mnl_attr_validate(attr, MNL_TYPE_U16)) if (!mnl_attr_validate(attr, MNL_TYPE_U16))
ctx->allowedip->family = mnl_attr_get_u16(attr); allowedip->family = mnl_attr_get_u16(attr);
break; break;
case WGALLOWEDIP_A_IPADDR: case WGALLOWEDIP_A_IPADDR:
if (mnl_attr_get_payload_len(attr) == sizeof(ctx->allowedip->ip4)) if (mnl_attr_get_payload_len(attr) == sizeof(allowedip->ip4))
memcpy(&ctx->allowedip->ip4, mnl_attr_get_payload(attr), sizeof(ctx->allowedip->ip4)); memcpy(&allowedip->ip4, mnl_attr_get_payload(attr), sizeof(allowedip->ip4));
else if (mnl_attr_get_payload_len(attr) == sizeof(ctx->allowedip->ip6)) else if (mnl_attr_get_payload_len(attr) == sizeof(allowedip->ip6))
memcpy(&ctx->allowedip->ip6, mnl_attr_get_payload(attr), sizeof(ctx->allowedip->ip6)); memcpy(&allowedip->ip6, mnl_attr_get_payload(attr), sizeof(allowedip->ip6));
break; break;
case WGALLOWEDIP_A_CIDR_MASK: case WGALLOWEDIP_A_CIDR_MASK:
if (!mnl_attr_validate(attr, MNL_TYPE_U8)) if (!mnl_attr_validate(attr, MNL_TYPE_U8))
ctx->allowedip->cidr = mnl_attr_get_u8(attr); allowedip->cidr = mnl_attr_get_u8(attr);
break; break;
default: default:
warn_unrecognized("netlink"); warn_unrecognized("netlink");
@ -712,68 +706,70 @@ static int parse_allowedip(const struct nlattr *attr, void *data)
static int parse_allowedips(const struct nlattr *attr, void *data) static int parse_allowedips(const struct nlattr *attr, void *data)
{ {
struct get_device_ctx *ctx = data; struct wgpeer *peer = data;
struct wgallowedip *new_allowedip = calloc(1, sizeof(struct wgallowedip)); struct wgallowedip *new_allowedip = calloc(1, sizeof(struct wgallowedip));
int ret; int ret;
if (!new_allowedip) { if (!new_allowedip) {
perror("calloc"); perror("calloc");
return MNL_CB_ERROR; return MNL_CB_ERROR;
} }
if (ctx->allowedip) if (!peer->first_allowedip)
ctx->allowedip->next_allowedip = new_allowedip; peer->first_allowedip = peer->last_allowedip = new_allowedip;
else else {
ctx->peer->first_allowedip = new_allowedip; peer->last_allowedip->next_allowedip = new_allowedip;
ctx->allowedip = new_allowedip; peer->last_allowedip = new_allowedip;
ret = mnl_attr_parse_nested(attr, parse_allowedip, ctx); }
ret = mnl_attr_parse_nested(attr, parse_allowedip, new_allowedip);
if (!ret) if (!ret)
return ret; return ret;
if (!((ctx->allowedip->family == AF_INET && ctx->allowedip->cidr <= 32) || (ctx->allowedip->family == AF_INET6 && ctx->allowedip->cidr <= 128))) if (!((new_allowedip->family == AF_INET && new_allowedip->cidr <= 32) || (new_allowedip->family == AF_INET6 && new_allowedip->cidr <= 128)))
return MNL_CB_ERROR; return MNL_CB_ERROR;
return MNL_CB_OK; return MNL_CB_OK;
} }
static int parse_peer(const struct nlattr *attr, void *data) static int parse_peer(const struct nlattr *attr, void *data)
{ {
struct get_device_ctx *ctx = data; struct wgpeer *peer = data;
switch (mnl_attr_get_type(attr)) { switch (mnl_attr_get_type(attr)) {
case WGPEER_A_PUBLIC_KEY: case WGPEER_A_PUBLIC_KEY:
if (mnl_attr_get_payload_len(attr) == sizeof(ctx->peer->public_key)) if (mnl_attr_get_payload_len(attr) == sizeof(peer->public_key))
memcpy(ctx->peer->public_key, mnl_attr_get_payload(attr), sizeof(ctx->peer->public_key)); memcpy(peer->public_key, mnl_attr_get_payload(attr), sizeof(peer->public_key));
break; break;
case WGPEER_A_PRESHARED_KEY: case WGPEER_A_PRESHARED_KEY:
if (mnl_attr_get_payload_len(attr) == sizeof(ctx->peer->preshared_key)) if (mnl_attr_get_payload_len(attr) == sizeof(peer->preshared_key))
memcpy(ctx->peer->preshared_key, mnl_attr_get_payload(attr), sizeof(ctx->peer->preshared_key)); memcpy(peer->preshared_key, mnl_attr_get_payload(attr), sizeof(peer->preshared_key));
break; break;
case WGPEER_A_ENDPOINT: { case WGPEER_A_ENDPOINT: {
struct sockaddr *addr; struct sockaddr *addr;
if (mnl_attr_get_payload_len(attr) < sizeof(*addr)) if (mnl_attr_get_payload_len(attr) < sizeof(*addr))
break; break;
addr = mnl_attr_get_payload(attr); addr = mnl_attr_get_payload(attr);
if (addr->sa_family == AF_INET && mnl_attr_get_payload_len(attr) == sizeof(ctx->peer->endpoint.addr4)) if (addr->sa_family == AF_INET && mnl_attr_get_payload_len(attr) == sizeof(peer->endpoint.addr4))
memcpy(&ctx->peer->endpoint.addr4, addr, sizeof(ctx->peer->endpoint.addr4)); memcpy(&peer->endpoint.addr4, addr, sizeof(peer->endpoint.addr4));
else if (addr->sa_family == AF_INET6 && mnl_attr_get_payload_len(attr) == sizeof(ctx->peer->endpoint.addr6)) else if (addr->sa_family == AF_INET6 && mnl_attr_get_payload_len(attr) == sizeof(peer->endpoint.addr6))
memcpy(&ctx->peer->endpoint.addr6, addr, sizeof(ctx->peer->endpoint.addr6)); memcpy(&peer->endpoint.addr6, addr, sizeof(peer->endpoint.addr6));
break; break;
} }
case WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL: case WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL:
if (!mnl_attr_validate(attr, MNL_TYPE_U16)) if (!mnl_attr_validate(attr, MNL_TYPE_U16))
ctx->peer->persistent_keepalive_interval = mnl_attr_get_u16(attr); peer->persistent_keepalive_interval = mnl_attr_get_u16(attr);
break; break;
case WGPEER_A_LAST_HANDSHAKE_TIME: case WGPEER_A_LAST_HANDSHAKE_TIME:
if (mnl_attr_get_payload_len(attr) == sizeof(ctx->peer->last_handshake_time)) if (mnl_attr_get_payload_len(attr) == sizeof(peer->last_handshake_time))
memcpy(&ctx->peer->last_handshake_time, mnl_attr_get_payload(attr), sizeof(ctx->peer->last_handshake_time)); memcpy(&peer->last_handshake_time, mnl_attr_get_payload(attr), sizeof(peer->last_handshake_time));
break; break;
case WGPEER_A_RX_BYTES: case WGPEER_A_RX_BYTES:
if (!mnl_attr_validate(attr, MNL_TYPE_U64)) if (!mnl_attr_validate(attr, MNL_TYPE_U64))
ctx->peer->rx_bytes = mnl_attr_get_u64(attr); peer->rx_bytes = mnl_attr_get_u64(attr);
break; break;
case WGPEER_A_TX_BYTES: case WGPEER_A_TX_BYTES:
if (!mnl_attr_validate(attr, MNL_TYPE_U64)) if (!mnl_attr_validate(attr, MNL_TYPE_U64))
ctx->peer->tx_bytes = mnl_attr_get_u64(attr); peer->tx_bytes = mnl_attr_get_u64(attr);
break; break;
case WGPEER_A_ALLOWEDIPS: case WGPEER_A_ALLOWEDIPS:
return mnl_attr_parse_nested(attr, parse_allowedips, ctx); return mnl_attr_parse_nested(attr, parse_allowedips, peer);
default: default:
warn_unrecognized("netlink"); warn_unrecognized("netlink");
} }
@ -783,58 +779,59 @@ static int parse_peer(const struct nlattr *attr, void *data)
static int parse_peers(const struct nlattr *attr, void *data) static int parse_peers(const struct nlattr *attr, void *data)
{ {
struct get_device_ctx *ctx = data; struct wgdevice *device = data;
struct wgpeer *new_peer = calloc(1, sizeof(struct wgpeer)); struct wgpeer *new_peer = calloc(1, sizeof(struct wgpeer));
int ret; int ret;
if (!new_peer) { if (!new_peer) {
perror("calloc"); perror("calloc");
return MNL_CB_ERROR; return MNL_CB_ERROR;
} }
if (ctx->peer) if (!device->first_peer)
ctx->peer->next_peer = new_peer; device->first_peer = device->last_peer = new_peer;
else else {
ctx->device->first_peer = new_peer; device->last_peer->next_peer = new_peer;
ctx->peer = new_peer; device->last_peer = new_peer;
ctx->allowedip = NULL; }
ret = mnl_attr_parse_nested(attr, parse_peer, ctx); ret = mnl_attr_parse_nested(attr, parse_peer, new_peer);
if (!ret) if (!ret)
return ret; return ret;
if (key_is_zero(ctx->peer->public_key)) if (key_is_zero(new_peer->public_key))
return MNL_CB_ERROR; return MNL_CB_ERROR;
return MNL_CB_OK; return MNL_CB_OK;
} }
static int parse_device(const struct nlattr *attr, void *data) static int parse_device(const struct nlattr *attr, void *data)
{ {
struct get_device_ctx *ctx = data; struct wgdevice *device = data;
switch (mnl_attr_get_type(attr)) { switch (mnl_attr_get_type(attr)) {
case WGDEVICE_A_IFINDEX: case WGDEVICE_A_IFINDEX:
if (!mnl_attr_validate(attr, MNL_TYPE_U32)) if (!mnl_attr_validate(attr, MNL_TYPE_U32))
ctx->device->ifindex = mnl_attr_get_u32(attr); device->ifindex = mnl_attr_get_u32(attr);
break; break;
case WGDEVICE_A_IFNAME: case WGDEVICE_A_IFNAME:
if (!mnl_attr_validate(attr, MNL_TYPE_STRING)) if (!mnl_attr_validate(attr, MNL_TYPE_STRING))
strncpy(ctx->device->name, mnl_attr_get_str(attr), sizeof(ctx->device->name) - 1); strncpy(device->name, mnl_attr_get_str(attr), sizeof(device->name) - 1);
break; break;
case WGDEVICE_A_PRIVATE_KEY: case WGDEVICE_A_PRIVATE_KEY:
if (mnl_attr_get_payload_len(attr) == sizeof(ctx->device->private_key)) if (mnl_attr_get_payload_len(attr) == sizeof(device->private_key))
memcpy(ctx->device->private_key, mnl_attr_get_payload(attr), sizeof(ctx->device->private_key)); memcpy(device->private_key, mnl_attr_get_payload(attr), sizeof(device->private_key));
break; break;
case WGDEVICE_A_PUBLIC_KEY: case WGDEVICE_A_PUBLIC_KEY:
if (mnl_attr_get_payload_len(attr) == sizeof(ctx->device->public_key)) if (mnl_attr_get_payload_len(attr) == sizeof(device->public_key))
memcpy(ctx->device->public_key, mnl_attr_get_payload(attr), sizeof(ctx->device->public_key)); memcpy(device->public_key, mnl_attr_get_payload(attr), sizeof(device->public_key));
break; break;
case WGDEVICE_A_LISTEN_PORT: case WGDEVICE_A_LISTEN_PORT:
if (!mnl_attr_validate(attr, MNL_TYPE_U16)) if (!mnl_attr_validate(attr, MNL_TYPE_U16))
ctx->device->listen_port = mnl_attr_get_u16(attr); device->listen_port = mnl_attr_get_u16(attr);
break; break;
case WGDEVICE_A_FWMARK: case WGDEVICE_A_FWMARK:
if (!mnl_attr_validate(attr, MNL_TYPE_U32)) if (!mnl_attr_validate(attr, MNL_TYPE_U32))
ctx->device->fwmark = mnl_attr_get_u32(attr); device->fwmark = mnl_attr_get_u32(attr);
break; break;
case WGDEVICE_A_PEERS: case WGDEVICE_A_PEERS:
return mnl_attr_parse_nested(attr, parse_peers, ctx); return mnl_attr_parse_nested(attr, parse_peers, device);
default: default:
warn_unrecognized("netlink"); warn_unrecognized("netlink");
} }
@ -849,42 +846,41 @@ static int read_device_cb(const struct nlmsghdr *nlh, void *data)
static void coalesce_peers(struct wgdevice *device) static void coalesce_peers(struct wgdevice *device)
{ {
struct wgallowedip *allowedip;
struct wgpeer *old_next_peer, *peer = device->first_peer; struct wgpeer *old_next_peer, *peer = device->first_peer;
while (peer && peer->next_peer) { while (peer && peer->next_peer) {
if (memcmp(peer->public_key, peer->next_peer->public_key, WG_KEY_LEN)) { if (memcmp(peer->public_key, peer->next_peer->public_key, WG_KEY_LEN)) {
peer = peer->next_peer; peer = peer->next_peer;
continue; continue;
} }
/* TODO: It would be more efficient to store the tail, rather than having to seek to the end each time. */ if (!peer->first_allowedip) {
for (allowedip = peer->first_allowedip; allowedip && allowedip->next_allowedip; allowedip = allowedip->next_allowedip);
if (!allowedip)
peer->first_allowedip = peer->next_peer->first_allowedip; peer->first_allowedip = peer->next_peer->first_allowedip;
else peer->last_allowedip = peer->next_peer->last_allowedip;
allowedip->next_allowedip = peer->next_peer->first_allowedip; } else {
peer->last_allowedip->next_allowedip = peer->next_peer->first_allowedip;
peer->last_allowedip = peer->next_peer->last_allowedip;
}
old_next_peer = peer->next_peer; old_next_peer = peer->next_peer;
peer->next_peer = old_next_peer->next_peer; peer->next_peer = old_next_peer->next_peer;
free(old_next_peer); free(old_next_peer);
} }
} }
static int kernel_get_device(struct wgdevice **dev, const char *interface) static int kernel_get_device(struct wgdevice **device, const char *interface)
{ {
int ret = 0; int ret = 0;
struct nlmsghdr *nlh; struct nlmsghdr *nlh;
struct mnlg_socket *nlg; struct mnlg_socket *nlg;
struct get_device_ctx ctx = { 0 };
try_again: try_again:
*dev = ctx.device = calloc(1, sizeof(struct wgdevice)); *device = calloc(1, sizeof(struct wgdevice));
if (!*dev) if (!*device)
return -errno; return -errno;
nlg = mnlg_socket_open(WG_GENL_NAME, WG_GENL_VERSION); nlg = mnlg_socket_open(WG_GENL_NAME, WG_GENL_VERSION);
if (!nlg) { if (!nlg) {
free_wgdevice(*dev); free_wgdevice(*device);
*dev = NULL; *device = NULL;
return -errno; return -errno;
} }
@ -895,20 +891,20 @@ try_again:
goto out; goto out;
} }
errno = 0; errno = 0;
if (mnlg_socket_recv_run(nlg, read_device_cb, &ctx) < 0) { if (mnlg_socket_recv_run(nlg, read_device_cb, *device) < 0) {
ret = errno ? -errno : -EINVAL; ret = errno ? -errno : -EINVAL;
goto out; goto out;
} }
coalesce_peers(*dev); coalesce_peers(*device);
out: out:
if (nlg) if (nlg)
mnlg_socket_close(nlg); mnlg_socket_close(nlg);
if (ret) { if (ret) {
free_wgdevice(*dev); free_wgdevice(*device);
if (ret == -EINTR) if (ret == -EINTR)
goto try_again; goto try_again;
*dev = NULL; *device = NULL;
} }
errno = -ret; errno = -ret;
return ret; return ret;