network: add support for specifying a host gateway
[project/unetd.git] / host.c
diff --git a/host.c b/host.c
index bef8863292cfca764343a48c0f6dc7e0006abe26..996dbcf5982de7a0bfeff9c86eca38f920297c3b 100644 (file)
--- a/host.c
+++ b/host.c
@@ -93,6 +93,7 @@ network_host_create(struct network *net, struct blob_attr *attr)
                NETWORK_HOST_SUBNET,
                NETWORK_HOST_PORT,
                NETWORK_HOST_ENDPOINT,
+               NETWORK_HOST_GATEWAY,
                __NETWORK_HOST_MAX
        };
        static const struct blobmsg_policy policy[__NETWORK_HOST_MAX] = {
@@ -102,6 +103,7 @@ network_host_create(struct network *net, struct blob_attr *attr)
                [NETWORK_HOST_SUBNET] = { "subnet", BLOBMSG_TYPE_ARRAY },
                [NETWORK_HOST_PORT] = { "port", BLOBMSG_TYPE_INT32 },
                [NETWORK_HOST_ENDPOINT] = { "endpoint", BLOBMSG_TYPE_STRING },
+               [NETWORK_HOST_GATEWAY] = { "gateway", BLOBMSG_TYPE_STRING },
        };
        struct blob_attr *tb[__NETWORK_HOST_MAX];
        struct blob_attr *cur, *ipaddr, *subnet;
@@ -109,8 +111,8 @@ network_host_create(struct network *net, struct blob_attr *attr)
        struct network_host *host;
        struct network_peer *peer;
        int ipaddr_len, subnet_len;
-       const char *name, *endpoint;
-       char *name_buf, *endpoint_buf;
+       const char *name, *endpoint, *gateway;
+       char *name_buf, *endpoint_buf, *gateway_buf;
        int rem;
 
        blobmsg_parse(policy, __NETWORK_HOST_MAX, tb, blobmsg_data(attr), blobmsg_len(attr));
@@ -133,6 +135,11 @@ network_host_create(struct network *net, struct blob_attr *attr)
        else
                endpoint = NULL;
 
+       if ((cur = tb[NETWORK_HOST_GATEWAY]) != NULL)
+               gateway = blobmsg_get_string(cur);
+       else
+               gateway = NULL;
+
        if (b64_decode(blobmsg_get_string(tb[NETWORK_HOST_KEY]), key,
                       sizeof(key)) != sizeof(key))
                return;
@@ -146,7 +153,8 @@ network_host_create(struct network *net, struct blob_attr *attr)
                        &name_buf, strlen(name) + 1,
                        &ipaddr, ipaddr_len,
                        &subnet, subnet_len,
-                       &endpoint_buf, endpoint ? strlen(endpoint) + 1 : 0);
+                       &endpoint_buf, endpoint ? strlen(endpoint) + 1 : 0,
+                       &gateway_buf, gateway ? strlen(endpoint) + 1 : 0);
        peer = &host->peer;
        if ((cur = tb[NETWORK_HOST_IPADDR]) != NULL && ipaddr_len)
                peer->ipaddr = memcpy(ipaddr, cur, ipaddr_len);
@@ -158,6 +166,8 @@ network_host_create(struct network *net, struct blob_attr *attr)
                peer->port = net->net_config.port;
        if (endpoint)
                peer->endpoint = strcpy(endpoint_buf, endpoint);
+       if (gateway)
+               host->gateway = strcpy(gateway_buf, gateway);
        memcpy(peer->key, key, sizeof(key));
        host->node.key = strcpy(name_buf, name);
 
@@ -202,17 +212,27 @@ void network_hosts_update_start(struct network *net)
 
 void network_hosts_update_done(struct network *net)
 {
-       struct network_host *host, *tmp;
+       struct network_host *local, *host, *tmp;
+       const char *local_name;
 
-       if (!net->net_config.local_host)
+       local = net->net_config.local_host;
+       if (!local)
                goto out;
 
+       local_name = network_host_name(local);
+
        if (net->net_config.local_host_changed)
-               wg_init_local(net, &net->net_config.local_host->peer);
+               wg_init_local(net, &local->peer);
 
-       avl_for_each_element(&net->hosts, host, node)
-               if (host != net->net_config.local_host)
-                       vlist_add(&net->peers, &host->peer.node, host->peer.key);
+       avl_for_each_element(&net->hosts, host, node) {
+               if (host == local)
+                       continue;
+               if (host->gateway && strcmp(host->gateway, local_name) != 0)
+                       continue;
+               if (local->gateway && strcmp(local->gateway, network_host_name(host)) != 0)
+                       continue;
+               vlist_add(&net->peers, &host->peer.node, host->peer.key);
+       }
 
 out:
        vlist_flush(&net->peers);
@@ -242,7 +262,7 @@ network_hosts_connect_cb(struct uloop_timeout *t)
        avl_for_each_element(&net->hosts, host, node) {
                struct network_peer *peer = &host->peer;
 
-               if (host == net->net_config.local_host)
+               if (!network_host_is_peer(host))
                        continue;
 
                if (peer->state.connected)