diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index e32ab948f589..d4575a1d6e99 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -68,10 +68,9 @@ void mem_cgroup_migrate(struct page *oldpage, struct page *newpage,
 struct lruvec *mem_cgroup_zone_lruvec(struct zone *, struct mem_cgroup *);
 struct lruvec *mem_cgroup_page_lruvec(struct page *, struct zone *);
 
-bool __mem_cgroup_same_or_subtree(const struct mem_cgroup *root_memcg,
-				  struct mem_cgroup *memcg);
-bool task_in_mem_cgroup(struct task_struct *task,
-			const struct mem_cgroup *memcg);
+bool mem_cgroup_is_descendant(struct mem_cgroup *memcg,
+			      struct mem_cgroup *root);
+bool task_in_mem_cgroup(struct task_struct *task, struct mem_cgroup *memcg);
 
 extern struct mem_cgroup *try_get_mem_cgroup_from_page(struct page *page);
 extern struct mem_cgroup *mem_cgroup_from_task(struct task_struct *p);
@@ -79,8 +78,8 @@ extern struct mem_cgroup *mem_cgroup_from_task(struct task_struct *p);
 extern struct mem_cgroup *parent_mem_cgroup(struct mem_cgroup *memcg);
 extern struct mem_cgroup *mem_cgroup_from_css(struct cgroup_subsys_state *css);
 
-static inline
-bool mm_match_cgroup(const struct mm_struct *mm, const struct mem_cgroup *memcg)
+static inline bool mm_match_cgroup(struct mm_struct *mm,
+				   struct mem_cgroup *memcg)
 {
 	struct mem_cgroup *task_memcg;
 	bool match = false;
@@ -88,7 +87,7 @@ bool mm_match_cgroup(const struct mm_struct *mm, const struct mem_cgroup *memcg)
 	rcu_read_lock();
 	task_memcg = mem_cgroup_from_task(rcu_dereference(mm->owner));
 	if (task_memcg)
-		match = __mem_cgroup_same_or_subtree(memcg, task_memcg);
+		match = mem_cgroup_is_descendant(task_memcg, memcg);
 	rcu_read_unlock();
 	return match;
 }
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index e5dcebd71dfb..b841bf430179 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -1328,41 +1328,24 @@ void mem_cgroup_update_lru_size(struct lruvec *lruvec, enum lru_list lru,
 	VM_BUG_ON((long)(*lru_size) < 0);
 }
 
-/*
- * Checks whether given mem is same or in the root_mem_cgroup's
- * hierarchy subtree
- */
-bool __mem_cgroup_same_or_subtree(const struct mem_cgroup *root_memcg,
-				  struct mem_cgroup *memcg)
+bool mem_cgroup_is_descendant(struct mem_cgroup *memcg, struct mem_cgroup *root)
 {
-	if (root_memcg == memcg)
+	if (root == memcg)
 		return true;
-	if (!root_memcg->use_hierarchy)
+	if (!root->use_hierarchy)
 		return false;
-	return cgroup_is_descendant(memcg->css.cgroup, root_memcg->css.cgroup);
+	return cgroup_is_descendant(memcg->css.cgroup, root->css.cgroup);
 }
 
-static bool mem_cgroup_same_or_subtree(const struct mem_cgroup *root_memcg,
-				       struct mem_cgroup *memcg)
+bool task_in_mem_cgroup(struct task_struct *task, struct mem_cgroup *memcg)
 {
-	bool ret;
-
-	rcu_read_lock();
-	ret = __mem_cgroup_same_or_subtree(root_memcg, memcg);
-	rcu_read_unlock();
-	return ret;
-}
-
-bool task_in_mem_cgroup(struct task_struct *task,
-			const struct mem_cgroup *memcg)
-{
-	struct mem_cgroup *curr;
+	struct mem_cgroup *task_memcg;
 	struct task_struct *p;
 	bool ret;
 
 	p = find_lock_task_mm(task);
 	if (p) {
-		curr = get_mem_cgroup_from_mm(p->mm);
+		task_memcg = get_mem_cgroup_from_mm(p->mm);
 		task_unlock(p);
 	} else {
 		/*
@@ -1371,18 +1354,12 @@ bool task_in_mem_cgroup(struct task_struct *task,
 		 * killed to prevent needlessly killing additional tasks.
 		 */
 		rcu_read_lock();
-		curr = mem_cgroup_from_task(task);
-		css_get(&curr->css);
+		task_memcg = mem_cgroup_from_task(task);
+		css_get(&task_memcg->css);
 		rcu_read_unlock();
 	}
-	/*
-	 * We should check use_hierarchy of "memcg" not "curr". Because checking
-	 * use_hierarchy of "curr" here make this function true if hierarchy is
-	 * enabled in "curr" and "curr" is a child of "memcg" in *cgroup*
-	 * hierarchy(even if use_hierarchy is disabled in "memcg").
-	 */
-	ret = mem_cgroup_same_or_subtree(memcg, curr);
-	css_put(&curr->css);
+	ret = mem_cgroup_is_descendant(task_memcg, memcg);
+	css_put(&task_memcg->css);
 	return ret;
 }
 
@@ -1467,8 +1444,8 @@ static bool mem_cgroup_under_move(struct mem_cgroup *memcg)
 	if (!from)
 		goto unlock;
 
-	ret = mem_cgroup_same_or_subtree(memcg, from)
-		|| mem_cgroup_same_or_subtree(memcg, to);
+	ret = mem_cgroup_is_descendant(from, memcg) ||
+		mem_cgroup_is_descendant(to, memcg);
 unlock:
 	spin_unlock(&mc.lock);
 	return ret;
@@ -1900,12 +1877,8 @@ static int memcg_oom_wake_function(wait_queue_t *wait,
 	oom_wait_info = container_of(wait, struct oom_wait_info, wait);
 	oom_wait_memcg = oom_wait_info->memcg;
 
-	/*
-	 * Both of oom_wait_info->memcg and wake_memcg are stable under us.
-	 * Then we can use css_is_ancestor without taking care of RCU.
-	 */
-	if (!mem_cgroup_same_or_subtree(oom_wait_memcg, wake_memcg)
-		&& !mem_cgroup_same_or_subtree(wake_memcg, oom_wait_memcg))
+	if (!mem_cgroup_is_descendant(wake_memcg, oom_wait_memcg) &&
+	    !mem_cgroup_is_descendant(oom_wait_memcg, wake_memcg))
 		return 0;
 	return autoremove_wake_function(wait, mode, sync, arg);
 }
@@ -2225,7 +2198,7 @@ static void drain_all_stock(struct mem_cgroup *root_memcg)
 		memcg = stock->cached;
 		if (!memcg || !stock->nr_pages)
 			continue;
-		if (!mem_cgroup_same_or_subtree(root_memcg, memcg))
+		if (!mem_cgroup_is_descendant(memcg, root_memcg))
 			continue;
 		if (!test_and_set_bit(FLUSHING_CACHED_CHARGE, &stock->flags)) {
 			if (cpu == curcpu)
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index 5340f6b91312..3b014d326151 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -119,7 +119,7 @@ found:
 
 /* return true if the task is not adequate as candidate victim task. */
 static bool oom_unkillable_task(struct task_struct *p,
-		const struct mem_cgroup *memcg, const nodemask_t *nodemask)
+		struct mem_cgroup *memcg, const nodemask_t *nodemask)
 {
 	if (is_global_init(p))
 		return true;
@@ -353,7 +353,7 @@ static struct task_struct *select_bad_process(unsigned int *ppoints,
  * State information includes task's pid, uid, tgid, vm size, rss, nr_ptes,
  * swapents, oom_score_adj value, and name.
  */
-static void dump_tasks(const struct mem_cgroup *memcg, const nodemask_t *nodemask)
+static void dump_tasks(struct mem_cgroup *memcg, const nodemask_t *nodemask)
 {
 	struct task_struct *p;
 	struct task_struct *task;