vxlan: add bpf program to fix up tcp mss values
authorFelix Fietkau <nbd@nbd.name>
Wed, 29 Jun 2022 18:12:48 +0000 (20:12 +0200)
committerFelix Fietkau <nbd@nbd.name>
Wed, 29 Jun 2022 19:18:09 +0000 (21:18 +0200)
Signed-off-by: Felix Fietkau <nbd@nbd.name>
CMakeLists.txt
bpf.c [new file with mode: 0644]
main.c
mss-bpf.c [new file with mode: 0644]
rtnl.c [new file with mode: 0644]
unetd.h
utils.h
vxlan.c

index 73c4cfa55300003f9ba884a31f800e246425e071..4ac58f1c82babf49e44e61f6bc41c17fd6f276d0 100644 (file)
@@ -17,9 +17,11 @@ FIND_LIBRARY(libjson NAMES json-c json)
 OPTION(UBUS_SUPPORT "enable ubus support" ON)
 IF(CMAKE_SYSTEM_NAME STREQUAL "Linux")
        FIND_LIBRARY(nl nl-tiny)
-       SET(SOURCES ${SOURCES} wg-linux.c vxlan.c)
+       find_library(bpf NAMES bpf)
+       SET(SOURCES ${SOURCES} wg-linux.c vxlan.c bpf.c rtnl.c)
 ELSE()
        SET(nl "")
+       SET(bpf "")
 ENDIF()
 IF(UBUS_SUPPORT)
   SET(SOURCES ${SOURCES} ubus.c)
@@ -30,7 +32,7 @@ ELSE()
 ENDIF()
 
 ADD_EXECUTABLE(unetd ${SOURCES})
