kernel_optimize_test/net/ipv6/seg6_local.c
Ahmed Abdelsalam bb986a5042 seg6: fix seg6_validate_srh() to avoid slab-out-of-bounds
The seg6_validate_srh() is used to validate SRH for three cases:

case1: SRH of data-plane SRv6 packets to be processed by the Linux kernel.
Case2: SRH of the netlink message received  from user-space (iproute2)
Case3: SRH injected into packets through setsockopt

In case1, the SRH can be encoded in the Reduced way (i.e., first SID is
carried in DA only and not represented as SID in the SRH) and the
seg6_validate_srh() now handles this case correctly.

In case2 and case3, the SRH shouldn’t be encoded in the Reduced way
otherwise we lose the first segment (i.e., the first hop).

The current implementation of the seg6_validate_srh() allow SRH of case2
and case3 to be encoded in the Reduced way. This leads a slab-out-of-bounds
problem.

This patch verifies SRH of case1, case2 and case3. Allowing case1 to be
reduced while preventing SRH of case2 and case3 from being reduced .

Reported-by: syzbot+e8c028b62439eac42073@syzkaller.appspotmail.com
Reported-by: YueHaibing <yuehaibing@huawei.com>
Fixes: 0cb7498f23 ("seg6: fix SRH processing to comply with RFC8754")
Signed-off-by: Ahmed Abdelsalam <ahabdels@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
2020-06-04 15:39:32 -07:00

1139 lines
24 KiB
C

