diff --git a/main.c b/main.c index 66c38ce..1a17f08 100644 --- a/main.c +++ b/main.c @@ -30,12 +30,32 @@ struct vlan_hdr { __be16 h_vlan_encapsulated_proto; }; +/* Auto-learned VLAN info */ +struct vlan_learning_entry { + __u16 vlan_id; + __u16 confidence; + __u32 last_seen; +}; + struct { - __uint(type, BPF_MAP_TYPE_DEVMAP); + __uint(type, BPF_MAP_TYPE_HASH); __type(key, __u32); - __type(value, __u32); + __type(value, struct vlan_learning_entry); __uint(max_entries, 512); -} xdp_l3fwd_ports SEC(".maps"); +} xdp_vlan_learning SEC(".maps"); + +struct vlan_parent_info { + __u32 parent_ifindex; + __u16 vlan_id; + __u16 pad; +}; + +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __type(key, __u32); + __type(value, struct vlan_parent_info); + __uint(max_entries, 512); +} xdp_vlan_parents SEC(".maps"); struct flow_key { __u8 proto; @@ -91,6 +111,60 @@ static __always_inline void record_stats(struct flow_key *key, __u64 bytes) } } +static __always_inline void learn_vlan(struct xdp_md *ctx, __u16 vlan_id) +{ + __u32 ifindex = ctx->ingress_ifindex; + struct vlan_learning_entry *entry = bpf_map_lookup_elem(&xdp_vlan_learning, &ifindex); + + if (entry) { + if (vlan_id > 0) { + if (entry->vlan_id == vlan_id) { + if (entry->confidence < 65535) + entry->confidence++; + } else if (entry->confidence > 0) { + entry->confidence--; + if (entry->confidence == 0) { + entry->vlan_id = vlan_id; + entry->confidence = 1; + } + } + } + } else if (vlan_id > 0) { + struct vlan_learning_entry new_entry = { + .vlan_id = vlan_id, + .confidence = 1, + .last_seen = 0, + }; + bpf_map_update_elem(&xdp_vlan_learning, &ifindex, &new_entry, BPF_ANY); + } +} + +static __always_inline __u16 get_interface_vlan(struct xdp_md *ctx, __u32 ifindex) +{ + struct vlan_parent_info *parent_info = bpf_map_lookup_elem(&xdp_vlan_parents, &ifindex); + if (parent_info && parent_info->vlan_id > 0) { + return parent_info->vlan_id; + } + + struct vlan_learning_entry *learned = bpf_map_lookup_elem(&xdp_vlan_learning, &ifindex); + if (learned && learned->confidence > 5) { + return learned->vlan_id; + } + + __u32 ingress_idx = ctx->ingress_ifindex; + if (ingress_idx != ifindex) { + struct vlan_learning_entry *ingress_learned = bpf_map_lookup_elem(&xdp_vlan_learning, &ingress_idx); + if (ingress_learned && ingress_learned->confidence > 10) { + struct vlan_learning_entry *egress_learned = bpf_map_lookup_elem(&xdp_vlan_learning, &ifindex); + if (!egress_learned || egress_learned->confidence < 3) { + return 0; + } + } + } + + return 0; +} + static __always_inline int parse_vlan(void *data, void *data_end, __u64 *nh_off, __u16 *h_proto, __u16 *vlan_id) { struct vlan_hdr *vh; @@ -144,6 +218,84 @@ static __always_inline int skip_ip6hdrext(void *data, void *data_end, __u64 *nh_ return -1; } +/* Insert VLAN tag using head adjustment */ +static __always_inline int insert_vlan_tag(struct xdp_md *ctx, __u16 vlan_id) +{ + void *data_end = (void *)(long)ctx->data_end; + void *data = (void *)(long)ctx->data; + + struct ethhdr *old_eth = data; + + if ((void *)(old_eth + 1) > data_end) + return -1; + + struct ethhdr orig_eth; + __builtin_memcpy(&orig_eth, old_eth, sizeof(orig_eth)); + + /* Expand headroom */ + if (bpf_xdp_adjust_head(ctx, -(int)sizeof(struct vlan_hdr))) + return -1; + + /* Re-read pointers after head adjustment */ + data = (void *)(long)ctx->data; + data_end = (void *)(long)ctx->data_end; + + struct ethhdr *new_eth = data; + struct vlan_hdr *vlan = (struct vlan_hdr *)(new_eth + 1); + + if ((void *)(vlan + 1) > data_end) + return -1; + + /* Copy ethernet header to new position */ + __builtin_memcpy(new_eth->h_dest, orig_eth.h_dest, ETH_ALEN); + __builtin_memcpy(new_eth->h_source, orig_eth.h_source, ETH_ALEN); + + /* Set up VLAN header */ + vlan->h_vlan_TCI = bpf_htons(vlan_id & 0x0FFF); + vlan->h_vlan_encapsulated_proto = orig_eth.h_proto; + + /* Update ethernet proto to VLAN */ + new_eth->h_proto = bpf_htons(ETH_P_8021Q); + + return 0; +} + +/* Remove VLAN tag */ +static __always_inline int remove_vlan_tag(struct xdp_md *ctx) +{ + void *data_end = (void *)(long)ctx->data_end; + void *data = (void *)(long)ctx->data; + + struct ethhdr *eth = data; + struct vlan_hdr *vlan = (struct vlan_hdr *)(eth + 1); + + if ((void *)(vlan + 1) > data_end) + return -1; + + __be16 encap_proto = vlan->h_vlan_encapsulated_proto; + + struct ethhdr tmp_eth; + __builtin_memcpy(&tmp_eth, eth, sizeof(tmp_eth)); + + /* Adjust head to remove VLAN header */ + if (bpf_xdp_adjust_head(ctx, (int)sizeof(struct vlan_hdr))) + return -1; + + /* Re-read pointers after head adjustment */ + data = (void *)(long)ctx->data; + data_end = (void *)(long)ctx->data_end; + eth = data; + + if ((void *)(eth + 1) > data_end) + return -1; + + __builtin_memcpy(eth->h_dest, tmp_eth.h_dest, ETH_ALEN); + __builtin_memcpy(eth->h_source, tmp_eth.h_source, ETH_ALEN); + eth->h_proto = encap_proto; + + return 0; +} + static __always_inline int xdp_l3fwd_flags(struct xdp_md *ctx, __u32 flags) { void *data_end = (void *)(long)ctx->data_end; @@ -157,9 +309,19 @@ static __always_inline int xdp_l3fwd_flags(struct xdp_md *ctx, __u32 flags) struct bpf_fib_lookup fib_params = {}; __u16 h_proto = eth->h_proto; __u16 vlan_id = 0; + __u16 orig_vlan_id = 0; + int had_vlan = 0; + + if (h_proto == bpf_htons(ETH_P_8021Q) || h_proto == bpf_htons(ETH_P_8021AD)) + had_vlan = 1; if (parse_vlan(data, data_end, &nh_off, &h_proto, &vlan_id) < 0) return XDP_DROP; + + orig_vlan_id = vlan_id; + + if (vlan_id > 0) + learn_vlan(ctx, vlan_id); struct flow_key key = {}; key.vlan_id = vlan_id; @@ -183,7 +345,6 @@ static __always_inline int xdp_l3fwd_flags(struct xdp_md *ctx, __u32 flags) __u64 l4_off = nh_off + (ihl * 4); - /* Parse L4 ports - check exactly 4 bytes (sport + dport) */ void *l4_hdr = (void *)((char *)data + l4_off); if ((void *)((char *)l4_hdr + 4) <= data_end) { if (iph->protocol == IPPROTO_TCP || iph->protocol == IPPROTO_UDP) { @@ -220,7 +381,6 @@ static __always_inline int xdp_l3fwd_flags(struct xdp_md *ctx, __u32 flags) key.proto = l4_proto; - /* Parse L4 ports - check exactly 4 bytes */ void *l4_hdr = (void *)((char *)data + l4_off); if ((void *)((char *)l4_hdr + 4) <= data_end) { if (l4_proto == IPPROTO_TCP || l4_proto == IPPROTO_UDP) { @@ -249,18 +409,80 @@ static __always_inline int xdp_l3fwd_flags(struct xdp_md *ctx, __u32 flags) if (rc == 0) { record_stats(&key, bytes); + __u16 egress_vlan = get_interface_vlan(ctx, fib_params.ifindex); + + if (egress_vlan > 0 && !had_vlan) { + /* Need to add VLAN tag */ + if (insert_vlan_tag(ctx, egress_vlan) < 0) + return XDP_DROP; + + } else if (egress_vlan == 0 && had_vlan) { + /* Need to remove VLAN tag */ + if (remove_vlan_tag(ctx) < 0) { + /* Keep VLAN if removal fails */ + } + + } else if (egress_vlan > 0 && had_vlan && egress_vlan != orig_vlan_id) { + /* Need to change VLAN ID - reload pointers first */ + data = (void *)(long)ctx->data; + data_end = (void *)(long)ctx->data_end; + eth = data; + + if ((void *)(eth + 1) > data_end) + return XDP_DROP; + + if (eth->h_proto == bpf_htons(ETH_P_8021Q) || + eth->h_proto == bpf_htons(ETH_P_8021AD)) { + struct vlan_hdr *vlan = (struct vlan_hdr *)(eth + 1); + if ((void *)(vlan + 1) > data_end) + return XDP_DROP; + + vlan->h_vlan_TCI = bpf_htons(egress_vlan & 0x0FFF); + } + } + + /* CRITICAL: Always reload pointers after FIB lookup to satisfy verifier */ + data = (void *)(long)ctx->data; + data_end = (void *)(long)ctx->data_end; + eth = data; + + /* Re-establish packet bounds for verifier */ + if ((void *)(eth + 1) > data_end) + return XDP_DROP; + + nh_off = sizeof(*eth); + + /* Skip VLAN header if present */ + if (eth->h_proto == bpf_htons(ETH_P_8021Q) || + eth->h_proto == bpf_htons(ETH_P_8021AD)) { + nh_off += sizeof(struct vlan_hdr); + } + + /* Verify nh_off is within bounds */ + if ((void *)((char *)data + nh_off) > data_end) + return XDP_DROP; + + /* Decrease TTL/hop_limit */ if (h_proto == bpf_htons(ETH_P_IP)) { struct iphdr *iph = (void *)((char *)data + nh_off); + if ((void *)(iph + 1) > data_end) + return XDP_DROP; ip_decrease_ttl(iph); } else if (h_proto == bpf_htons(ETH_P_IPV6)) { struct ipv6hdr *ip6h = (void *)((char *)data + nh_off); + if ((void *)(ip6h + 1) > data_end) + return XDP_DROP; ip6h->hop_limit--; } + /* Update MAC addresses - verify eth is still valid */ + if ((void *)(eth + 1) > data_end) + return XDP_DROP; + __builtin_memcpy(eth->h_dest, fib_params.dmac, ETH_ALEN); __builtin_memcpy(eth->h_source, fib_params.smac, ETH_ALEN); - return bpf_redirect_map(&xdp_l3fwd_ports, fib_params.ifindex, XDP_PASS); + return bpf_redirect(fib_params.ifindex, 0); } return XDP_PASS;