diff --git a/net/core/devlink.c b/net/core/devlink.c index 44f880d3b816..b84cf0df4a0e 100644 --- a/net/core/devlink.c +++ b/net/core/devlink.c @@ -119,7 +119,8 @@ static struct devlink_port *devlink_port_get_from_info(struct devlink *devlink, return devlink_port_get_from_attrs(devlink, info->attrs); } -#define DEVLINK_NL_FLAG_NEED_PORT BIT(0) +#define DEVLINK_NL_FLAG_NEED_DEVLINK BIT(0) +#define DEVLINK_NL_FLAG_NEED_PORT BIT(1) static int devlink_nl_pre_doit(const struct genl_ops *ops, struct sk_buff *skb, struct genl_info *info) @@ -132,8 +133,9 @@ static int devlink_nl_pre_doit(const struct genl_ops *ops, mutex_unlock(&devlink_mutex); return PTR_ERR(devlink); } - info->user_ptr[0] = devlink; - if (ops->internal_flags & DEVLINK_NL_FLAG_NEED_PORT) { + if (ops->internal_flags & DEVLINK_NL_FLAG_NEED_DEVLINK) { + info->user_ptr[0] = devlink; + } else if (ops->internal_flags & DEVLINK_NL_FLAG_NEED_PORT) { struct devlink_port *devlink_port; mutex_lock(&devlink_port_mutex); @@ -143,7 +145,7 @@ static int devlink_nl_pre_doit(const struct genl_ops *ops, mutex_unlock(&devlink_mutex); return PTR_ERR(devlink_port); } - info->user_ptr[1] = devlink_port; + info->user_ptr[0] = devlink_port; } return 0; } @@ -356,8 +358,8 @@ out: static int devlink_nl_cmd_port_get_doit(struct sk_buff *skb, struct genl_info *info) { - struct devlink *devlink = info->user_ptr[0]; - struct devlink_port *devlink_port = info->user_ptr[1]; + struct devlink_port *devlink_port = info->user_ptr[0]; + struct devlink *devlink = devlink_port->devlink; struct sk_buff *msg; int err; @@ -436,8 +438,8 @@ static int devlink_port_type_set(struct devlink *devlink, static int devlink_nl_cmd_port_set_doit(struct sk_buff *skb, struct genl_info *info) { - struct devlink *devlink = info->user_ptr[0]; - struct devlink_port *devlink_port = info->user_ptr[1]; + struct devlink_port *devlink_port = info->user_ptr[0]; + struct devlink *devlink = devlink_port->devlink; int err; if (info->attrs[DEVLINK_ATTR_PORT_TYPE]) { @@ -511,6 +513,7 @@ static const struct genl_ops devlink_nl_ops[] = { .doit = devlink_nl_cmd_get_doit, .dumpit = devlink_nl_cmd_get_dumpit, .policy = devlink_nl_policy, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK, /* can be retrieved by unprivileged users */ }, { @@ -533,12 +536,14 @@ static const struct genl_ops devlink_nl_ops[] = { .doit = devlink_nl_cmd_port_split_doit, .policy = devlink_nl_policy, .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK, }, { .cmd = DEVLINK_CMD_PORT_UNSPLIT, .doit = devlink_nl_cmd_port_unsplit_doit, .policy = devlink_nl_policy, .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK, }, };