pex: add support for figuring out the external data port via STUN servers
[project/unetd.git] / pex.c
diff --git a/pex.c b/pex.c
index c8b073104029f060d1c4b17e8aabfba397affc24..3f28f5137bcc9979e106469de17861ae0e222ba5 100644 (file)
--- a/pex.c
+++ b/pex.c
@@ -166,7 +166,7 @@ network_pex_handle_endpoint_change(struct network *net, struct network_peer *pee
 }
 
 static void
-network_pex_host_request_update(struct network *net, struct network_pex_host *host)
+network_pex_host_send_endpoint_notify(struct network *net, struct network_pex_host *host)
 {
        union {
                struct {
@@ -179,32 +179,10 @@ network_pex_host_request_update(struct network *net, struct network_pex_host *ho
                } ipv6;
        } packet = {};
        struct udphdr *udp;
-       char addrstr[INET6_ADDRSTRLEN];
        union network_endpoint dest_ep;
        union network_addr local_addr = {};
-       uint64_t version = 0;
        int len;
 
-       if (net->net_data_len)
-               version = net->net_data_version;
-
-       D("request network data from host %s",
-         inet_ntop(host->endpoint.sa.sa_family,
-                   (host->endpoint.sa.sa_family == AF_INET6 ?
-                    (const void *)&host->endpoint.in6.sin6_addr :
-                    (const void *)&host->endpoint.in.sin_addr),
-                   addrstr, sizeof(addrstr)));
-
-       if (!pex_msg_update_request_init(net->config.pubkey, net->config.key,
-                                        net->config.auth_key, &host->endpoint,
-                                        version, true))
-               return;
-
-       __pex_msg_send(-1, &host->endpoint, NULL, 0);
-
-       if (!net->net_config.local_host)
-               return;
-
        pex_msg_init_ext(net, PEX_MSG_ENDPOINT_NOTIFY, true);
 
        memcpy(&dest_ep, &host->endpoint, sizeof(dest_ep));
@@ -252,6 +230,53 @@ network_pex_host_request_update(struct network *net, struct network_pex_host *ho
                D_NET(net, "pex_msg_send_raw failed: %s", strerror(errno));
 }
 
+
+static void
+network_pex_host_send_port_notify(struct network *net, struct network_pex_host *host)
+{
+       struct pex_endpoint_port_notify *data;
+
+       if (!net->stun.port_ext)
+               return;
+
+       pex_msg_init_ext(net, PEX_MSG_ENDPOINT_PORT_NOTIFY, true);
+
+       data = pex_msg_append(sizeof(*data));
+       data->port = htons(net->stun.port_ext);
+
+       __pex_msg_send(-1, &host->endpoint, NULL, 0);
+}
+
+static void
+network_pex_host_request_update(struct network *net, struct network_pex_host *host)
+{
+       char addrstr[INET6_ADDRSTRLEN];
+       uint64_t version = 0;
+
+       if (net->net_data_len)
+               version = net->net_data_version;
+
+       D("request network data from host %s",
+         inet_ntop(host->endpoint.sa.sa_family,
+                   (host->endpoint.sa.sa_family == AF_INET6 ?
+                    (const void *)&host->endpoint.in6.sin6_addr :
+                    (const void *)&host->endpoint.in.sin_addr),
+                   addrstr, sizeof(addrstr)));
+
+       if (!pex_msg_update_request_init(net->config.pubkey, net->config.key,
+                                        net->config.auth_key, &host->endpoint,
+                                        version, true))
+               return;
+
+       __pex_msg_send(-1, &host->endpoint, NULL, 0);
+
+       if (!net->net_config.local_host)
+               return;
+
+       network_pex_host_send_port_notify(net, host);
+       network_pex_host_send_endpoint_notify(net, host);
+}
+
 static void
 network_pex_request_update_cb(struct uloop_timeout *t)
 {
@@ -300,9 +325,8 @@ network_pex_query_hosts(struct network *net)
                struct network_peer *peer = &host->peer;
                void *id;
 
-               if (host == net->net_config.local_host ||
-                   peer->state.connected ||
-                   peer->endpoint)
+               if ((net->stun.port_ext && host == net->net_config.local_host) ||
+                   peer->state.connected || peer->endpoint)
                        continue;
 
                id = pex_msg_append(PEX_ID_LEN);
@@ -434,11 +458,13 @@ network_pex_recv_peers(struct network *net, struct network_peer *peer,
                void *addr;
                int len;
 
-               cur = pex_msg_peer(net, data->peer_id);
-               if (!cur)
+               if (!memcmp(data->peer_id, &local->key, PEX_ID_LEN)) {
+                       network_stun_update_port(net, false, ntohs(data->port));
                        continue;
+               }
 
-               if (cur == peer || cur == local)
+               cur = pex_msg_peer(net, data->peer_id);
+               if (!cur || cur == peer)
                        continue;
 
                D_PEER(net, peer, "received peer address for %s",
@@ -863,6 +889,11 @@ global_pex_recv(void *msg, size_t msg_len, struct sockaddr_in6 *addr)
        void *data;
        int addr_len;
 
+       if (stun_msg_is_valid(msg, msg_len)) {
+               avl_for_each_element(&networks, net, node)
+                       network_stun_rx_packet(net, msg, msg_len);
+       }
+
        hdr = pex_rx_accept(msg, msg_len, true);
        if (!hdr)
                return;
@@ -899,6 +930,9 @@ global_pex_recv(void *msg, size_t msg_len, struct sockaddr_in6 *addr)
        case PEX_MSG_UPDATE_RESPONSE_NO_DATA:
                network_pex_recv_update_response(net, data, hdr->len, addr, hdr->opcode);
                break;
+       case PEX_MSG_ENDPOINT_PORT_NOTIFY:
+               if (hdr->len < sizeof(struct pex_endpoint_port_notify))
+                       break;
        case PEX_MSG_ENDPOINT_NOTIFY:
                peer = pex_msg_peer(net, hdr->id);
                if (!peer)
@@ -909,6 +943,11 @@ global_pex_recv(void *msg, size_t msg_len, struct sockaddr_in6 *addr)
                            buf, sizeof(buf)));
 
                memcpy(&peer->state.next_endpoint, addr, sizeof(*addr));
+               if (hdr->opcode == PEX_MSG_ENDPOINT_PORT_NOTIFY) {
+                       struct pex_endpoint_port_notify *port = data;
+
+                       peer->state.next_endpoint.in.sin_port = port->port;
+               }
                break;
        }
 }