diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index e234c21a5e6c..fc10620967c7 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -903,14 +903,20 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
 		if (prev && reclaim->generation != iter->generation)
 			goto out_unlock;
 
-		do {
+		while (1) {
 			pos = READ_ONCE(iter->position);
+			if (!pos || css_tryget(&pos->css))
+				break;
 			/*
-			 * A racing update may change the position and
-			 * put the last reference, hence css_tryget(),
-			 * or retry to see the updated position.
+			 * css reference reached zero, so iter->position will
+			 * be cleared by ->css_released. However, we should not
+			 * rely on this happening soon, because ->css_released
+			 * is called from a work queue, and by busy-waiting we
+			 * might block it. So we clear iter->position right
+			 * away.
 			 */
-		} while (pos && !css_tryget(&pos->css));
+			(void)cmpxchg(&iter->position, pos, NULL);
+		}
 	}
 
 	if (pos)
@@ -956,17 +962,13 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
 	}
 
 	if (reclaim) {
-		if (cmpxchg(&iter->position, pos, memcg) == pos) {
-			if (memcg)
-				css_get(&memcg->css);
-			if (pos)
-				css_put(&pos->css);
-		}
-
 		/*
-		 * pairs with css_tryget when dereferencing iter->position
-		 * above.
+		 * The position could have already been updated by a competing
+		 * thread, so check that the value hasn't changed since we read
+		 * it to avoid reclaiming from the same cgroup twice.
 		 */
+		(void)cmpxchg(&iter->position, pos, memcg);
+
 		if (pos)
 			css_put(&pos->css);
 
@@ -999,6 +1001,28 @@ void mem_cgroup_iter_break(struct mem_cgroup *root,
 		css_put(&prev->css);
 }
 
+static void invalidate_reclaim_iterators(struct mem_cgroup *dead_memcg)
+{
+	struct mem_cgroup *memcg = dead_memcg;
+	struct mem_cgroup_reclaim_iter *iter;
+	struct mem_cgroup_per_zone *mz;
+	int nid, zid;
+	int i;
+
+	while ((memcg = parent_mem_cgroup(memcg))) {
+		for_each_node(nid) {
+			for (zid = 0; zid < MAX_NR_ZONES; zid++) {
+				mz = &memcg->nodeinfo[nid]->zoneinfo[zid];
+				for (i = 0; i <= DEF_PRIORITY; i++) {
+					iter = &mz->iter[i];
+					cmpxchg(&iter->position,
+						dead_memcg, NULL);
+				}
+			}
+		}
+	}
+}
+
 /*
  * Iteration constructs for visiting all cgroups (under a tree).  If
  * loops are exited prematurely (break), mem_cgroup_iter_break() must
@@ -4324,6 +4348,13 @@ static void mem_cgroup_css_offline(struct cgroup_subsys_state *css)
 	wb_memcg_offline(memcg);
 }
 
+static void mem_cgroup_css_released(struct cgroup_subsys_state *css)
+{
+	struct mem_cgroup *memcg = mem_cgroup_from_css(css);
+
+	invalidate_reclaim_iterators(memcg);
+}
+
 static void mem_cgroup_css_free(struct cgroup_subsys_state *css)
 {
 	struct mem_cgroup *memcg = mem_cgroup_from_css(css);
@@ -5185,6 +5216,7 @@ struct cgroup_subsys memory_cgrp_subsys = {
 	.css_alloc = mem_cgroup_css_alloc,
 	.css_online = mem_cgroup_css_online,
 	.css_offline = mem_cgroup_css_offline,
+	.css_released = mem_cgroup_css_released,
 	.css_free = mem_cgroup_css_free,
 	.css_reset = mem_cgroup_css_reset,
 	.can_attach = mem_cgroup_can_attach,