diff --git a/tools/lib/bpf/xsk.c b/tools/lib/bpf/xsk.c
index 86c1b61017f6..8ebd810931ab 100644
--- a/tools/lib/bpf/xsk.c
+++ b/tools/lib/bpf/xsk.c
@@ -586,15 +586,21 @@ int xsk_socket__create(struct xsk_socket **xsk_ptr, const char *ifname,
 	if (!umem || !xsk_ptr || !rx || !tx)
 		return -EFAULT;
 
-	if (umem->refcount) {
-		pr_warn("Error: shared umems not supported by libbpf.\n");
-		return -EBUSY;
-	}
-
 	xsk = calloc(1, sizeof(*xsk));
 	if (!xsk)
 		return -ENOMEM;
 
+	err = xsk_set_xdp_socket_config(&xsk->config, usr_config);
+	if (err)
+		goto out_xsk_alloc;
+
+	if (umem->refcount &&
+	    !(xsk->config.libbpf_flags & XSK_LIBBPF_FLAGS__INHIBIT_PROG_LOAD)) {
+		pr_warn("Error: shared umems not supported by libbpf supplied XDP program.\n");
+		err = -EBUSY;
+		goto out_xsk_alloc;
+	}
+
 	if (umem->refcount++ > 0) {
 		xsk->fd = socket(AF_XDP, SOCK_RAW, 0);
 		if (xsk->fd < 0) {
@@ -616,10 +622,6 @@ int xsk_socket__create(struct xsk_socket **xsk_ptr, const char *ifname,
 	memcpy(xsk->ifname, ifname, IFNAMSIZ - 1);
 	xsk->ifname[IFNAMSIZ - 1] = '\0';
 
-	err = xsk_set_xdp_socket_config(&xsk->config, usr_config);
-	if (err)
-		goto out_socket;
-
 	if (rx) {
 		err = setsockopt(xsk->fd, SOL_XDP, XDP_RX_RING,
 				 &xsk->config.rx_size,
@@ -687,7 +689,12 @@ int xsk_socket__create(struct xsk_socket **xsk_ptr, const char *ifname,
 	sxdp.sxdp_family = PF_XDP;
 	sxdp.sxdp_ifindex = xsk->ifindex;
 	sxdp.sxdp_queue_id = xsk->queue_id;
-	sxdp.sxdp_flags = xsk->config.bind_flags;
+	if (umem->refcount > 1) {
+		sxdp.sxdp_flags = XDP_SHARED_UMEM;
+		sxdp.sxdp_shared_umem_fd = umem->fd;
+	} else {
+		sxdp.sxdp_flags = xsk->config.bind_flags;
+	}
 
 	err = bind(xsk->fd, (struct sockaddr *)&sxdp, sizeof(sxdp));
 	if (err) {