diff --git a/net/core/rtnetlink.c b/net/core/rtnetlink.c index 94c4572512b8..a0fad4d8856c 100644 --- a/net/core/rtnetlink.c +++ b/net/core/rtnetlink.c @@ -6410,17 +6410,64 @@ static int rtnl_mdb_add(struct sk_buff *skb, struct nlmsghdr *nlh, return dev->netdev_ops->ndo_mdb_add(dev, tb, nlh->nlmsg_flags, extack); } +static int rtnl_validate_mdb_entry_del_bulk(const struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct br_mdb_entry *entry = nla_data(attr); + struct br_mdb_entry zero_entry = {}; + + if (nla_len(attr) != sizeof(struct br_mdb_entry)) { + NL_SET_ERR_MSG_ATTR(extack, attr, "Invalid attribute length"); + return -EINVAL; + } + + if (entry->state != MDB_PERMANENT && entry->state != MDB_TEMPORARY) { + NL_SET_ERR_MSG(extack, "Unknown entry state"); + return -EINVAL; + } + + if (entry->flags) { + NL_SET_ERR_MSG(extack, "Entry flags cannot be set"); + return -EINVAL; + } + + if (entry->vid >= VLAN_N_VID - 1) { + NL_SET_ERR_MSG(extack, "Invalid entry VLAN id"); + return -EINVAL; + } + + if (memcmp(&entry->addr, &zero_entry.addr, sizeof(entry->addr))) { + NL_SET_ERR_MSG(extack, "Entry address cannot be set"); + return -EINVAL; + } + + return 0; +} + +static const struct nla_policy mdba_del_bulk_policy[MDBA_SET_ENTRY_MAX + 1] = { + [MDBA_SET_ENTRY] = NLA_POLICY_VALIDATE_FN(NLA_BINARY, + rtnl_validate_mdb_entry_del_bulk, + sizeof(struct br_mdb_entry)), + [MDBA_SET_ENTRY_ATTRS] = { .type = NLA_NESTED }, +}; + static int rtnl_mdb_del(struct sk_buff *skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { + bool del_bulk = !!(nlh->nlmsg_flags & NLM_F_BULK); struct nlattr *tb[MDBA_SET_ENTRY_MAX + 1]; struct net *net = sock_net(skb->sk); struct br_port_msg *bpm; struct net_device *dev; int err; - err = nlmsg_parse_deprecated(nlh, sizeof(*bpm), tb, - MDBA_SET_ENTRY_MAX, mdba_policy, extack); + if (!del_bulk) + err = nlmsg_parse_deprecated(nlh, sizeof(*bpm), tb, + MDBA_SET_ENTRY_MAX, mdba_policy, + extack); + else + err = nlmsg_parse(nlh, sizeof(*bpm), tb, MDBA_SET_ENTRY_MAX, + mdba_del_bulk_policy, extack); if (err) return err;