utils: fix memory leak in network_get_endpoint()
[project/unetd.git] / mss-bpf.c
1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3 * Copyright (C) 2021 Felix Fietkau <nbd@nbd.name>
4 */
5 #define KBUILD_MODNAME "foo"
6 #include <uapi/linux/bpf.h>
7 #include <uapi/linux/if_ether.h>
8 #include <uapi/linux/if_packet.h>
9 #include <uapi/linux/ip.h>
10 #include <uapi/linux/ipv6.h>
11 #include <uapi/linux/in.h>
12 #include <uapi/linux/tcp.h>
13 #include <uapi/linux/filter.h>
14 #include <uapi/linux/pkt_cls.h>
15 #include <linux/ip.h>
16 #include <net/ipv6.h>
17 #include <net/tcp.h>
18 #include <bpf/bpf_helpers.h>
19 #include <bpf/bpf_endian.h>
20
21 const volatile static uint32_t mtu = 1420;
22
23 static __always_inline int proto_is_vlan(__u16 h_proto)
24 {
25 return !!(h_proto == bpf_htons(ETH_P_8021Q) ||
26 h_proto == bpf_htons(ETH_P_8021AD));
27 }
28
29 static __always_inline int proto_is_ip(__u16 h_proto)
30 {
31 return !!(h_proto == bpf_htons(ETH_P_IP) ||
32 h_proto == bpf_htons(ETH_P_IPV6));
33 }
34
35 static __always_inline void *skb_ptr(struct __sk_buff *skb, __u32 offset)
36 {
37 void *start = (void *)(unsigned long long)skb->data;
38
39 return start + offset;
40 }
41
42 static __always_inline void *skb_end_ptr(struct __sk_buff *skb)
43 {
44 return (void *)(unsigned long long)skb->data_end;
45 }
46
47 static __always_inline int skb_check(struct __sk_buff *skb, void *ptr)
48 {
49 if (ptr > skb_end_ptr(skb))
50 return -1;
51
52 return 0;
53 }
54
55 static __always_inline int
56 parse_ethernet(struct __sk_buff *skb, __u32 *offset)
57 {
58 struct ethhdr *eth;
59 __u16 h_proto;
60 int i;
61
62 eth = skb_ptr(skb, *offset);
63 if (skb_check(skb, eth + 1))
64 return -1;
65
66 h_proto = eth->h_proto;
67 *offset += sizeof(*eth);
68
69 #pragma unroll
70 for (i = 0; i < 2; i++) {
71 struct vlan_hdr *vlh = skb_ptr(skb, *offset);
72
73 if (!proto_is_vlan(h_proto))
74 break;
75
76 if (skb_check(skb, vlh + 1))
77 return -1;
78
79 h_proto = vlh->h_vlan_encapsulated_proto;
80 *offset += sizeof(*vlh);
81 }
82
83 return h_proto;
84 }
85
86 static __always_inline int
87 parse_ipv4(struct __sk_buff *skb, __u32 *offset)
88 {
89 struct iphdr *iph;
90 int hdr_len;
91
92 iph = skb_ptr(skb, *offset);
93 if (skb_check(skb, iph + 1))
94 return -1;
95
96 hdr_len = iph->ihl * 4;
97 if (bpf_skb_pull_data(skb, *offset + hdr_len + sizeof(struct tcphdr) + 20))
98 return -1;
99
100 iph = skb_ptr(skb, *offset);
101 *offset += hdr_len;
102
103 if (skb_check(skb, (void *)(iph + 1)))
104 return -1;
105
106 return READ_ONCE(iph->protocol);
107 }
108
109 static __always_inline bool
110 parse_ipv6(struct __sk_buff *skb, __u32 *offset)
111 {
112 struct ipv6hdr *iph;
113
114 if (bpf_skb_pull_data(skb, *offset + sizeof(*iph) + sizeof(struct tcphdr) + 20))
115 return -1;
116
117 iph = skb_ptr(skb, *offset);
118 *offset += sizeof(*iph);
119
120 if (skb_check(skb, (void *)(iph + 1)))
121 return -1;
122
123 return READ_ONCE(iph->nexthdr);
124 }
125
126 static inline unsigned int
127 optlen(const u_int8_t *opt)
128 {
129 if (opt[0] <= TCPOPT_NOP || opt[1] == 0)
130 return 1;
131
132 return opt[1];
133 }
134
135 static __always_inline void
136 fixup_tcp(struct __sk_buff *skb, __u32 offset, __u16 mss)
137 {
138 struct tcphdr *tcph;
139 __u16 oldmss;
140 __u8 *opt;
141 u8 flags;
142 int hdrlen;
143 int i;
144
145 tcph = skb_ptr(skb, offset);
146 if (skb_check(skb, tcph + 1))
147 return;
148
149 flags = tcp_flag_byte(tcph);
150 if (!(flags & TCPHDR_SYN))
151 return;
152
153 hdrlen = tcph->doff * 4;
154 if (hdrlen <= sizeof(struct tcphdr))
155 return;
156
157 hdrlen += offset;
158 offset += sizeof(*tcph);
159
160 #pragma unroll
161 for (i = 0; i < 5; i++) {
162 unsigned int len;
163
164 if (offset >= hdrlen)
165 return;
166
167 opt = skb_ptr(skb, offset);
168 if (skb_check(skb, opt + TCPOLEN_MSS))
169 return;
170
171 len = optlen(opt);
172 offset += len;
173 if (opt[0] != TCPOPT_MSS || opt[1] != TCPOLEN_MSS)
174 continue;
175
176 goto found;
177 }
178 return;
179
180 found:
181 oldmss = (opt[2] << 8) | opt[3];
182 if (oldmss <= mss)
183 return;
184
185 opt[2] = mss >> 8;
186 opt[3] = mss & 0xff;
187 csum_replace2(&tcph->check, bpf_htons(oldmss), bpf_htons(mss));
188 }
189
190 SEC("tc")
191 int mssfix(struct __sk_buff *skb)
192 {
193 __u32 offset = 0;
194 __u8 ipproto;
195 __u16 mss;
196 int type;
197
198 type = parse_ethernet(skb, &offset);
199 if (type == bpf_htons(ETH_P_IP))
200 type = parse_ipv4(skb, &offset);
201 else if (type == bpf_htons(ETH_P_IPV6))
202 type = parse_ipv6(skb, &offset);
203 else
204 return TC_ACT_UNSPEC;
205
206 if (type != IPPROTO_TCP)
207 return TC_ACT_UNSPEC;
208
209 mss = mtu;
210 mss -= offset + sizeof(struct tcphdr);
211 fixup_tcp(skb, offset, mss);
212
213 return TC_ACT_UNSPEC;
214 }