wg: ipc: read from socket incrementally

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2017-01-10 04:50:42 +01:00
parent e92e0dca14
commit 12904a1095
2 changed files with 50 additions and 44 deletions

View file

@ -33,7 +33,7 @@ endif
CFLAGS ?= -O3 CFLAGS ?= -O3
CFLAGS += -std=gnu11 CFLAGS += -std=gnu11
CFLAGS += -pedantic -Wall -Wextra CFLAGS += -Wall -Wextra
CFLAGS += -MMD -MP CFLAGS += -MMD -MP
CFLAGS += -DRUNSTATEDIR="\"$(RUNSTATEDIR)\"" CFLAGS += -DRUNSTATEDIR="\"$(RUNSTATEDIR)\""
LDLIBS += -lresolv LDLIBS += -lresolv

View file

@ -18,7 +18,6 @@
#include <unistd.h> #include <unistd.h>
#include <time.h> #include <time.h>
#include <dirent.h> #include <dirent.h>
#include <poll.h>
#include <signal.h> #include <signal.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/types.h> #include <sys/types.h>
@ -41,7 +40,7 @@ struct inflatable_buffer {
size_t pos; size_t pos;
}; };
#define max(a, b) (a > b ? a : b) #define max(a, b) ((a) > (b) ? (a) : (b))
static int add_next_to_inflatable_buffer(struct inflatable_buffer *buffer) static int add_next_to_inflatable_buffer(struct inflatable_buffer *buffer)
{ {
@ -190,68 +189,75 @@ out:
return (int)ret; return (int)ret;
} }
#define READ_BYTES(bytes) ({ \
void *__p; \
size_t __bytes = (bytes); \
if (bytes_left < __bytes) { \
offset = p - buffer; \
bytes_left += buffer_size; \
buffer_size *= 2; \
ret = -ENOMEM; \
p = realloc(buffer, buffer_size); \
if (!p) \
goto out; \
buffer = p; \
p += offset; \
} \
bytes_left -= __bytes; \
ret = read(fd, p, __bytes); \
if (ret < 0) \
goto out; \
if ((size_t)ret != __bytes) { \
ret = -EBADMSG; \
goto out; \
} \
__p = p; \
p += __bytes; \
__p; \
})
static int userspace_get_device(struct wgdevice **dev, const char *interface) static int userspace_get_device(struct wgdevice **dev, const char *interface)
{ {
struct pollfd pollfd = { .events = POLLIN }; unsigned int len = 0, i;
int len; size_t buffer_size, bytes_left;
char byte = 0;
size_t i;
struct wgpeer *peer;
ssize_t ret; ssize_t ret;
ptrdiff_t offset;
uint8_t *buffer = NULL, *p, byte = 0;
int fd = userspace_interface_fd(interface); int fd = userspace_interface_fd(interface);
if (fd < 0) if (fd < 0)
return fd; return fd;
*dev = NULL;
ret = write(fd, &byte, sizeof(byte)); ret = write(fd, &byte, sizeof(byte));
if (ret < 0) if (ret < 0)
goto out; goto out;
if (ret != sizeof(byte)) {
pollfd.fd = fd; ret = -EBADMSG;
if (poll(&pollfd, 1, -1) < 0)
goto out;
ret = -ECONNABORTED;
if (!(pollfd.revents & POLLIN))
goto out;
ret = ioctl(fd, FIONREAD, &len);
if (ret < 0) {
ret = -errno;
goto out; goto out;
} }
ret = -EBADMSG;
if ((size_t)len < sizeof(struct wgdevice))
goto out;
ioctl(fd, FIONREAD, &len);
bytes_left = buffer_size = max(len, sizeof(struct wgdevice) + sizeof(struct wgpeer) + sizeof(struct wgipmask));
p = buffer = malloc(buffer_size);
ret = -ENOMEM; ret = -ENOMEM;
*dev = malloc(len); if (!buffer)
if (!*dev)
goto out; goto out;
ret = read(fd, *dev, len); len = ((struct wgdevice *)READ_BYTES(sizeof(struct wgdevice)))->num_peers;
if (ret < 0) for (i = 0; i < len; ++i)
goto out; READ_BYTES(sizeof(struct wgipmask) * ((struct wgpeer *)READ_BYTES(sizeof(struct wgpeer)))->num_ipmasks);
if (ret != len) {
ret = -EBADMSG;
goto out;
}
ret = -EBADMSG;
for_each_wgpeer(*dev, peer, i) {
if ((uint8_t *)peer + sizeof(struct wgpeer) > (uint8_t *)*dev + len)
goto out;
if ((uint8_t *)peer + sizeof(struct wgpeer) + sizeof(struct wgipmask) * peer->num_ipmasks > (uint8_t *)*dev + len)
goto out;
}
ret = 0; ret = 0;
out: out:
if (*dev && ret) { if (buffer && ret) {
free(*dev); free(buffer);
*dev = NULL; buffer = NULL;
} }
*dev = (struct wgdevice *)buffer;
close(fd); close(fd);
errno = -ret; errno = -ret;
return ret; return ret;
} }
#undef READ_BYTES
#ifdef __linux__ #ifdef __linux__
static int parse_linkinfo(const struct nlattr *attr, void *data) static int parse_linkinfo(const struct nlattr *attr, void *data)