-TARGET_LINK_LIBRARIES(unetd ubox ${ubus} blobmsg_json ${libjson} ${nl})
+TARGET_LINK_LIBRARIES(unetd ubox ${ubus} blobmsg_json ${libjson} ${nl} ${bpf})
 
 INSTALL(TARGETS unetd
        RUNTIME DESTINATION sbin
diff --git a/bpf.c b/bpf.c
new file mode 100644 (file)
index 0000000..d0ad683
--- /dev/null
+++ b/bpf.c
@@ -0,0 +1,110 @@
+// SPDX-License-Identifier: GPL-2.0+
+/*
+ * Copyright (C) 2022 Felix Fietkau <nbd@nbd.name>
+ */
+#define _GNU_SOURCE
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <sys/resource.h>
+#include <netinet/if_ether.h>
+#include <netlink/msg.h>
+#include <netlink/attr.h>
+#include <netlink/socket.h>
+#include <linux/rtnetlink.h>
+#include <linux/pkt_cls.h>
+#include <bpf/bpf.h>
+#include <bpf/libbpf.h>
+#include "unetd.h"
+
+static int unetd_bpf_pr(enum libbpf_print_level level, const char *format,
+                    va_list args)
+{
+       return vfprintf(stderr, format, args);
+}
+
+static void unetd_init_env(void)
+{
+       struct rlimit limit = {
+               .rlim_cur = RLIM_INFINITY,
+               .rlim_max = RLIM_INFINITY,
+       };
+
+       setrlimit(RLIMIT_MEMLOCK, &limit);
+}
+
+static void
+unetd_set_prog_mtu(struct bpf_object *obj, uint32_t mtu)
+{
+       struct bpf_map *map = NULL;
+
+       while ((map = bpf_object__next_map(obj, map)) != NULL) {
+               if (!strstr(bpf_map__name(map), ".rodata"))
+                       continue;
+
+               bpf_map__set_initial_value(map, &mtu, sizeof(mtu));
+       }
+}
+
+static int
+unetd_attach_bpf_prog(int ifindex, int fd, bool egress)
+{
+       DECLARE_LIBBPF_OPTS(bpf_tc_hook, hook,
+                           .attach_point = egress ? BPF_TC_EGRESS : BPF_TC_INGRESS,
+                           .ifindex = ifindex);
+       DECLARE_LIBBPF_OPTS(bpf_tc_opts, attach_tc,
+                           .flags = BPF_TC_F_REPLACE,
+                           .handle = 1,
+                               .prog_fd = fd,
+                           .priority = UNETD_MSS_PRIO_BASE);
+
+       bpf_tc_hook_create(&hook);
+
+       return bpf_tc_attach(&hook, &attach_tc);
+}
+
+int unetd_attach_mssfix(int ifindex, int mtu)
+{
+       struct bpf_program *prog;
+       struct bpf_object *obj;
+       int prog_fd;
+       int ret = -1;
+
+       if (rtnl_init())
+               return -1;
+
+       unetd_init_env();
+       libbpf_set_print(unetd_bpf_pr);
+
+       obj = bpf_object__open_file(mssfix_path, NULL);
+       if (libbpf_get_error(obj)) {
+               perror("bpf_object__open_file");
+               goto out;
+       }
+
+       prog = bpf_object__find_program_by_name(obj, "mssfix");
+       if (!prog) {
+               perror("bpf_object__find_program_by_name");
+               goto out;
+       }
+
+       bpf_program__set_type(prog, BPF_PROG_TYPE_SCHED_CLS);
+
+       unetd_set_prog_mtu(obj, mtu);
+
+       if (bpf_object__load(obj)) {
+               perror("bpf_object__load");
+               goto out;
+       }
+
+       prog_fd = bpf_program__fd(prog);
+       if (unetd_attach_bpf_prog(ifindex, prog_fd, true) ||
+           unetd_attach_bpf_prog(ifindex, prog_fd, false))
+               goto out;
+
+       ret = 0;
+
+out:
+       bpf_object__close(obj);
+
+       return ret;
+}
diff --git a/main.c b/main.c
index 74fe9648d3acf3a2272ac6f4192e9712cd0e4e2d..3bd774471b8fd8baf3d4917b61de59a548dc566f 100644 (file)
--- a/main.c
+++ b/main.c
@@ -16,6 +16,7 @@ struct cmdline_network {
 
 static struct cmdline_network *cmd_nets;
 static const char *hosts_file;
+const char *mssfix_path = UNETD_MSS_BPF_PATH;
 bool dummy_mode;
 bool debug;
 
@@ -97,7 +98,7 @@ int main(int argc, char **argv)
        struct cmdline_network *net;
        int ch;
 
-       while ((ch = getopt(argc, argv, "Ddh:N:")) != -1) {
+       while ((ch = getopt(argc, argv, "Ddh:M:N:")) != -1) {
                switch (ch) {
                case 'd':
                        debug = true;
@@ -114,6 +115,9 @@ int main(int argc, char **argv)
                        net->data = optarg;
                        cmd_nets = net;
                        break;
+               case 'M':
+                       mssfix_path = optarg;
+                       break;
                }
        }
 
diff --git a/mss-bpf.c b/mss-bpf.c
new file mode 100644 (file)
index 0000000..5f5cb61
--- /dev/null
+++ b/mss-bpf.c
@@ -0,0 +1,214 @@
+// SPDX-License-Identifier: GPL-2.0+
+/*
+ * Copyright (C) 2021 Felix Fietkau <nbd@nbd.name>
+ */
+#define KBUILD_MODNAME "foo"
+#include <uapi/linux/bpf.h>
+#include <uapi/linux/if_ether.h>
+#include <uapi/linux/if_packet.h>
+#include <uapi/linux/ip.h>
+#include <uapi/linux/ipv6.h>
+#include <uapi/linux/in.h>
+#include <uapi/linux/tcp.h>
+#include <uapi/linux/filter.h>
+#include <uapi/linux/pkt_cls.h>
+#include <linux/ip.h>
+#include <net/ipv6.h>
+#include <net/tcp.h>
+#include <bpf/bpf_helpers.h>
+#include <bpf/bpf_endian.h>
+
+const volatile static uint32_t mtu = 1420;
+
+static __always_inline int proto_is_vlan(__u16 h_proto)
+{
+       return !!(h_proto == bpf_htons(ETH_P_8021Q) ||
+                 h_proto == bpf_htons(ETH_P_8021AD));
+}
+
+static __always_inline int proto_is_ip(__u16 h_proto)
+{
+       return !!(h_proto == bpf_htons(ETH_P_IP) ||
+                 h_proto == bpf_htons(ETH_P_IPV6));
+}
+
+static __always_inline void *skb_ptr(struct __sk_buff *skb, __u32 offset)
+{
+       void *start = (void *)(unsigned long long)skb->data;
+
+       return start + offset;
+}
+
+static __always_inline void *skb_end_ptr(struct __sk_buff *skb)
+{
+       return (void *)(unsigned long long)skb->data_end;
+}
+
+static __always_inline int skb_check(struct __sk_buff *skb, void *ptr)
+{
+       if (ptr > skb_end_ptr(skb))
+               return -1;
+
+       return 0;
+}
+
+static __always_inline int
+parse_ethernet(struct __sk_buff *skb, __u32 *offset)
+{
+       struct ethhdr *eth;
+       __u16 h_proto;
+       int i;
+
+       eth = skb_ptr(skb, *offset);
+       if (skb_check(skb, eth + 1))
+               return -1;
+
+       h_proto = eth->h_proto;
+       *offset += sizeof(*eth);
+
+#pragma unroll
+       for (i = 0; i < 2; i++) {
+               struct vlan_hdr *vlh = skb_ptr(skb, *offset);
+
+               if (!proto_is_vlan(h_proto))
+                       break;
+
+               if (skb_check(skb, vlh + 1))
+                       return -1;
+
+               h_proto = vlh->h_vlan_encapsulated_proto;
+               *offset += sizeof(*vlh);
+       }
+
+       return h_proto;
+}
+
+static __always_inline int
+parse_ipv4(struct __sk_buff *skb, __u32 *offset)
+{
+       struct iphdr *iph;
+       int hdr_len;
+
+       iph = skb_ptr(skb, *offset);
+       if (skb_check(skb, iph + 1))
+               return -1;
+
+       hdr_len = iph->ihl * 4;
+       if (bpf_skb_pull_data(skb, *offset + hdr_len + sizeof(struct tcphdr) + 20))
+               return -1;
+
+       iph = skb_ptr(skb, *offset);
+       *offset += hdr_len;
+
+       if (skb_check(skb, (void *)(iph + 1)))
+               return -1;
+
+       return READ_ONCE(iph->protocol);
+}
+
+static __always_inline bool
+parse_ipv6(struct __sk_buff *skb, __u32 *offset)
+{
+       struct ipv6hdr *iph;
+
+       if (bpf_skb_pull_data(skb, *offset + sizeof(*iph) + sizeof(struct tcphdr) + 20))
+               return -1;
+
+       iph = skb_ptr(skb, *offset);
+       *offset += sizeof(*iph);
+
+       if (skb_check(skb, (void *)(iph + 1)))
+               return -1;
+
+       return READ_ONCE(iph->nexthdr);
+}
+
+static inline unsigned int
+optlen(const u_int8_t *opt)
+{
+       if (opt[0] <= TCPOPT_NOP || opt[1] == 0)
+               return 1;
+
+       return opt[1];
+}
+
+static __always_inline void
+fixup_tcp(struct __sk_buff *skb, __u32 offset, __u16 mss)
+{
+       struct tcphdr *tcph;
+       __u16 oldmss;
+       __u8 *opt;
+       u8 flags;
+       int hdrlen;
+       int i;
+
+       tcph = skb_ptr(skb, offset);
+       if (skb_check(skb, tcph + 1))
+               return;
+
+       flags = tcp_flag_byte(tcph);
+       if (!(flags & TCPHDR_SYN))
+               return;
+
+       hdrlen = tcph->doff * 4;
+       if (hdrlen <= sizeof(struct tcphdr))
+               return;
+
+       hdrlen += offset;
+       offset += sizeof(*tcph);
+
+#pragma unroll
+       for (i = 0; i < 5; i++) {
+               unsigned int len;
+
+               if (offset >= hdrlen)
+                       return;
+
+               opt = skb_ptr(skb, offset);
+               if (skb_check(skb, opt + TCPOLEN_MSS))
+                       return;
+
+               len = optlen(opt);
+               offset += len;
+               if (opt[0] != TCPOPT_MSS || opt[1] != TCPOLEN_MSS)
+                       continue;
+
+               goto found;
+       }
+       return;
+
+found:
+       oldmss = (opt[2] << 8) | opt[3];
+       if (oldmss <= mss)
+               return;
+
+       opt[2] = mss >> 8;
+       opt[3] = mss & 0xff;
+       csum_replace2(&tcph->check, bpf_htons(oldmss), bpf_htons(mss));
+}
+
+SEC("tc")
+int mssfix(struct __sk_buff *skb)
+{
+       __u32 offset = 0;
+       __u8 ipproto;
+       __u16 mss;
+       int type;
+
+       type = parse_ethernet(skb, &offset);
+       if (type == bpf_htons(ETH_P_IP))
+               type = parse_ipv4(skb, &offset);
+       else if (type == bpf_htons(ETH_P_IPV6))
+               type = parse_ipv6(skb, &offset);
+       else
+               return TC_ACT_UNSPEC;
+
+       if (type != IPPROTO_TCP)
+               return TC_ACT_UNSPEC;
+
+       mss = mtu;
+       mss -= offset + sizeof(struct tcphdr);
+       fixup_tcp(skb, offset, mss);
+
+       return TC_ACT_UNSPEC;
+}
diff --git a/rtnl.c b/rtnl.c
new file mode 100644 (file)
index 0000000..4180575
--- /dev/null
+++ b/rtnl.c
@@ -0,0 +1,94 @@
+// SPDX-License-Identifier: GPL-2.0+
+/*
+ * Copyright (C) 2022 Felix Fietkau <nbd@nbd.name>
+ */
+#define _GNU_SOURCE
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netlink/msg.h>
+#include <netlink/attr.h>
+#include <netlink/socket.h>
+#include <linux/rtnetlink.h>
+#include "unetd.h"
+
+static struct nl_sock *rtnl;
+bool rtnl_ignore_errors;
+
+static int
+unetd_nl_error_cb(struct sockaddr_nl *nla, struct nlmsgerr *err,
+                  void *arg)
+{
+       struct nlmsghdr *nlh = (struct nlmsghdr *) err - 1;
+       struct nlattr *tb[NLMSGERR_ATTR_MAX + 1];
+       struct nlattr *attrs;
+       int ack_len = sizeof(*nlh) + sizeof(int) + sizeof(*nlh);
+       int len = nlh->nlmsg_len;
+       const char *errstr = "(unknown)";
+
+       if (rtnl_ignore_errors)
+               return NL_STOP;
+
+       if (!(nlh->nlmsg_flags & NLM_F_ACK_TLVS))
+               return NL_STOP;
+
+       if (!(nlh->nlmsg_flags & NLM_F_CAPPED))
+               ack_len += err->msg.nlmsg_len - sizeof(*nlh);
+
+       attrs = (void *) ((unsigned char *) nlh + ack_len);
+       len -= ack_len;
+
+       nla_parse(tb, NLMSGERR_ATTR_MAX, attrs, len, NULL);
+       if (tb[NLMSGERR_ATTR_MSG])
+               errstr = nla_data(tb[NLMSGERR_ATTR_MSG]);
+
+       D("Netlink error(%d): %s\n", err->error, errstr);
+
+       return NL_STOP;
+}
+
+int rtnl_call(struct nl_msg *msg)
+{
+       int ret;
+
+       ret = nl_send_auto_complete(rtnl, msg);
+       nlmsg_free(msg);
+
+       if (ret < 0)
+               return ret;
+
+       return nl_wait_for_ack(rtnl);
+}
+
+int rtnl_init(void)
+{
+       int fd, opt;
+
+       if (rtnl)
+               return 0;
+
+       rtnl = nl_socket_alloc();
+       if (!rtnl)
+               return -1;
+
+       if (nl_connect(rtnl, NETLINK_ROUTE))
+               goto free;
+
+       nl_socket_disable_seq_check(rtnl);
+       nl_socket_set_buffer_size(rtnl, 65536, 0);
+       nl_cb_err(nl_socket_get_cb(rtnl), NL_CB_CUSTOM, unetd_nl_error_cb, NULL);
+
+       fd = nl_socket_get_fd(rtnl);
+
+       opt = 1;
+       setsockopt(fd, SOL_NETLINK, NETLINK_EXT_ACK, &opt, sizeof(opt));
+
+       opt = 1;
+       setsockopt(fd, SOL_NETLINK, NETLINK_CAP_ACK, &opt, sizeof(opt));
+
+       return 0;
+
+free:
+       nl_socket_free(rtnl);
+       rtnl = NULL;
+       return -1;
+}
diff --git a/unetd.h b/unetd.h
index b6fd43746747e60782d5d0c32ed5c2c1a07eb650..799bff72873d2e7b28467f9b0e486515d26cf74b 100644 (file)
--- a/unetd.h
+++ b/unetd.h
@@ -19,6 +19,7 @@
 #include "service.h"
 #include "ubus.h"
 
+extern const char *mssfix_path;
 extern bool dummy_mode;
 extern bool debug;
 
@@ -34,7 +35,10 @@ extern bool debug;
 #define D_PEER(net, peer, format, ...) D_NET(net, "host %s " format, network_peer_name(peer), ##__VA_ARGS__)
 #define D_SERVICE(net, service, format, ...) D_NET(net, "service %s " format, network_service_name(service), ##__VA_ARGS__)
 
+#define UNETD_MSS_BPF_PATH     "/lib/bpf/mss.o"
+#define UNETD_MSS_PRIO_BASE    0x130
 
 void unetd_write_hosts(void);
+int unetd_attach_mssfix(int ifindex, int mtu);
 
 #endif
diff --git a/utils.h b/utils.h
index b0243743e8db6c8e5fb3d253fab9d38d2c615577..077080783c314af51b09fee04fac8295b12e2293 100644 (file)
--- a/utils.h
+++ b/utils.h
@@ -7,6 +7,8 @@
 
 #include <netinet/in.h>
 
+struct nl_msg;
+
 union network_addr {
        struct {
                uint8_t network_id[8];
@@ -82,4 +84,7 @@ static inline void bitmask_set_val(uint32_t *mask, unsigned int i, bool val)
                bitmask_clear(mask, i);
 }
 
+int rtnl_init(void);
+int rtnl_call(struct nl_msg *msg);
+
 #endif
diff --git a/vxlan.c b/vxlan.c
index b15d4c316da650a14debf7f9850bb50bc5455251..f042a487baebd3c13cce41e742878cbaa68b603d 100644 (file)
--- a/vxlan.c
+++ b/vxlan.c
@@ -11,6 +11,8 @@
 #include <netinet/if_ether.h>
 #include <net/if.h>
 #include <linux/rtnetlink.h>
+#include <linux/ipv6.h>
+#include <linux/udp.h>
 #include "unetd.h"
 
 struct vxlan_tunnel {
@@ -26,39 +28,19 @@ struct vxlan_tunnel {
        bool active;
 };
 
-static struct nl_sock *rtnl;
-static bool ignore_errors;
-
-static int
-unetd_nl_error_cb(struct sockaddr_nl *nla, struct nlmsgerr *err,
-                  void *arg)
+static uint32_t
+vxlan_tunnel_id(struct vxlan_tunnel *vt)
 {
-       struct nlmsghdr *nlh = (struct nlmsghdr *) err - 1;
-       struct nlattr *tb[NLMSGERR_ATTR_MAX + 1];
-       struct nlattr *attrs;
-       int ack_len = sizeof(*nlh) + sizeof(int) + sizeof(*nlh);
-       int len = nlh->nlmsg_len;
-       const char *errstr = "(unknown)";
-
-       if (ignore_errors)
-               return NL_STOP;
-
-       if (!(nlh->nlmsg_flags & NLM_F_ACK_TLVS))
-               return NL_STOP;
-
-       if (!(nlh->nlmsg_flags & NLM_F_CAPPED))
-               ack_len += err->msg.nlmsg_len - sizeof(*nlh);
-
-       attrs = (void *) ((unsigned char *) nlh + ack_len);
-       len -= ack_len;
+       siphash_key_t key = {};
+       const char *name = network_service_name(vt->s);
+       uint64_t val;
 
-       nla_parse(tb, NLMSGERR_ATTR_MAX, attrs, len, NULL);
-       if (tb[NLMSGERR_ATTR_MSG])
-               errstr = nla_data(tb[NLMSGERR_ATTR_MSG]);
+       if (vt->vni != ~0)
+               return vt->vni;
 
-       D("Netlink error(%d): %s\n", err->error, errstr);
+       siphash_to_le64(&val, name, strlen(name), &key);
 
-       return NL_STOP;
+       return val & 0x00ffffff;
 }
 
 static struct nl_msg *vxlan_rtnl_msg(const char *ifname, int type, int flags)
@@ -78,69 +60,6 @@ static struct nl_msg *vxlan_rtnl_msg(const char *ifname, int type, int flags)
        return msg;
 }
 
-static int vxlan_rtnl_call(struct nl_msg *msg)
-{
-       int ret;
-
-       ret = nl_send_auto_complete(rtnl, msg);
-       nlmsg_free(msg);
-
-       if (ret < 0)
-               return ret;
-
-       return nl_wait_for_ack(rtnl);
-}
-
-static int
-vxlan_rtnl_init(void)
-{
-       int fd, opt;
-
-       if (rtnl)
-               return 0;
-
-       rtnl = nl_socket_alloc();
-       if (!rtnl)
-               return -1;
-
-       if (nl_connect(rtnl, NETLINK_ROUTE))
-               goto free;
-
-       nl_socket_disable_seq_check(rtnl);
-       nl_socket_set_buffer_size(rtnl, 65536, 0);
-       nl_cb_err(nl_socket_get_cb(rtnl), NL_CB_CUSTOM, unetd_nl_error_cb, NULL);
-
-       fd = nl_socket_get_fd(rtnl);
-
-       opt = 1;
-       setsockopt(fd, SOL_NETLINK, NETLINK_EXT_ACK, &opt, sizeof(opt));
-
-       opt = 1;
-       setsockopt(fd, SOL_NETLINK, NETLINK_CAP_ACK, &opt, sizeof(opt));
-
-       return 0;
-
-free:
-       nl_socket_free(rtnl);
-       rtnl = NULL;
-       return -1;
-}
-
-static uint32_t
-vxlan_tunnel_id(struct vxlan_tunnel *vt)
-{
-       siphash_key_t key = {};
-       const char *name = network_service_name(vt->s);
-       uint64_t val;
-
-       if (vt->vni != ~0)
-               return vt->vni;
-
-       siphash_to_le64(&val, name, strlen(name), &key);
-
-       return val & 0x00ffffff;
-}
-
 static int
 vxlan_update_host_fdb_entry(struct vxlan_tunnel *vt, struct network_host *host, bool add)
 {
@@ -163,7 +82,7 @@ vxlan_update_host_fdb_entry(struct vxlan_tunnel *vt, struct network_host *host,
        nla_put(msg, NDA_DST, sizeof(struct in6_addr), &host->peer.local_addr);
        nla_put_u32(msg, NDA_IFINDEX, vt->net->ifindex);
 
-       return vxlan_rtnl_call(msg);
+       return rtnl_call(msg);
 }
 
 static void
@@ -208,8 +127,9 @@ vxlan_tunnel_init(struct vxlan_tunnel *vt)
        struct nlattr *linkinfo, *data;
        struct nl_msg *msg;
        struct in6_addr group_addr;
+       int mtu;
 
-       if (vxlan_rtnl_init())
+       if (rtnl_init())
                return;
 
        memset(&group_addr, 0xff, sizeof(group_addr));
@@ -230,7 +150,7 @@ vxlan_tunnel_init(struct vxlan_tunnel *vt)
 
        nla_nest_end(msg, linkinfo);
 
-       if (vxlan_rtnl_call(msg) < 0)
+       if (rtnl_call(msg) < 0)
                return;
 
        vt->ifindex = if_nametoindex(vt->ifname);
@@ -241,6 +161,9 @@ vxlan_tunnel_init(struct vxlan_tunnel *vt)
 
        vt->active = true;
        vxlan_update_fdb_hosts(vt);
+
+       mtu = 1420 - sizeof(struct ipv6hdr) - sizeof(struct udphdr) - 8;
+       unetd_attach_mssfix(vt->ifindex, mtu);
 }
 
 static void
@@ -248,12 +171,9 @@ vxlan_tunnel_teardown(struct vxlan_tunnel *vt)
 {
        struct nl_msg *msg;
 
-       if (!rtnl)
-               return;
-
        vt->active = false;
        msg = vxlan_rtnl_msg(vt->ifname, RTM_DELLINK, 0);
-       vxlan_rtnl_call(msg);
+       rtnl_call(msg);
 }
 
 static const char *