From 4c173f596b3ffe6b967f5818043665c565648809 Mon Sep 17 00:00:00 2001 From: Jason Gunthorpe Date: Tue, 12 Feb 2019 21:12:52 -0700 Subject: [PATCH] RDMA/rxe: Use ib_device_get_by_netdev() instead of open coding The core API handles the locking correctly and is faster if there are multiple devices. Signed-off-by: Jason Gunthorpe --- drivers/infiniband/sw/rxe/rxe.h | 12 ++++++++- drivers/infiniband/sw/rxe/rxe_net.c | 39 +++++++++++---------------- drivers/infiniband/sw/rxe/rxe_sysfs.c | 17 ++++++------ drivers/infiniband/sw/rxe/rxe_verbs.c | 17 +++--------- 4 files changed, 38 insertions(+), 47 deletions(-) diff --git a/drivers/infiniband/sw/rxe/rxe.h b/drivers/infiniband/sw/rxe/rxe.h index 5bde2ad964d2..2b875875962f 100644 --- a/drivers/infiniband/sw/rxe/rxe.h +++ b/drivers/infiniband/sw/rxe/rxe.h @@ -105,9 +105,19 @@ static inline void rxe_dev_put(struct rxe_dev *rxe) { kref_put(&rxe->ref_cnt, rxe_release); } -struct rxe_dev *net_to_rxe(struct net_device *ndev); struct rxe_dev *get_rxe_by_name(const char *name); +/* The caller must do a matching ib_device_put(&dev->ib_dev) */ +static inline struct rxe_dev *rxe_get_dev_from_net(struct net_device *ndev) +{ + struct ib_device *ibdev = + ib_device_get_by_netdev(ndev, RDMA_DRIVER_RXE); + + if (!ibdev) + return NULL; + return container_of(ibdev, struct rxe_dev, ib_dev); +} + void rxe_port_up(struct rxe_dev *rxe); void rxe_port_down(struct rxe_dev *rxe); void rxe_set_port_state(struct rxe_dev *rxe); diff --git a/drivers/infiniband/sw/rxe/rxe_net.c b/drivers/infiniband/sw/rxe/rxe_net.c index 3b162e92e8e8..56878453f1ae 100644 --- a/drivers/infiniband/sw/rxe/rxe_net.c +++ b/drivers/infiniband/sw/rxe/rxe_net.c @@ -48,23 +48,6 @@ static LIST_HEAD(rxe_dev_list); static DEFINE_SPINLOCK(dev_list_lock); /* spinlock for device list */ -struct rxe_dev *net_to_rxe(struct net_device *ndev) -{ - struct rxe_dev *rxe; - struct rxe_dev *found = NULL; - - spin_lock_bh(&dev_list_lock); - list_for_each_entry(rxe, &rxe_dev_list, list) { - if (rxe->ndev == ndev) { - found = rxe; - break; - } - } - spin_unlock_bh(&dev_list_lock); - - return found; -} - struct rxe_dev *get_rxe_by_name(const char *name) { struct rxe_dev *rxe; @@ -81,7 +64,6 @@ struct rxe_dev *get_rxe_by_name(const char *name) return found; } - static struct rxe_recv_sockets recv_sockets; struct device *rxe_dma_device(struct rxe_dev *rxe) @@ -229,18 +211,19 @@ static int rxe_udp_encap_recv(struct sock *sk, struct sk_buff *skb) struct udphdr *udph; struct net_device *ndev = skb->dev; struct net_device *rdev = ndev; - struct rxe_dev *rxe = net_to_rxe(ndev); + struct rxe_dev *rxe = rxe_get_dev_from_net(ndev); struct rxe_pkt_info *pkt = SKB_TO_PKT(skb); if (!rxe && is_vlan_dev(rdev)) { rdev = vlan_dev_real_dev(ndev); - rxe = net_to_rxe(rdev); + rxe = rxe_get_dev_from_net(rdev); } if (!rxe) goto drop; if (skb_linearize(skb)) { pr_err("skb_linearize failed\n"); + ib_device_put(&rxe->ib_dev); goto drop; } @@ -253,6 +236,12 @@ static int rxe_udp_encap_recv(struct sock *sk, struct sk_buff *skb) rxe_rcv(skb); + /* + * FIXME: this is in the wrong place, it needs to be done when pkt is + * destroyed + */ + ib_device_put(&rxe->ib_dev); + return 0; drop: kfree_skb(skb); @@ -635,16 +624,17 @@ static int rxe_notify(struct notifier_block *not_blk, void *arg) { struct net_device *ndev = netdev_notifier_info_to_dev(arg); - struct rxe_dev *rxe = net_to_rxe(ndev); + struct rxe_dev *rxe = rxe_get_dev_from_net(ndev); if (!rxe) - goto out; + return NOTIFY_OK; switch (event) { case NETDEV_UNREGISTER: list_del(&rxe->list); + ib_device_put(&rxe->ib_dev); rxe_remove(rxe); - break; + return NOTIFY_OK; case NETDEV_UP: rxe_port_up(rxe); break; @@ -668,7 +658,8 @@ static int rxe_notify(struct notifier_block *not_blk, event, ndev->name); break; } -out: + + ib_device_put(&rxe->ib_dev); return NOTIFY_OK; } diff --git a/drivers/infiniband/sw/rxe/rxe_sysfs.c b/drivers/infiniband/sw/rxe/rxe_sysfs.c index 95a15892f7e6..6802be71bf9b 100644 --- a/drivers/infiniband/sw/rxe/rxe_sysfs.c +++ b/drivers/infiniband/sw/rxe/rxe_sysfs.c @@ -58,24 +58,25 @@ static int rxe_param_set_add(const char *val, const struct kernel_param *kp) int len; int err = 0; char intf[32]; - struct net_device *ndev = NULL; + struct net_device *ndev; + struct rxe_dev *exists; struct rxe_dev *rxe; len = sanitize_arg(val, intf, sizeof(intf)); if (!len) { pr_err("add: invalid interface name\n"); - err = -EINVAL; - goto err; + return -EINVAL; } ndev = dev_get_by_name(&init_net, intf); if (!ndev) { pr_err("interface %s not found\n", intf); - err = -EINVAL; - goto err; + return -EINVAL; } - if (net_to_rxe(ndev)) { + exists = rxe_get_dev_from_net(ndev); + if (exists) { + ib_device_put(&exists->ib_dev); pr_err("already configured on %s\n", intf); err = -EINVAL; goto err; @@ -90,9 +91,9 @@ static int rxe_param_set_add(const char *val, const struct kernel_param *kp) rxe_set_port_state(rxe); dev_info(&rxe->ib_dev.dev, "added %s\n", intf); + err: - if (ndev) - dev_put(ndev); + dev_put(ndev); return err; } diff --git a/drivers/infiniband/sw/rxe/rxe_verbs.c b/drivers/infiniband/sw/rxe/rxe_verbs.c index ffca654c8697..55f793ed1e77 100644 --- a/drivers/infiniband/sw/rxe/rxe_verbs.c +++ b/drivers/infiniband/sw/rxe/rxe_verbs.c @@ -80,19 +80,6 @@ static int rxe_query_port(struct ib_device *dev, return rc; } -static struct net_device *rxe_get_netdev(struct ib_device *device, - u8 port_num) -{ - struct rxe_dev *rxe = to_rdev(device); - - if (rxe->ndev) { - dev_hold(rxe->ndev); - return rxe->ndev; - } - - return NULL; -} - static int rxe_query_pkey(struct ib_device *device, u8 port_num, u16 index, u16 *pkey) { @@ -1159,7 +1146,6 @@ static const struct ib_device_ops rxe_dev_ops = { .get_dma_mr = rxe_get_dma_mr, .get_hw_stats = rxe_ib_get_hw_stats, .get_link_layer = rxe_get_link_layer, - .get_netdev = rxe_get_netdev, .get_port_immutable = rxe_port_immutable, .map_mr_sg = rxe_map_mr_sg, .mmap = rxe_mmap, @@ -1240,6 +1226,9 @@ int rxe_register_device(struct rxe_dev *rxe) ; ib_set_device_ops(dev, &rxe_dev_ops); + err = ib_device_set_netdev(&rxe->ib_dev, rxe->ndev, 1); + if (err) + return err; tfm = crypto_alloc_shash("crc32", 0, 0); if (IS_ERR(tfm)) {