diff --git a/include/linux/kthread.h b/include/linux/kthread.h
index 0714b24c0e45..22ccf9dee177 100644
--- a/include/linux/kthread.h
+++ b/include/linux/kthread.h
@@ -49,8 +49,6 @@ extern int tsk_fork_get_node(struct task_struct *tsk);
  * can be queued and flushed using queue/flush_kthread_work()
  * respectively.  Queued kthread_works are processed by a kthread
  * running kthread_worker_fn().
- *
- * A kthread_work can't be freed while it is executing.
  */
 struct kthread_work;
 typedef void (*kthread_work_func_t)(struct kthread_work *work);
@@ -59,15 +57,14 @@ struct kthread_worker {
 	spinlock_t		lock;
 	struct list_head	work_list;
 	struct task_struct	*task;
+	struct kthread_work	*current_work;
 };
 
 struct kthread_work {
 	struct list_head	node;
 	kthread_work_func_t	func;
 	wait_queue_head_t	done;
-	atomic_t		flushing;
-	int			queue_seq;
-	int			done_seq;
+	struct kthread_worker	*worker;
 };
 
 #define KTHREAD_WORKER_INIT(worker)	{				\
@@ -79,7 +76,6 @@ struct kthread_work {
 	.node = LIST_HEAD_INIT((work).node),				\
 	.func = (fn),							\
 	.done = __WAIT_QUEUE_HEAD_INITIALIZER((work).done),		\
-	.flushing = ATOMIC_INIT(0),					\
 	}
 
 #define DEFINE_KTHREAD_WORKER(worker)					\
diff --git a/kernel/kthread.c b/kernel/kthread.c
index 4bfbff36d447..b579af57ea10 100644
--- a/kernel/kthread.c
+++ b/kernel/kthread.c
@@ -360,16 +360,12 @@ repeat:
 					struct kthread_work, node);
 		list_del_init(&work->node);
 	}
+	worker->current_work = work;
 	spin_unlock_irq(&worker->lock);
 
 	if (work) {
 		__set_current_state(TASK_RUNNING);
 		work->func(work);
-		smp_wmb();	/* wmb worker-b0 paired with flush-b1 */
-		work->done_seq = work->queue_seq;
-		smp_mb();	/* mb worker-b1 paired with flush-b0 */
-		if (atomic_read(&work->flushing))
-			wake_up_all(&work->done);
 	} else if (!freezing(current))
 		schedule();
 
@@ -386,7 +382,7 @@ static void insert_kthread_work(struct kthread_worker *worker,
 	lockdep_assert_held(&worker->lock);
 
 	list_add_tail(&work->node, pos);
-	work->queue_seq++;
+	work->worker = worker;
 	if (likely(worker->task))
 		wake_up_process(worker->task);
 }
@@ -436,25 +432,35 @@ static void kthread_flush_work_fn(struct kthread_work *work)
  */
 void flush_kthread_work(struct kthread_work *work)
 {
-	int seq = work->queue_seq;
+	struct kthread_flush_work fwork = {
+		KTHREAD_WORK_INIT(fwork.work, kthread_flush_work_fn),
+		COMPLETION_INITIALIZER_ONSTACK(fwork.done),
+	};
+	struct kthread_worker *worker;
+	bool noop = false;
 
-	atomic_inc(&work->flushing);
+retry:
+	worker = work->worker;
+	if (!worker)
+		return;
 
-	/*
-	 * mb flush-b0 paired with worker-b1, to make sure either
-	 * worker sees the above increment or we see done_seq update.
-	 */
-	smp_mb__after_atomic_inc();
+	spin_lock_irq(&worker->lock);
+	if (work->worker != worker) {
+		spin_unlock_irq(&worker->lock);
+		goto retry;
+	}
 
-	/* A - B <= 0 tests whether B is in front of A regardless of overflow */
-	wait_event(work->done, seq - work->done_seq <= 0);
-	atomic_dec(&work->flushing);
+	if (!list_empty(&work->node))
+		insert_kthread_work(worker, &fwork.work, work->node.next);
+	else if (worker->current_work == work)
+		insert_kthread_work(worker, &fwork.work, worker->work_list.next);
+	else
+		noop = true;
 
-	/*
-	 * rmb flush-b1 paired with worker-b0, to make sure our caller
-	 * sees every change made by work->func().
-	 */
-	smp_mb__after_atomic_dec();
+	spin_unlock_irq(&worker->lock);
+
+	if (!noop)
+		wait_for_completion(&fwork.done);
 }
 EXPORT_SYMBOL_GPL(flush_kthread_work);