diff --git a/main.c b/main.c index 7b44b6b..f81801f 100644 --- a/main.c +++ b/main.c @@ -22,6 +22,7 @@ #endif #define IPV6_FLOWINFO_MASK bpf_htonl(0x0FFFFFFF) +#define VLAN_MAX_DEPTH 2 /* Support double-tagged VLANs */ /* Forwarding ports */ struct { @@ -35,14 +36,19 @@ struct { struct flow_key { __u8 proto; __u8 pad[3]; /* alignment */ + __u16 vlan_id; /* VLAN ID (0 if untagged) */ + __u16 pad2; /* alignment */ + union { __u32 ipv4_src; __u8 ipv6_src[16]; }; + union { __u32 ipv4_dst; __u8 ipv6_dst[16]; }; + __u16 sport; __u16 dport; }; @@ -88,6 +94,36 @@ static __always_inline void record_stats(struct xdp_md *ctx, struct flow_key *ke } } +/* Parse VLAN headers and return next protocol and offset */ +static __always_inline int parse_vlan(void *data, void *data_end, __u64 *nh_off, __u16 *h_proto, __u16 *vlan_id) +{ + struct vlan_hdr { + __be16 h_vlan_TCI; + __be16 h_vlan_encapsulated_proto; + } *vhdr; + int i; + + /* Parse up to VLAN_MAX_DEPTH VLAN headers */ + #pragma unroll + for (i = 0; i < VLAN_MAX_DEPTH; i++) { + if (*h_proto != bpf_htons(ETH_P_8021Q) && *h_proto != bpf_htons(ETH_P_8021AD)) + break; + + vhdr = data + *nh_off; + if ((void *)(vhdr + 1) > data_end) + return -1; + + /* Store the outermost VLAN ID */ + if (i == 0) + *vlan_id = bpf_ntohs(vhdr->h_vlan_TCI) & 0x0FFF; + + *nh_off += sizeof(*vhdr); + *h_proto = vhdr->h_vlan_encapsulated_proto; + } + + return 0; +} + static __always_inline int xdp_l3fwd_flags(struct xdp_md *ctx, __u32 flags) { void *data_end = (void *)(long)ctx->data_end; @@ -99,6 +135,7 @@ static __always_inline int xdp_l3fwd_flags(struct xdp_md *ctx, __u32 flags) __u16 h_proto; __u64 nh_off; int rc; + __u16 vlan_id = 0; nh_off = sizeof(*eth); if (data + nh_off > data_end) @@ -107,7 +144,12 @@ static __always_inline int xdp_l3fwd_flags(struct xdp_md *ctx, __u32 flags) __builtin_memset(&fib_params, 0, sizeof(fib_params)); h_proto = eth->h_proto; + /* Parse VLAN headers if present */ + if (parse_vlan(data, data_end, &nh_off, &h_proto, &vlan_id) < 0) + return XDP_DROP; + struct flow_key key = {}; + key.vlan_id = vlan_id; __u64 bytes = data_end - data; if (h_proto == bpf_htons(ETH_P_IP)) { @@ -214,4 +256,4 @@ int xdp_l3fwd_direct_prog(struct xdp_md *ctx) return xdp_l3fwd_flags(ctx, BPF_FIB_LOOKUP_DIRECT); } -char _license[] SEC("license") = "GPL"; \ No newline at end of file +char _license[] SEC("license") = "GPL";