diff --git a/net/devlink/devl_internal.h b/net/devlink/devl_internal.h index 14767e809178..6342552e5f99 100644 --- a/net/devlink/devl_internal.h +++ b/net/devlink/devl_internal.h @@ -131,7 +131,8 @@ struct devlink_gen_cmd { extern const struct genl_small_ops devlink_nl_ops[56]; -struct devlink *devlink_get_from_attrs(struct net *net, struct nlattr **attrs); +struct devlink * +devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs); void devlink_notify_unregister(struct devlink *devlink); void devlink_notify_register(struct devlink *devlink); diff --git a/net/devlink/leftover.c b/net/devlink/leftover.c index e6d6c7f74ae7..bec408da4dbe 100644 --- a/net/devlink/leftover.c +++ b/net/devlink/leftover.c @@ -6314,12 +6314,10 @@ static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb, start_offset = state->start_offset; - devlink = devlink_get_from_attrs(sock_net(cb->skb->sk), attrs); + devlink = devlink_get_from_attrs_lock(sock_net(cb->skb->sk), attrs); if (IS_ERR(devlink)) return PTR_ERR(devlink); - devl_lock(devlink); - if (!attrs[DEVLINK_ATTR_REGION_NAME]) { NL_SET_ERR_MSG(cb->extack, "No region name provided"); err = -EINVAL; @@ -7735,9 +7733,10 @@ devlink_health_reporter_get_from_cb(struct netlink_callback *cb) struct nlattr **attrs = info->attrs; struct devlink *devlink; - devlink = devlink_get_from_attrs(sock_net(cb->skb->sk), attrs); + devlink = devlink_get_from_attrs_lock(sock_net(cb->skb->sk), attrs); if (IS_ERR(devlink)) return NULL; + devl_unlock(devlink); reporter = devlink_health_reporter_get_from_attrs(devlink, attrs); devlink_put(devlink); diff --git a/net/devlink/netlink.c b/net/devlink/netlink.c index a552e723f4a6..69111746f5d9 100644 --- a/net/devlink/netlink.c +++ b/net/devlink/netlink.c @@ -82,7 +82,8 @@ static const struct nla_policy devlink_nl_policy[DEVLINK_ATTR_MAX + 1] = { [DEVLINK_ATTR_REGION_DIRECT] = { .type = NLA_FLAG }, }; -struct devlink *devlink_get_from_attrs(struct net *net, struct nlattr **attrs) +struct devlink * +devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs) { struct devlink *devlink; unsigned long index; @@ -96,9 +97,11 @@ struct devlink *devlink_get_from_attrs(struct net *net, struct nlattr **attrs) devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]); devlinks_xa_for_each_registered_get(net, index, devlink) { + devl_lock(devlink); if (strcmp(devlink->dev->bus->name, busname) == 0 && strcmp(dev_name(devlink->dev), devname) == 0) return devlink; + devl_unlock(devlink); devlink_put(devlink); } @@ -113,10 +116,10 @@ static int devlink_nl_pre_doit(const struct genl_split_ops *ops, struct devlink *devlink; int err; - devlink = devlink_get_from_attrs(genl_info_net(info), info->attrs); + devlink = devlink_get_from_attrs_lock(genl_info_net(info), info->attrs); if (IS_ERR(devlink)) return PTR_ERR(devlink); - devl_lock(devlink); + info->user_ptr[0] = devlink; if (ops->internal_flags & DEVLINK_NL_FLAG_NEED_PORT) { devlink_port = devlink_port_get_from_info(devlink, info);