pex: add support for sending endpoint notification from the wg port via raw socket
[project/unetd.git] / pex.c
diff --git a/pex.c b/pex.c
index 62a30f48428822230739acc38a79240ccf8cc655..839567dd2763a7b4ec1c339b10e34eba8a7e87c4 100644 (file)
--- a/pex.c
+++ b/pex.c
@@ -5,6 +5,10 @@
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <arpa/inet.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/ip6.h>
+#include <netinet/udp.h>
 #include <fcntl.h>
 #include <stdlib.h>
 #include <inttypes.h>
@@ -70,7 +74,7 @@ static void pex_msg_send(struct network *net, struct network_peer *peer)
                return;
 
        pex_get_peer_addr(&sin6, net, peer);
-       if (__pex_msg_send(net->pex.fd.fd, &sin6) < 0)
+       if (__pex_msg_send(net->pex.fd.fd, &sin6, NULL, 0) < 0)
                D_PEER(net, peer, "pex_msg_send failed: %s", strerror(errno));
 }
 
@@ -82,7 +86,7 @@ static void pex_msg_send_ext(struct network *net, struct network_peer *peer,
        if (!addr)
                return pex_msg_send(net, peer);
 
-       if (__pex_msg_send(-1, addr) < 0)
+       if (__pex_msg_send(-1, addr, NULL, 0) < 0)
                D_NET(net, "pex_msg_send_ext(%s) failed: %s",
                      inet_ntop(addr->sin6_family, (const void *)&addr->sin6_addr, addrbuf,
                                sizeof(addrbuf)),
@@ -164,8 +168,22 @@ network_pex_handle_endpoint_change(struct network *net, struct network_peer *pee
 static void
 network_pex_host_request_update(struct network *net, struct network_pex_host *host)
 {
+       union {
+               struct {
+                       struct ip ip;
+                       struct udphdr udp;
+               } ipv4;
+               struct {
+                       struct ip6_hdr ip;
+                       struct udphdr udp;
+               } ipv6;
+       } packet = {};
+       struct udphdr *udp;
        char addrstr[INET6_ADDRSTRLEN];
+       union network_endpoint dest_ep;
+       union network_addr local_addr = {};
        uint64_t version = 0;
+       int len;
 
        if (net->net_data_len)
                version = net->net_data_version;
@@ -181,7 +199,57 @@ network_pex_host_request_update(struct network *net, struct network_pex_host *ho
                                         net->config.auth_key, &host->endpoint,
                                         version, true))
                return;
-       __pex_msg_send(-1, &host->endpoint);
+
+       __pex_msg_send(-1, &host->endpoint, NULL, 0);
+
+       if (!net->net_config.local_host)
+               return;
+
+       pex_msg_init_ext(net, PEX_MSG_ENDPOINT_NOTIFY, true);
+
+       memcpy(&dest_ep, &host->endpoint, sizeof(dest_ep));
+
+       /* work around issue with local address lookup for local broadcast */
+       if (host->endpoint.sa.sa_family == AF_INET) {
+               uint8_t *data = (uint8_t *)&dest_ep.in.sin_addr;
+
+               if (data[3] == 0xff)
+                       data[3] = 0xfe;
+       }
+       network_get_local_addr(&local_addr, &dest_ep);
+
+       memset(&dest_ep, 0, sizeof(dest_ep));
+       dest_ep.sa.sa_family = host->endpoint.sa.sa_family;
+       if (host->endpoint.sa.sa_family == AF_INET) {
+               packet.ipv4.ip = (struct ip){
+                       .ip_hl = 5,
+                       .ip_v = 4,
+                       .ip_ttl = 64,
+                       .ip_p = IPPROTO_UDP,
+                       .ip_src = local_addr.in,
+                       .ip_dst = host->endpoint.in.sin_addr,
+               };
+               dest_ep.in.sin_addr = host->endpoint.in.sin_addr;
+               udp = &packet.ipv4.udp;
+               len = sizeof(packet.ipv4);
+       } else {
+               packet.ipv6.ip = (struct ip6_hdr){
+                       .ip6_flow = htonl(6 << 28),
+                       .ip6_hops = 128,
+                       .ip6_nxt = IPPROTO_UDP,
+                       .ip6_src = local_addr.in6,
+                       .ip6_dst = host->endpoint.in6.sin6_addr,
+               };
+               dest_ep.in6.sin6_addr = host->endpoint.in6.sin6_addr;
+               udp = &packet.ipv6.udp;
+               len = sizeof(packet.ipv6);
+       }
+
+       udp->uh_sport = htons(net->net_config.local_host->peer.port);
+       udp->uh_dport = host->endpoint.in6.sin6_port;
+
+       if (__pex_msg_send(-1, &dest_ep, &packet, len) < 0)
+               D_NET(net, "pex_msg_send_raw failed: %s", strerror(errno));
 }
 
 static void
@@ -543,6 +611,8 @@ network_pex_recv(struct network *net, struct network_peer *peer, struct pex_hdr
                network_pex_recv_update_response(net, data, hdr->len,
                                              NULL, hdr->opcode);
                break;
+       case PEX_MSG_ENDPOINT_NOTIFY:
+               break;
        }
 }
 
@@ -740,6 +810,8 @@ global_pex_recv(struct pex_hdr *hdr, struct sockaddr_in6 *addr)
        struct network_peer *peer;
        struct network *net;
        void *data = (void *)(ehdr + 1);
+       char buf[INET6_ADDRSTRLEN];
+       int addr_len;
 
        if (hdr->version != 0)
                return;
@@ -768,6 +840,28 @@ global_pex_recv(struct pex_hdr *hdr, struct sockaddr_in6 *addr)
        case PEX_MSG_UPDATE_RESPONSE_NO_DATA:
                network_pex_recv_update_response(net, data, hdr->len, addr, hdr->opcode);
                break;
+       case PEX_MSG_ENDPOINT_NOTIFY:
+               peer = pex_msg_peer(net, hdr->id);
+               if (!peer)
+                       break;
+
+               if (IN6_IS_ADDR_V4MAPPED(&addr->sin6_addr)) {
+                       struct sockaddr_in *sin = (struct sockaddr_in *)addr;
+                       struct in_addr in = *(struct in_addr *)&addr->sin6_addr.s6_addr[12];
+                       int port = addr->sin6_port;
+
+                       memset(addr, 0, sizeof(*addr));
+                       sin->sin_port = port;
+                       sin->sin_family = AF_INET;
+                       sin->sin_addr = in;
+               }
+
+               D_PEER(net, peer, "receive endpoint notification from %s",
+                 inet_ntop(addr->sin6_family, network_endpoint_addr((void *)addr, &addr_len),
+                           buf, sizeof(buf)));
+
+               memcpy(&peer->state.next_endpoint, addr, sizeof(*addr));
+               break;
        }
 }