// SPDX-License-Identifier: GPL-2.0-or-later
/*
* SR-IPv6 implementation
*
* Authors:
* David Lebrun <david.lebrun@uclouvain.be>
* eBPF support: Mathieu Xhonneux <m.xhonneux@gmail.com>
*/
#include <linux/types.h>
#include <linux/skbuff.h>
#include <linux/net.h>
#include <linux/module.h>
#include <net/ip.h>
#include <net/lwtunnel.h>
#include <net/netevent.h>
#include <net/netns/generic.h>
#include <net/ip6_fib.h>
#include <net/route.h>
#include <net/seg6.h>
#include <linux/seg6.h>
#include <linux/seg6_local.h>
#include <net/addrconf.h>
#include <net/ip6_route.h>
#include <net/dst_cache.h>
#include <net/ip_tunnels.h>
#ifdef CONFIG_IPV6_SEG6_HMAC
#include <net/seg6_hmac.h>
#endif
#include <net/seg6_local.h>
#include <linux/etherdevice.h>
#include <linux/bpf.h>
struct seg6_local_lwt;
struct seg6_action_desc {
int action;
unsigned long attrs;
int (*input)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
int static_headroom;
};
struct bpf_lwt_prog {
struct bpf_prog *prog;
char *name;
};
struct seg6_local_lwt {
int action;
struct ipv6_sr_hdr *srh;
int table;
struct in_addr nh4;
struct in6_addr nh6;
int iif;
int oif;
struct bpf_lwt_prog bpf;
int headroom;
struct seg6_action_desc *desc;
};
static struct seg6_local_lwt *seg6_local_lwtunnel(struct lwtunnel_state *lwt)
{
return (struct seg6_local_lwt *)lwt->data;
}
static struct ipv6_sr_hdr *get_srh(struct sk_buff *skb)
{
struct ipv6_sr_hdr *srh;
int len, srhoff = 0;
if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0)
return NULL;
if (!pskb_may_pull(skb, srhoff + sizeof(*srh)))
return NULL;
srh = (struct ipv6_sr_hdr *)(skb->data + srhoff);
len = (srh->hdrlen + 1) << 3;
if (!pskb_may_pull(skb, srhoff + len))
return NULL;
/* note that pskb_may_pull may change pointers in header;
* for this reason it is necessary to reload them when needed.
*/
srh = (struct ipv6_sr_hdr *)(skb->data + srhoff);
if (!seg6_validate_srh(srh, len, true))
return NULL;
return srh;
}
static struct ipv6_sr_hdr *get_and_validate_srh(struct sk_buff *skb)
{
struct ipv6_sr_hdr *srh;
srh = get_srh(skb);
if (!srh)
return NULL;
if (srh->segments_left == 0)
return NULL;
#ifdef CONFIG_IPV6_SEG6_HMAC
if (!seg6_hmac_validate_skb(skb))
return NULL;
#endif
return srh;
}
static bool decap_and_validate(struct sk_buff *skb, int proto)
{
struct ipv6_sr_hdr *srh;
unsigned int off = 0;
srh = get_srh(skb);
if (srh && srh->segments_left > 0)
return false;
#ifdef CONFIG_IPV6_SEG6_HMAC
if (srh && !seg6_hmac_validate_skb(skb))
return false;
#endif
if (ipv6_find_hdr(skb, &off, proto, NULL, NULL) < 0)
return false;
if (!pskb_pull(skb, off))
return false;
skb_postpull_rcsum(skb, skb_network_header(skb), off);
skb_reset_network_header(skb);
skb_reset_transport_header(skb);
if (iptunnel_pull_offloads(skb))
return false;
return true;
}
static void advance_nextseg(struct ipv6_sr_hdr *srh, struct in6_addr *daddr)
{
struct in6_addr *addr;
srh->segments_left--;
addr = srh->segments + srh->segments_left;
*daddr = *addr;
}
static int
seg6_lookup_any_nexthop(struct sk_buff *skb, struct in6_addr *nhaddr,
u32 tbl_id, bool local_delivery)
{
struct net *net = dev_net(skb->dev);
struct ipv6hdr *hdr = ipv6_hdr(skb);
int flags = RT6_LOOKUP_F_HAS_SADDR;
struct dst_entry *dst = NULL;
struct rt6_info *rt;
struct flowi6 fl6;
int dev_flags = 0;
fl6.flowi6_iif = skb->dev->ifindex;
fl6.daddr = nhaddr ? *nhaddr : hdr->daddr;
fl6.saddr = hdr->saddr;
fl6.flowlabel = ip6_flowinfo(hdr);
fl6.flowi6_mark = skb->mark;
fl6.flowi6_proto = hdr->nexthdr;
if (nhaddr)
fl6.flowi6_flags = FLOWI_FLAG_KNOWN_NH;
if (!tbl_id) {
dst = ip6_route_input_lookup(net, skb->dev, &fl6, skb, flags);
} else {
struct fib6_table *table;
table = fib6_get_table(net, tbl_id);
if (!table)
goto out;
rt = ip6_pol_route(net, table, 0, &fl6, skb, flags);
dst = &rt->dst;
}
/* we want to discard traffic destined for local packet processing,
* if @local_delivery is set to false.
*/
if (!local_delivery)
dev_flags |= IFF_LOOPBACK;
if (dst && (dst->dev->flags & dev_flags) && !dst->error) {
dst_release(dst);
dst = NULL;
}
out:
if (!dst) {
rt = net->ipv6.ip6_blk_hole_entry;
dst = &rt->dst;
dst_hold(dst);
}
skb_dst_drop(skb);
skb_dst_set(skb, dst);
return dst->error;
}
int seg6_lookup_nexthop(struct sk_buff *skb,
struct in6_addr *nhaddr, u32 tbl_id)
{
return seg6_lookup_any_nexthop(skb, nhaddr, tbl_id, false);
}
/* regular endpoint function */
static int input_action_end(struct sk_buff *skb, struct seg6_local_lwt *slwt)
{
struct ipv6_sr_hdr *srh;
srh = get_and_validate_srh(skb);
if (!srh)
goto drop;
advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
seg6_lookup_nexthop(skb, NULL, 0);
return dst_input(skb);
drop:
kfree_skb(skb);
return -EINVAL;
}
/* regular endpoint, and forward to specified nexthop */
static int input_action_end_x(struct sk_buff *skb, struct seg6_local_lwt *slwt)
{
struct ipv6_sr_hdr *srh;
srh = get_and_validate_srh(skb);
if (!srh)
goto drop;
advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
seg6_lookup_nexthop(skb, &slwt->nh6, 0);
return dst_input(skb);
drop:
kfree_skb(skb);
return -EINVAL;
}
static int input_action_end_t(struct sk_buff *skb, struct seg6_local_lwt *slwt)
{
struct ipv6_sr_hdr *srh;
srh = get_and_validate_srh(skb);
if (!srh)
goto drop;
advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
seg6_lookup_nexthop(skb, NULL, slwt->table);
return dst_input(skb);
drop:
kfree_skb(skb);
return -EINVAL;
}
/* decapsulate and forward inner L2 frame on specified interface */
static int input_action_end_dx2(struct sk_buff *skb,
struct seg6_local_lwt *slwt)
{
struct net *net = dev_net(skb->dev);
struct net_device *odev;
struct ethhdr *eth;
if (!decap_and_validate(skb, IPPROTO_ETHERNET))
goto drop;
if (!pskb_may_pull(skb, ETH_HLEN))
goto drop;
skb_reset_mac_header(skb);
eth = (struct ethhdr *)skb->data;
/* To determine the frame's protocol, we assume it is 802.3. This avoids
* a call to eth_type_trans(), which is not really relevant for our
* use case.
*/
if (!eth_proto_is_802_3(eth->h_proto))
goto drop;
odev = dev_get_by_index_rcu(net, slwt->oif);
if (!odev)
goto drop;
/* As we accept Ethernet frames, make sure the egress device is of
* the correct type.
*/
if (odev->type != ARPHRD_ETHER)
goto drop;
if (!(odev->flags & IFF_UP) || !netif_carrier_ok(odev))
goto drop;
skb_orphan(skb);
if (skb_warn_if_lro(skb))
goto drop;
skb_forward_csum(skb);
if (skb->len - ETH_HLEN > odev->mtu)
goto drop;
skb->dev = odev;
skb->protocol = eth->h_proto;
return dev_queue_xmit(skb);
drop:
kfree_skb(skb);
return -EINVAL;
}
/* decapsulate and forward to specified nexthop */
static int input_action_end_dx6(struct sk_buff *skb,
struct seg6_local_lwt *slwt)
{
struct in6_addr *nhaddr = NULL;
/* this function accepts IPv6 encapsulated packets, with either
* an SRH with SL=0, or no SRH.
*/
if (!decap_and_validate(skb, IPPROTO_IPV6))
goto drop;
if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
goto drop;
/* The inner packet is not associated to any local interface,
* so we do not call netif_rx().
*
* If slwt->nh6 is set to ::, then lookup the nexthop for the
* inner packet's DA. Otherwise, use the specified nexthop.
*/
if (!ipv6_addr_any(&slwt->nh6))
nhaddr = &slwt->nh6;
skb_set_transport_header(skb, sizeof(struct ipv6hdr));
seg6_lookup_nexthop(skb, nhaddr, 0);
return dst_input(skb);
drop:
kfree_skb(skb);
return -EINVAL;
}
static int input_action_end_dx4(struct sk_buff *skb,
struct seg6_local_lwt *slwt)
{
struct iphdr *iph;
__be32 nhaddr;
int err;
if (!decap_and_validate(skb, IPPROTO_IPIP))
goto drop;
if (!pskb_may_pull(skb, sizeof(struct iphdr)))
goto drop;
skb->protocol = htons(ETH_P_IP);
iph = ip_hdr(skb);
nhaddr = slwt->nh4.s_addr ?: iph->daddr;
skb_dst_drop(skb);
skb_set_transport_header(skb, sizeof(struct iphdr));
err = ip_route_input(skb, nhaddr, iph->saddr, 0, skb->dev);
if (err)
goto drop;
return dst_input(skb);
drop:
kfree_skb(skb);
return -EINVAL;
}
static int input_action_end_dt6(struct sk_buff *skb,
struct seg6_local_lwt *slwt)
{
if (!decap_and_validate(skb, IPPROTO_IPV6))
goto drop;
if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
goto drop;
skb_set_transport_header(skb, sizeof(struct ipv6hdr));
seg6_lookup_any_nexthop(skb, NULL, slwt->table, true);
return dst_input(skb);
drop:
kfree_skb(skb);
return -EINVAL;
}
/* push an SRH on top of the current one */
static int input_action_end_b6(struct sk_buff *skb, struct seg6_local_lwt *slwt)
{
struct ipv6_sr_hdr *srh;
int err = -EINVAL;
srh = get_and_validate_srh(skb);
if (!srh)
goto drop;
err = seg6_do_srh_inline(skb, slwt->srh);
if (err)
goto drop;
ipv6_hdr(skb)->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
skb_set_transport_header(skb, sizeof(struct ipv6hdr));
seg6_lookup_nexthop(skb, NULL, 0);
return dst_input(skb);
drop:
kfree_skb(skb);
return err;
}
/* encapsulate within an outer IPv6 header and a specified SRH */
static int input_action_end_b6_encap(struct sk_buff *skb,
struct seg6_local_lwt *slwt)
{
struct ipv6_sr_hdr *srh;
int err = -EINVAL;
srh = get_and_validate_srh(skb);
if (!srh)
goto drop;
advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
skb_reset_inner_headers(skb);
skb->encapsulation = 1;
err = seg6_do_srh_encap(skb, slwt->srh, IPPROTO_IPV6);
if (err)
goto drop;
ipv6_hdr(skb)->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
skb_set_transport_header(skb, sizeof(struct ipv6hdr));
seg6_lookup_nexthop(skb, NULL, 0);
return dst_input(skb);
drop:
kfree_skb(skb);
return err;
}
DEFINE_PER_CPU(struct seg6_bpf_srh_state, seg6_bpf_srh_states);
bool seg6_bpf_has_valid_srh(struct sk_buff *skb)
{
struct seg6_bpf_srh_state *srh_state =
this_cpu_ptr(&seg6_bpf_srh_states);
struct ipv6_sr_hdr *srh = srh_state->srh;
if (unlikely(srh == NULL))
return false;
if (unlikely(!srh_state->valid)) {
if ((srh_state->hdrlen & 7) != 0)
return false;
srh->hdrlen = (u8)(srh_state->hdrlen >> 3);
if (!seg6_validate_srh(srh, (srh->hdrlen + 1) << 3, true))
return false;
srh_state->valid = true;
}
return true;
}
static int input_action_end_bpf(struct sk_buff *skb,
struct seg6_local_lwt *slwt)
{
struct seg6_bpf_srh_state *srh_state =
this_cpu_ptr(&seg6_bpf_srh_states);
struct ipv6_sr_hdr *srh;
int ret;
srh = get_and_validate_srh(skb);
if (!srh) {
kfree_skb(skb);
return -EINVAL;
}
advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
/* preempt_disable is needed to protect the per-CPU buffer srh_state,
* which is also accessed by the bpf_lwt_seg6_* helpers
*/
preempt_disable();
srh_state->srh = srh;
srh_state->hdrlen = srh->hdrlen << 3;
srh_state->valid = true;
rcu_read_lock();
bpf_compute_data_pointers(skb);
ret = bpf_prog_run_save_cb(slwt->bpf.prog, skb);
rcu_read_unlock();
switch (ret) {
case BPF_OK:
case BPF_REDIRECT:
break;
case BPF_DROP:
goto drop;
default:
pr_warn_once("bpf-seg6local: Illegal return value %u\n", ret);
goto drop;
}
if (srh_state->srh && !seg6_bpf_has_valid_srh(skb))
goto drop;
preempt_enable();
if (ret != BPF_REDIRECT)
seg6_lookup_nexthop(skb, NULL, 0);
return dst_input(skb);
drop:
preempt_enable();
kfree_skb(skb);
return -EINVAL;
}
static struct seg6_action_desc seg6_action_table[] = {
{
.action = SEG6_LOCAL_ACTION_END,
.attrs = 0,
.input = input_action_end,
},
{
.action = SEG6_LOCAL_ACTION_END_X,
.attrs = (1 << SEG6_LOCAL_NH6),
.input = input_action_end_x,
},
{
.action = SEG6_LOCAL_ACTION_END_T,
.attrs = (1 << SEG6_LOCAL_TABLE),
.input = input_action_end_t,
},
{
.action = SEG6_LOCAL_ACTION_END_DX2,
.attrs = (1 << SEG6_LOCAL_OIF),
.input = input_action_end_dx2,
},
{
.action = SEG6_LOCAL_ACTION_END_DX6,
.attrs = (1 << SEG6_LOCAL_NH6),
.input = input_action_end_dx6,
},
{
.action = SEG6_LOCAL_ACTION_END_DX4,
.attrs = (1 << SEG6_LOCAL_NH4),
.input = input_action_end_dx4,
},
{
.action = SEG6_LOCAL_ACTION_END_DT6,
.attrs = (1 << SEG6_LOCAL_TABLE),
.input = input_action_end_dt6,
},
{
.action = SEG6_LOCAL_ACTION_END_B6,
.attrs = (1 << SEG6_LOCAL_SRH),
.input = input_action_end_b6,
},
{
.action = SEG6_LOCAL_ACTION_END_B6_ENCAP,
.attrs = (1 << SEG6_LOCAL_SRH),
.input = input_action_end_b6_encap,
.static_headroom = sizeof(struct ipv6hdr),
},
{
.action = SEG6_LOCAL_ACTION_END_BPF,
.attrs = (1 << SEG6_LOCAL_BPF),
.input = input_action_end_bpf,
},
};
static struct seg6_action_desc *__get_action_desc(int action)
{
struct seg6_action_desc *desc;
int i, count;
count = ARRAY_SIZE(seg6_action_table);
for (i = 0; i < count; i++) {
desc = &seg6_action_table[i];
if (desc->action == action)
return desc;
}
return NULL;
}
static int seg6_local_input(struct sk_buff *skb)
{
struct dst_entry *orig_dst = skb_dst(skb);
struct seg6_action_desc *desc;
struct seg6_local_lwt *slwt;
if (skb->protocol != htons(ETH_P_IPV6)) {
kfree_skb(skb);
return -EINVAL;
}
slwt = seg6_local_lwtunnel(orig_dst->lwtstate);
desc = slwt->desc;
return desc->input(skb, slwt);
}
static const struct nla_policy seg6_local_policy[SEG6_LOCAL_MAX + 1] = {
[SEG6_LOCAL_ACTION] = { .type = NLA_U32 },
[SEG6_LOCAL_SRH] = { .type = NLA_BINARY },
[SEG6_LOCAL_TABLE] = { .type = NLA_U32 },
[SEG6_LOCAL_NH4] = { .type = NLA_BINARY,
.len = sizeof(struct in_addr) },
[SEG6_LOCAL_NH6] = { .type = NLA_BINARY,
.len = sizeof(struct in6_addr) },
[SEG6_LOCAL_IIF] = { .type = NLA_U32 },
[SEG6_LOCAL_OIF] = { .type = NLA_U32 },
[SEG6_LOCAL_BPF] = { .type = NLA_NESTED },
};
static int parse_nla_srh(struct nlattr **attrs, struct seg6_local_lwt *slwt)
{
struct ipv6_sr_hdr *srh;
int len;
srh = nla_data(attrs[SEG6_LOCAL_SRH]);
len = nla_len(attrs[SEG6_LOCAL_SRH]);
/* SRH must contain at least one segment */
if (len < sizeof(*srh) + sizeof(struct in6_addr))
return -EINVAL;
if (!seg6_validate_srh(srh, len, false))
return -EINVAL;
slwt->srh = kmemdup(srh, len, GFP_KERNEL);
if (!slwt->srh)
return -ENOMEM;
slwt->headroom += len;
return 0;
}
static int put_nla_srh(struct sk_buff *skb, struct seg6_local_lwt *slwt)
{
struct ipv6_sr_hdr *srh;
struct nlattr *nla;
int len;
srh = slwt->srh;
len = (srh->hdrlen + 1) << 3;
nla = nla_reserve(skb, SEG6_LOCAL_SRH, len);
if (!nla)
return -EMSGSIZE;
memcpy(nla_data(nla), srh, len);
return 0;
}
static int cmp_nla_srh(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
{
int len = (a->srh->hdrlen + 1) << 3;
if (len != ((b->srh->hdrlen + 1) << 3))
return 1;
return memcmp(a->srh, b->srh, len);
}
static int parse_nla_table(struct nlattr **attrs, struct seg6_local_lwt *slwt)
{
slwt->table = nla_get_u32(attrs[SEG6_LOCAL_TABLE]);
return 0;
}
static int put_nla_table(struct sk_buff *skb, struct seg6_local_lwt *slwt)
{
if (nla_put_u32(skb, SEG6_LOCAL_TABLE, slwt->table))
return -EMSGSIZE;
return 0;
}
static int cmp_nla_table(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
{
if (a->table != b->table)
return 1;
return 0;
}
static int parse_nla_nh4(struct nlattr **attrs, struct seg6_local_lwt *slwt)
{
memcpy(&slwt->nh4, nla_data(attrs[SEG6_LOCAL_NH4]),
sizeof(struct in_addr));
return 0;
}
static int put_nla_nh4(struct sk_buff *skb, struct seg6_local_lwt *slwt)
{
struct nlattr *nla;
nla = nla_reserve(skb, SEG6_LOCAL_NH4, sizeof(struct in_addr));
if (!nla)
return -EMSGSIZE;
memcpy(nla_data(nla), &slwt->nh4, sizeof(struct in_addr));
return 0;
}
static int cmp_nla_nh4(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
{
return memcmp(&a->nh4, &b->nh4, sizeof(struct in_addr));
}
static int parse_nla_nh6(struct nlattr **attrs, struct seg6_local_lwt *slwt)
{
memcpy(&slwt->nh6, nla_data(attrs[SEG6_LOCAL_NH6]),
sizeof(struct in6_addr));
return 0;
}
static int put_nla_nh6(struct sk_buff *skb, struct seg6_local_lwt *slwt)
{
struct nlattr *nla;
nla = nla_reserve(skb, SEG6_LOCAL_NH6, sizeof(struct in6_addr));
if (!nla)
return -EMSGSIZE;
memcpy(nla_data(nla), &slwt->nh6, sizeof(struct in6_addr));
return 0;
}
static int cmp_nla_nh6(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
{
return memcmp(&a->nh6, &b->nh6, sizeof(struct in6_addr));
}
static int parse_nla_iif(struct nlattr **attrs, struct seg6_local_lwt *slwt)
{
slwt->iif = nla_get_u32(attrs[SEG6_LOCAL_IIF]);
return 0;
}
static int put_nla_iif(struct sk_buff *skb, struct seg6_local_lwt *slwt)
{
if (nla_put_u32(skb, SEG6_LOCAL_IIF, slwt->iif))
return -EMSGSIZE;
return 0;
}
static int cmp_nla_iif(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
{
if (a->iif != b->iif)
return 1;
return 0;
}
static int parse_nla_oif(struct nlattr **attrs, struct seg6_local_lwt *slwt)
{
slwt->oif = nla_get_u32(attrs[SEG6_LOCAL_OIF]);
return 0;
}
static int put_nla_oif(struct sk_buff *skb, struct seg6_local_lwt *slwt)
{
if (nla_put_u32(skb, SEG6_LOCAL_OIF, slwt->oif))
return -EMSGSIZE;
return 0;
}
static int cmp_nla_oif(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
{
if (a->oif != b->oif)
return 1;
return 0;
}
#define MAX_PROG_NAME 256
static const struct nla_policy bpf_prog_policy[SEG6_LOCAL_BPF_PROG_MAX + 1] = {
[SEG6_LOCAL_BPF_PROG] = { .type = NLA_U32, },
[SEG6_LOCAL_BPF_PROG_NAME] = { .type = NLA_NUL_STRING,
.len = MAX_PROG_NAME },
};
static int parse_nla_bpf(struct nlattr **attrs, struct seg6_local_lwt *slwt)
{
struct nlattr *tb[SEG6_LOCAL_BPF_PROG_MAX + 1];
struct bpf_prog *p;
int ret;
u32 fd;
ret = nla_parse_nested_deprecated(tb, SEG6_LOCAL_BPF_PROG_MAX,
attrs[SEG6_LOCAL_BPF],
bpf_prog_policy, NULL);
if (ret < 0)
return ret;
if (!tb[SEG6_LOCAL_BPF_PROG] || !tb[SEG6_LOCAL_BPF_PROG_NAME])
return -EINVAL;
slwt->bpf.name = nla_memdup(tb[SEG6_LOCAL_BPF_PROG_NAME], GFP_KERNEL);
if (!slwt->bpf.name)
return -ENOMEM;
fd = nla_get_u32(tb[SEG6_LOCAL_BPF_PROG]);
p = bpf_prog_get_type(fd, BPF_PROG_TYPE_LWT_SEG6LOCAL);
if (IS_ERR(p)) {
kfree(slwt->bpf.name);
return PTR_ERR(p);
}
slwt->bpf.prog = p;
return 0;
}
static int put_nla_bpf(struct sk_buff *skb, struct seg6_local_lwt *slwt)
{
struct nlattr *nest;
if (!slwt->bpf.prog)
return 0;
nest = nla_nest_start_noflag(skb, SEG6_LOCAL_BPF);
if (!nest)
return -EMSGSIZE;
if (nla_put_u32(skb, SEG6_LOCAL_BPF_PROG, slwt->bpf.prog->aux->id))
return -EMSGSIZE;
if (slwt->bpf.name &&
nla_put_string(skb, SEG6_LOCAL_BPF_PROG_NAME, slwt->bpf.name))
return -EMSGSIZE;
return nla_nest_end(skb, nest);
}
static int cmp_nla_bpf(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
{
if (!a->bpf.name && !b->bpf.name)
return 0;
if (!a->bpf.name || !b->bpf.name)
return 1;
return strcmp(a->bpf.name, b->bpf.name);
}
struct seg6_action_param {
int (*parse)(struct nlattr **attrs, struct seg6_local_lwt *slwt);
int (*put)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
int (*cmp)(struct seg6_local_lwt *a, struct seg6_local_lwt *b);
};
static struct seg6_action_param seg6_action_params[SEG6_LOCAL_MAX + 1] = {
[SEG6_LOCAL_SRH] = { .parse = parse_nla_srh,
.put = put_nla_srh,
.cmp = cmp_nla_srh },
[SEG6_LOCAL_TABLE] = { .parse = parse_nla_table,
.put = put_nla_table,
.cmp = cmp_nla_table },
[SEG6_LOCAL_NH4] = { .parse = parse_nla_nh4,
.put = put_nla_nh4,
.cmp = cmp_nla_nh4 },
[SEG6_LOCAL_NH6] = { .parse = parse_nla_nh6,
.put = put_nla_nh6,
.cmp = cmp_nla_nh6 },
[SEG6_LOCAL_IIF] = { .parse = parse_nla_iif,
.put = put_nla_iif,
.cmp = cmp_nla_iif },
[SEG6_LOCAL_OIF] = { .parse = parse_nla_oif,
.put = put_nla_oif,
.cmp = cmp_nla_oif },
[SEG6_LOCAL_BPF] = { .parse = parse_nla_bpf,
.put = put_nla_bpf,
.cmp = cmp_nla_bpf },
};
static int parse_nla_action(struct nlattr **attrs, struct seg6_local_lwt *slwt)
{
struct seg6_action_param *param;
struct seg6_action_desc *desc;
int i, err;
desc = __get_action_desc(slwt->action);
if (!desc)
return -EINVAL;
if (!desc->input)
return -EOPNOTSUPP;
slwt->desc = desc;
slwt->headroom += desc->static_headroom;
for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
if (desc->attrs & (1 << i)) {
if (!attrs[i])
return -EINVAL;
param = &seg6_action_params[i];
err = param->parse(attrs, slwt);
if (err < 0)
return err;
}
}
return 0;
}
static int seg6_local_build_state(struct net *net, struct nlattr *nla,
unsigned int family, const void *cfg,
struct lwtunnel_state **ts,
struct netlink_ext_ack *extack)
{
struct nlattr *tb[SEG6_LOCAL_MAX + 1];
struct lwtunnel_state *newts;
struct seg6_local_lwt *slwt;
int err;
if (family != AF_INET6)
return -EINVAL;
err = nla_parse_nested_deprecated(tb, SEG6_LOCAL_MAX, nla,
seg6_local_policy, extack);
if (err < 0)
return err;
if (!tb[SEG6_LOCAL_ACTION])
return -EINVAL;
newts = lwtunnel_state_alloc(sizeof(*slwt));
if (!newts)
return -ENOMEM;
slwt = seg6_local_lwtunnel(newts);
slwt->action = nla_get_u32(tb[SEG6_LOCAL_ACTION]);
err = parse_nla_action(tb, slwt);
if (err < 0)
goto out_free;
newts->type = LWTUNNEL_ENCAP_SEG6_LOCAL;
newts->flags = LWTUNNEL_STATE_INPUT_REDIRECT;
newts->headroom = slwt->headroom;
*ts = newts;
return 0;
out_free:
kfree(slwt->srh);
kfree(newts);
return err;
}
static void seg6_local_destroy_state(struct lwtunnel_state *lwt)
{
struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
kfree(slwt->srh);
if (slwt->desc->attrs & (1 << SEG6_LOCAL_BPF)) {
kfree(slwt->bpf.name);
bpf_prog_put(slwt->bpf.prog);
}
return;
}
static int seg6_local_fill_encap(struct sk_buff *skb,
struct lwtunnel_state *lwt)
{
struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
struct seg6_action_param *param;
int i, err;
if (nla_put_u32(skb, SEG6_LOCAL_ACTION, slwt->action))
return -EMSGSIZE;
for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
if (slwt->desc->attrs & (1 << i)) {
param = &seg6_action_params[i];
err = param->put(skb, slwt);
if (err < 0)
return err;
}
}
return 0;
}
static int seg6_local_get_encap_size(struct lwtunnel_state *lwt)
{
struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
unsigned long attrs;
int nlsize;
nlsize = nla_total_size(4); /* action */
attrs = slwt->desc->attrs;
if (attrs & (1 << SEG6_LOCAL_SRH))
nlsize += nla_total_size((slwt->srh->hdrlen + 1) << 3);
if (attrs & (1 << SEG6_LOCAL_TABLE))
nlsize += nla_total_size(4);
if (attrs & (1 << SEG6_LOCAL_NH4))
nlsize += nla_total_size(4);
if (attrs & (1 << SEG6_LOCAL_NH6))
nlsize += nla_total_size(16);
if (attrs & (1 << SEG6_LOCAL_IIF))
nlsize += nla_total_size(4);
if (attrs & (1 << SEG6_LOCAL_OIF))
nlsize += nla_total_size(4);
if (attrs & (1 << SEG6_LOCAL_BPF))
nlsize += nla_total_size(sizeof(struct nlattr)) +
nla_total_size(MAX_PROG_NAME) +
nla_total_size(4);
return nlsize;
}
static int seg6_local_cmp_encap(struct lwtunnel_state *a,
struct lwtunnel_state *b)
{
struct seg6_local_lwt *slwt_a, *slwt_b;
struct seg6_action_param *param;
int i;
slwt_a = seg6_local_lwtunnel(a);
slwt_b = seg6_local_lwtunnel(b);
if (slwt_a->action != slwt_b->action)
return 1;
if (slwt_a->desc->attrs != slwt_b->desc->attrs)
return 1;
for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
if (slwt_a->desc->attrs & (1 << i)) {
param = &seg6_action_params[i];
if (param->cmp(slwt_a, slwt_b))
return 1;
}
}
return 0;
}
static const struct lwtunnel_encap_ops seg6_local_ops = {
.build_state = seg6_local_build_state,
.destroy_state = seg6_local_destroy_state,
.input = seg6_local_input,
.fill_encap = seg6_local_fill_encap,
.get_encap_size = seg6_local_get_encap_size,
.cmp_encap = seg6_local_cmp_encap,
.owner = THIS_MODULE,
};
int __init seg6_local_init(void)
{
return lwtunnel_encap_add_ops(&seg6_local_ops,
LWTUNNEL_ENCAP_SEG6_LOCAL);
}
void seg6_local_exit(void)
{
lwtunnel_encap_del_ops(&seg6_local_ops, LWTUNNEL_ENCAP_SEG6_LOCAL);
}