diff --git a/net/rds/message.c b/net/rds/message.c index 116cf87ccb89..c36edbb92c34 100644 --- a/net/rds/message.c +++ b/net/rds/message.c @@ -333,14 +333,14 @@ struct rds_message *rds_message_map_pages(unsigned long *page_addrs, unsigned in return rm; } -int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from, - bool zcopy) +int rds_message_zcopy_from_user(struct rds_message *rm, struct iov_iter *from) { - unsigned long to_copy, nbytes; unsigned long sg_off; struct scatterlist *sg; int ret = 0; int length = iov_iter_count(from); + int total_copied = 0; + struct sk_buff *skb; rm->m_inc.i_hdr.h_len = cpu_to_be32(iov_iter_count(from)); @@ -350,54 +350,66 @@ int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from, sg = rm->data.op_sg; sg_off = 0; /* Dear gcc, sg->page will be null from kzalloc. */ - if (zcopy) { - int total_copied = 0; - struct sk_buff *skb; + skb = alloc_skb(0, GFP_KERNEL); + if (!skb) + return -ENOMEM; + BUILD_BUG_ON(sizeof(skb->cb) < max_t(int, sizeof(struct rds_znotifier), + sizeof(struct rds_zcopy_cookies))); + rm->data.op_mmp_znotifier = RDS_ZCOPY_SKB(skb); + if (mm_account_pinned_pages(&rm->data.op_mmp_znotifier->z_mmp, + length)) { + ret = -ENOMEM; + goto err; + } + while (iov_iter_count(from)) { + struct page *pages; + size_t start; + ssize_t copied; - skb = alloc_skb(0, GFP_KERNEL); - if (!skb) - return -ENOMEM; - BUILD_BUG_ON(sizeof(skb->cb) < - max_t(int, sizeof(struct rds_znotifier), - sizeof(struct rds_zcopy_cookies))); - rm->data.op_mmp_znotifier = RDS_ZCOPY_SKB(skb); - if (mm_account_pinned_pages(&rm->data.op_mmp_znotifier->z_mmp, - length)) { - ret = -ENOMEM; + copied = iov_iter_get_pages(from, &pages, PAGE_SIZE, + 1, &start); + if (copied < 0) { + struct mmpin *mmp; + int i; + + for (i = 0; i < rm->data.op_nents; i++) + put_page(sg_page(&rm->data.op_sg[i])); + mmp = &rm->data.op_mmp_znotifier->z_mmp; + mm_unaccount_pinned_pages(mmp); + ret = -EFAULT; goto err; } - while (iov_iter_count(from)) { - struct page *pages; - size_t start; - ssize_t copied; - - copied = iov_iter_get_pages(from, &pages, PAGE_SIZE, - 1, &start); - if (copied < 0) { - struct mmpin *mmp; - int i; - - for (i = 0; i < rm->data.op_nents; i++) - put_page(sg_page(&rm->data.op_sg[i])); - mmp = &rm->data.op_mmp_znotifier->z_mmp; - mm_unaccount_pinned_pages(mmp); - ret = -EFAULT; - goto err; - } - total_copied += copied; - iov_iter_advance(from, copied); - length -= copied; - sg_set_page(sg, pages, copied, start); - rm->data.op_nents++; - sg++; - } - WARN_ON_ONCE(length != 0); - return ret; + total_copied += copied; + iov_iter_advance(from, copied); + length -= copied; + sg_set_page(sg, pages, copied, start); + rm->data.op_nents++; + sg++; + } + WARN_ON_ONCE(length != 0); + return ret; err: - consume_skb(skb); - rm->data.op_mmp_znotifier = NULL; - return ret; - } /* zcopy */ + consume_skb(skb); + rm->data.op_mmp_znotifier = NULL; + return ret; +} + +int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from, + bool zcopy) +{ + unsigned long to_copy, nbytes; + unsigned long sg_off; + struct scatterlist *sg; + int ret = 0; + + rm->m_inc.i_hdr.h_len = cpu_to_be32(iov_iter_count(from)); + + /* now allocate and copy in the data payload. */ + sg = rm->data.op_sg; + sg_off = 0; /* Dear gcc, sg->page will be null from kzalloc. */ + + if (zcopy) + return rds_message_zcopy_from_user(rm, from); while (iov_iter_count(from)) { if (!sg_page(sg)) {