pex: add utility function to get the sockets based on type / address family
[project/unetd.git] / host.c
diff --git a/host.c b/host.c
index bef8863292cfca764343a48c0f6dc7e0006abe26..9fea39c08a10b2a2b0e7744035e02cb0af1280c0 100644 (file)
--- a/host.c
+++ b/host.c
@@ -3,9 +3,11 @@
  * Copyright (C) 2022 Felix Fietkau <nbd@nbd.name>
  */
 #include <libubox/avl-cmp.h>
+#include <libubox/blobmsg_json.h>
 #include "unetd.h"
 
 static LIST_HEAD(old_hosts);
+static struct blob_buf b;
 
 static int avl_key_cmp(const void *k1, const void *k2, void *ptr)
 {
@@ -83,37 +85,43 @@ network_host_add_group(struct network *net, struct network_host *host,
        group->members[group->n_members - 1] = host;
 }
 
+enum {
+       NETWORK_HOST_KEY,
+       NETWORK_HOST_GROUPS,
+       NETWORK_HOST_IPADDR,
+       NETWORK_HOST_SUBNET,
+       NETWORK_HOST_PORT,
+       NETWORK_HOST_PEX_PORT,
+       NETWORK_HOST_ENDPOINT,
+       NETWORK_HOST_GATEWAY,
+       __NETWORK_HOST_MAX
+};
+
+static const struct blobmsg_policy host_policy[__NETWORK_HOST_MAX] = {
+       [NETWORK_HOST_KEY] = { "key", BLOBMSG_TYPE_STRING },
+       [NETWORK_HOST_GROUPS] = { "groups", BLOBMSG_TYPE_ARRAY },
+       [NETWORK_HOST_IPADDR] = { "ipaddr", BLOBMSG_TYPE_ARRAY },
+       [NETWORK_HOST_SUBNET] = { "subnet", BLOBMSG_TYPE_ARRAY },
+       [NETWORK_HOST_PORT] = { "port", BLOBMSG_TYPE_INT32 },
+       [NETWORK_HOST_PEX_PORT] = { "peer-exchange-port", BLOBMSG_TYPE_INT32 },
+       [NETWORK_HOST_ENDPOINT] = { "endpoint", BLOBMSG_TYPE_STRING },
+       [NETWORK_HOST_GATEWAY] = { "gateway", BLOBMSG_TYPE_STRING },
+};
+
 static void
-network_host_create(struct network *net, struct blob_attr *attr)
+network_host_create(struct network *net, struct blob_attr *attr, bool dynamic)
 {
-       enum {
-               NETWORK_HOST_KEY,
-               NETWORK_HOST_GROUPS,
-               NETWORK_HOST_IPADDR,
-               NETWORK_HOST_SUBNET,
-               NETWORK_HOST_PORT,
-               NETWORK_HOST_ENDPOINT,
-               __NETWORK_HOST_MAX
-       };
-       static const struct blobmsg_policy policy[__NETWORK_HOST_MAX] = {
-               [NETWORK_HOST_KEY] = { "key", BLOBMSG_TYPE_STRING },
-               [NETWORK_HOST_GROUPS] = { "groups", BLOBMSG_TYPE_ARRAY },
-               [NETWORK_HOST_IPADDR] = { "ipaddr", BLOBMSG_TYPE_ARRAY },
-               [NETWORK_HOST_SUBNET] = { "subnet", BLOBMSG_TYPE_ARRAY },
-               [NETWORK_HOST_PORT] = { "port", BLOBMSG_TYPE_INT32 },
-               [NETWORK_HOST_ENDPOINT] = { "endpoint", BLOBMSG_TYPE_STRING },
-       };
        struct blob_attr *tb[__NETWORK_HOST_MAX];
        struct blob_attr *cur, *ipaddr, *subnet;
        uint8_t key[CURVE25519_KEY_SIZE];
-       struct network_host *host;
+       struct network_host *host = NULL;
        struct network_peer *peer;
        int ipaddr_len, subnet_len;
-       const char *name, *endpoint;
-       char *name_buf, *endpoint_buf;
+       const char *endpoint, *gateway;
+       char *endpoint_buf, *gateway_buf;
        int rem;
 
-       blobmsg_parse(policy, __NETWORK_HOST_MAX, tb, blobmsg_data(attr), blobmsg_len(attr));
+       blobmsg_parse(host_policy, __NETWORK_HOST_MAX, tb, blobmsg_data(attr), blobmsg_len(attr));
 
        if (!tb[NETWORK_HOST_KEY])
                return;
@@ -133,21 +141,50 @@ network_host_create(struct network *net, struct blob_attr *attr)
        else
                endpoint = NULL;
 
+       if (!dynamic && (cur = tb[NETWORK_HOST_GATEWAY]) != NULL)
+               gateway = blobmsg_get_string(cur);
+       else
+               gateway = NULL;
+
        if (b64_decode(blobmsg_get_string(tb[NETWORK_HOST_KEY]), key,
                       sizeof(key)) != sizeof(key))
                return;
 
-       name = blobmsg_name(attr);
-       host = avl_find_element(&net->hosts, name, host, node);
-       if (host)
-               return;
+       if (dynamic) {
+               struct network_dynamic_peer *dyn_peer;
 
-       host = calloc_a(sizeof(*host),
-                       &name_buf, strlen(name) + 1,
-                       &ipaddr, ipaddr_len,
-                       &subnet, subnet_len,
-                       &endpoint_buf, endpoint ? strlen(endpoint) + 1 : 0);
-       peer = &host->peer;
+               /* don't override/alter hosts configured via network data */
+               peer = vlist_find(&net->peers, key, peer, node);
+               if (peer && !peer->dynamic &&
+                       peer->node.version == net->peers.version)
+                       return;
+
+               dyn_peer = calloc_a(sizeof(*dyn_peer),
+                               &ipaddr, ipaddr_len,
+                               &subnet, subnet_len,
+                               &endpoint_buf, endpoint ? strlen(endpoint) + 1 : 0);
+               list_add_tail(&dyn_peer->list, &net->dynamic_peers);
+               peer = &dyn_peer->peer;
+       } else {
+               const char *name;
+               char *name_buf;
+
+               name = blobmsg_name(attr);
+               host = avl_find_element(&net->hosts, name, host, node);
+               if (host)
+                       return;
+
+               host = calloc_a(sizeof(*host),
+                               &name_buf, strlen(name) + 1,
+                               &ipaddr, ipaddr_len,
+                               &subnet, subnet_len,
+                               &endpoint_buf, endpoint ? strlen(endpoint) + 1 : 0,
+                               &gateway_buf, gateway ? strlen(endpoint) + 1 : 0);
+               host->node.key = strcpy(name_buf, name);
+               peer = &host->peer;
+       }
+
+       peer->dynamic = dynamic;
        if ((cur = tb[NETWORK_HOST_IPADDR]) != NULL && ipaddr_len)
                peer->ipaddr = memcpy(ipaddr, cur, ipaddr_len);
        if ((cur = tb[NETWORK_HOST_SUBNET]) != NULL && subnet_len)
@@ -156,16 +193,25 @@ network_host_create(struct network *net, struct blob_attr *attr)
                peer->port = blobmsg_get_u32(cur);
        else
                peer->port = net->net_config.port;
+       if ((cur = tb[NETWORK_HOST_PEX_PORT]) != NULL)
+               peer->pex_port = blobmsg_get_u32(cur);
+       else
+               peer->pex_port = net->net_config.pex_port;
        if (endpoint)
                peer->endpoint = strcpy(endpoint_buf, endpoint);
        memcpy(peer->key, key, sizeof(key));
-       host->node.key = strcpy(name_buf, name);
 
        memcpy(&peer->local_addr.network_id,
                   &net->net_config.addr.network_id,
                   sizeof(peer->local_addr.network_id));
        network_fill_host_addr(&peer->local_addr, peer->key);
 
+       if (!host)
+               return;
+
+       if (gateway)
+               host->gateway = strcpy(gateway_buf, gateway);
+
        blobmsg_for_each_attr(cur, tb[NETWORK_HOST_GROUPS], rem) {
                if (!blobmsg_check_attr(cur, false) ||
                    blobmsg_type(cur) != BLOBMSG_TYPE_STRING)
@@ -184,6 +230,72 @@ network_host_create(struct network *net, struct blob_attr *attr)
        }
 }
 
+static void
+network_hosts_load_dynamic_file(struct network *net, const char *file)
+{
+       struct blob_attr *cur;
+       int rem;
+
+       blob_buf_init(&b, 0);
+
+    if (!blobmsg_add_json_from_file(&b, file))
+               return;
+
+       blob_for_each_attr(cur, b.head, rem)
+               network_host_create(net, cur, true);
+}
+
+static void
+network_hosts_load_dynamic_peers(struct network *net)
+{
+       struct network_dynamic_peer *dyn;
+       struct blob_attr *cur;
+       int rem;
+
+       if (!net->config.peer_data)
+               return;
+
+       blobmsg_for_each_attr(cur, net->config.peer_data, rem)
+               network_hosts_load_dynamic_file(net, blobmsg_get_string(cur));
+
+       blob_buf_free(&b);
+
+       list_for_each_entry(dyn, &net->dynamic_peers, list)
+               vlist_add(&net->peers, &dyn->peer.node, &dyn->peer.key);
+}
+
+static void
+network_host_free_dynamic_peers(struct list_head *list)
+{
+       struct network_dynamic_peer *dyn, *dyn_tmp;
+
+       list_for_each_entry_safe(dyn, dyn_tmp, list, list) {
+               list_del(&dyn->list);
+               free(dyn);
+       }
+}
+
+void network_hosts_reload_dynamic_peers(struct network *net)
+{
+       struct network_peer *peer;
+       LIST_HEAD(old_entries);
+
+       if (!net->config.peer_data)
+               return;
+
+       list_splice_init(&net->dynamic_peers, &old_entries);
+
+       vlist_for_each_element(&net->peers, peer, node)
+               if (peer->dynamic)
+                       peer->node.version = net->peers.version - 1;
+
+       network_hosts_load_dynamic_peers(net);
+
+       vlist_flush(&net->peers);
+
+       network_host_free_dynamic_peers(&old_entries);
+}
+
 void network_hosts_update_start(struct network *net)
 {
        struct network_host *host, *htmp;
@@ -200,34 +312,60 @@ void network_hosts_update_start(struct network *net)
        vlist_update(&net->peers);
 }
 
-void network_hosts_update_done(struct network *net)
+static void
+__network_hosts_update_done(struct network *net, bool free_net)
 {
-       struct network_host *host, *tmp;
+       struct network_host *local, *host, *tmp;
+       LIST_HEAD(old_dynamic);
+       const char *local_name;
+
+       list_splice_init(&net->dynamic_peers, &old_dynamic);
+       if (free_net)
+               goto out;
 
-       if (!net->net_config.local_host)
+       local = net->net_config.local_host;
+       if (!local)
                goto out;
 
+       local_name = network_host_name(local);
+
        if (net->net_config.local_host_changed)
-               wg_init_local(net, &net->net_config.local_host->peer);
+               wg_init_local(net, &local->peer);
 
-       avl_for_each_element(&net->hosts, host, node)
-               if (host != net->net_config.local_host)
-                       vlist_add(&net->peers, &host->peer.node, host->peer.key);
+       avl_for_each_element(&net->hosts, host, node) {
+               if (host == local)
+                       continue;
+               if (host->gateway && strcmp(host->gateway, local_name) != 0)
+                       continue;
+               if (local->gateway && strcmp(local->gateway, network_host_name(host)) != 0)
+                       continue;
+               vlist_add(&net->peers, &host->peer.node, host->peer.key);
+       }
+
+       network_hosts_load_dynamic_peers(net);
 
 out:
        vlist_flush(&net->peers);
 
+       network_host_free_dynamic_peers(&old_dynamic);
+
        list_for_each_entry_safe(host, tmp, &old_hosts, node.list) {
                list_del(&host->node.list);
                free(host);
        }
 }
 
+void network_hosts_update_done(struct network *net)
+{
+       return __network_hosts_update_done(net, false);
+}
+
 static void
 network_hosts_connect_cb(struct uloop_timeout *t)
 {
        struct network *net = container_of(t, struct network, connect_timer);
        struct network_host *host;
+       struct network_peer *peer;
        union network_endpoint *ep;
 
        avl_for_each_element(&net->hosts, host, node)
@@ -239,12 +377,7 @@ network_hosts_connect_cb(struct uloop_timeout *t)
 
        wg_peer_refresh(net);
 
-       avl_for_each_element(&net->hosts, host, node) {
-               struct network_peer *peer = &host->peer;
-
-               if (host == net->net_config.local_host)
-                       continue;
-
+       vlist_for_each_element(&net->peers, peer, node) {
                if (peer->state.connected)
                        continue;
 
@@ -274,11 +407,12 @@ void network_hosts_add(struct network *net, struct blob_attr *hosts)
        int rem;
 
        blobmsg_for_each_attr(cur, hosts, rem)
-               network_host_create(net, cur);
+               network_host_create(net, cur, false);
 }
 
 void network_hosts_init(struct network *net)
 {
+       INIT_LIST_HEAD(&net->dynamic_peers);
        avl_init(&net->hosts, avl_strcmp, false, NULL);
        vlist_init(&net->peers, avl_key_cmp, network_peer_update);
        avl_init(&net->groups, avl_strcmp, false, NULL);
@@ -289,5 +423,5 @@ void network_hosts_free(struct network *net)
 {
        uloop_timeout_cancel(&net->connect_timer);
        network_hosts_update_start(net);
-       network_hosts_update_done(net);
+       __network_hosts_update_done(net, true);
 }