rust: alloc: update VecExt to take allocation flags

We also rename the methods by removing the `try_` prefix since the names
are available due to our usage of the `no_global_oom_handling` config
when building the `alloc` crate.

Reviewed-by: Boqun Feng <boqun.feng@gmail.com>
Signed-off-by: Wedson Almeida Filho <walmeida@microsoft.com>
Reviewed-by: Benno Lossin <benno.lossin@proton.me>
Link: https://lore.kernel.org/r/20240328013603.206764-8-wedsonaf@gmail.com
Signed-off-by: Miguel Ojeda <ojeda@kernel.org>
This commit is contained in:
Wedson Almeida Filho 2024-03-27 22:36:00 -03:00 committed by Miguel Ojeda
parent 08d3f54928
commit 5ab560ce12
6 changed files with 152 additions and 34 deletions

View File

@ -2,47 +2,175 @@
//! Extensions to [`Vec`] for fallible allocations. //! Extensions to [`Vec`] for fallible allocations.
use alloc::{collections::TryReserveError, vec::Vec}; use super::Flags;
use alloc::{alloc::AllocError, vec::Vec};
use core::result::Result; use core::result::Result;
/// Extensions to [`Vec`]. /// Extensions to [`Vec`].
pub trait VecExt<T>: Sized { pub trait VecExt<T>: Sized {
/// Creates a new [`Vec`] instance with at least the given capacity. /// Creates a new [`Vec`] instance with at least the given capacity.
fn try_with_capacity(capacity: usize) -> Result<Self, TryReserveError>; ///
/// # Examples
///
/// ```
/// let v = Vec::<u32>::with_capacity(20, GFP_KERNEL)?;
///
/// assert!(v.capacity() >= 20);
/// # Ok::<(), Error>(())
/// ```
fn with_capacity(capacity: usize, flags: Flags) -> Result<Self, AllocError>;
/// Appends an element to the back of the [`Vec`] instance. /// Appends an element to the back of the [`Vec`] instance.
fn try_push(&mut self, v: T) -> Result<(), TryReserveError>; ///
/// # Examples
///
/// ```
/// let mut v = Vec::new();
/// v.push(1, GFP_KERNEL)?;
/// assert_eq!(&v, &[1]);
///
/// v.push(2, GFP_KERNEL)?;
/// assert_eq!(&v, &[1, 2]);
/// # Ok::<(), Error>(())
/// ```
fn push(&mut self, v: T, flags: Flags) -> Result<(), AllocError>;
/// Pushes clones of the elements of slice into the [`Vec`] instance. /// Pushes clones of the elements of slice into the [`Vec`] instance.
fn try_extend_from_slice(&mut self, other: &[T]) -> Result<(), TryReserveError> ///
/// # Examples
///
/// ```
/// let mut v = Vec::new();
/// v.push(1, GFP_KERNEL)?;
///
/// v.extend_from_slice(&[20, 30, 40], GFP_KERNEL)?;
/// assert_eq!(&v, &[1, 20, 30, 40]);
///
/// v.extend_from_slice(&[50, 60], GFP_KERNEL)?;
/// assert_eq!(&v, &[1, 20, 30, 40, 50, 60]);
/// # Ok::<(), Error>(())
/// ```
fn extend_from_slice(&mut self, other: &[T], flags: Flags) -> Result<(), AllocError>
where where
T: Clone; T: Clone;
/// Ensures that the capacity exceeds the length by at least `additional` elements.
///
/// # Examples
///
/// ```
/// let mut v = Vec::new();
/// v.push(1, GFP_KERNEL)?;
///
/// v.reserve(10, GFP_KERNEL)?;
/// let cap = v.capacity();
/// assert!(cap >= 10);
///
/// v.reserve(10, GFP_KERNEL)?;
/// let new_cap = v.capacity();
/// assert_eq!(new_cap, cap);
///
/// # Ok::<(), Error>(())
/// ```
fn reserve(&mut self, additional: usize, flags: Flags) -> Result<(), AllocError>;
} }
impl<T> VecExt<T> for Vec<T> { impl<T> VecExt<T> for Vec<T> {
fn try_with_capacity(capacity: usize) -> Result<Self, TryReserveError> { fn with_capacity(capacity: usize, flags: Flags) -> Result<Self, AllocError> {
let mut v = Vec::new(); let mut v = Vec::new();
v.try_reserve(capacity)?; <Self as VecExt<_>>::reserve(&mut v, capacity, flags)?;
Ok(v) Ok(v)
} }
fn try_push(&mut self, v: T) -> Result<(), TryReserveError> { fn push(&mut self, v: T, flags: Flags) -> Result<(), AllocError> {
if let Err(retry) = self.push_within_capacity(v) { <Self as VecExt<_>>::reserve(self, 1, flags)?;
self.try_reserve(1)?; let s = self.spare_capacity_mut();
let _ = self.push_within_capacity(retry); s[0].write(v);
}
// SAFETY: We just initialised the first spare entry, so it is safe to increase the length
// by 1. We also know that the new length is <= capacity because of the previous call to
// `reserve` above.
unsafe { self.set_len(self.len() + 1) };
Ok(()) Ok(())
} }
fn try_extend_from_slice(&mut self, other: &[T]) -> Result<(), TryReserveError> fn extend_from_slice(&mut self, other: &[T], flags: Flags) -> Result<(), AllocError>
where where
T: Clone, T: Clone,
{ {
self.try_reserve(other.len())?; <Self as VecExt<_>>::reserve(self, other.len(), flags)?;
for item in other { for (slot, item) in core::iter::zip(self.spare_capacity_mut(), other) {
self.try_push(item.clone())?; slot.write(item.clone());
} }
// SAFETY: We just initialised the `other.len()` spare entries, so it is safe to increase
// the length by the same amount. We also know that the new length is <= capacity because
// of the previous call to `reserve` above.
unsafe { self.set_len(self.len() + other.len()) };
Ok(()) Ok(())
} }
#[cfg(any(test, testlib))]
fn reserve(&mut self, additional: usize, _flags: Flags) -> Result<(), AllocError> {
Vec::reserve(self, additional);
Ok(())
}
#[cfg(not(any(test, testlib)))]
fn reserve(&mut self, additional: usize, flags: Flags) -> Result<(), AllocError> {
let len = self.len();
let cap = self.capacity();
if cap - len >= additional {
return Ok(());
}
if core::mem::size_of::<T>() == 0 {
// The capacity is already `usize::MAX` for SZTs, we can't go higher.
return Err(AllocError);
}
// We know cap is <= `isize::MAX` because `Layout::array` fails if the resulting byte size
// is greater than `isize::MAX`. So the multiplication by two won't overflow.
let new_cap = core::cmp::max(cap * 2, len.checked_add(additional).ok_or(AllocError)?);
let layout = core::alloc::Layout::array::<T>(new_cap).map_err(|_| AllocError)?;
let (ptr, len, cap) = destructure(self);
// SAFETY: `ptr` is valid because it's either NULL or comes from a previous call to
// `krealloc_aligned`. We also verified that the type is not a ZST.
let new_ptr = unsafe { super::allocator::krealloc_aligned(ptr.cast(), layout, flags) };
if new_ptr.is_null() {
// SAFETY: We are just rebuilding the existing `Vec` with no changes.
unsafe { rebuild(self, ptr, len, cap) };
Err(AllocError)
} else {
// SAFETY: `ptr` has been reallocated with the layout for `new_cap` elements. New cap
// is greater than `cap`, so it continues to be >= `len`.
unsafe { rebuild(self, new_ptr.cast::<T>(), len, new_cap) };
Ok(())
}
}
}
#[cfg(not(any(test, testlib)))]
fn destructure<T>(v: &mut Vec<T>) -> (*mut T, usize, usize) {
let mut tmp = Vec::new();
core::mem::swap(&mut tmp, v);
let mut tmp = core::mem::ManuallyDrop::new(tmp);
let len = tmp.len();
let cap = tmp.capacity();
(tmp.as_mut_ptr(), len, cap)
}
/// Rebuilds a `Vec` from a pointer, length, and capacity.
///
/// # Safety
///
/// The same as [`Vec::from_raw_parts`].
#[cfg(not(any(test, testlib)))]
unsafe fn rebuild<T>(v: &mut Vec<T>, ptr: *mut T, len: usize, cap: usize) {
// SAFETY: The safety requirements from this function satisfy those of `from_raw_parts`.
let mut tmp = unsafe { Vec::from_raw_parts(ptr, len, cap) };
core::mem::swap(&mut tmp, v);
} }

View File

@ -6,10 +6,7 @@
use crate::str::CStr; use crate::str::CStr;
use alloc::{ use alloc::alloc::{AllocError, LayoutError};
alloc::{AllocError, LayoutError},
collections::TryReserveError,
};
use core::convert::From; use core::convert::From;
use core::fmt; use core::fmt;
@ -192,12 +189,6 @@ impl From<Utf8Error> for Error {
} }
} }
impl From<TryReserveError> for Error {
fn from(_: TryReserveError) -> Error {
code::ENOMEM
}
}
impl From<LayoutError> for Error { impl From<LayoutError> for Error {
fn from(_: LayoutError) -> Error { fn from(_: LayoutError) -> Error {
code::ENOMEM code::ENOMEM

View File

@ -18,7 +18,6 @@
#![feature(new_uninit)] #![feature(new_uninit)]
#![feature(receiver_trait)] #![feature(receiver_trait)]
#![feature(unsize)] #![feature(unsize)]
#![feature(vec_push_within_capacity)]
// Ensure conditional compilation based on the kernel configuration works; // Ensure conditional compilation based on the kernel configuration works;
// otherwise we may silently break things like initcall handling. // otherwise we may silently break things like initcall handling.

View File

@ -2,7 +2,7 @@
//! String representations. //! String representations.
use crate::alloc::vec_ext::VecExt; use crate::alloc::{flags::*, vec_ext::VecExt};
use alloc::alloc::AllocError; use alloc::alloc::AllocError;
use alloc::vec::Vec; use alloc::vec::Vec;
use core::fmt::{self, Write}; use core::fmt::{self, Write};
@ -807,7 +807,7 @@ impl CString {
let size = f.bytes_written(); let size = f.bytes_written();
// Allocate a vector with the required number of bytes, and write to it. // Allocate a vector with the required number of bytes, and write to it.
let mut buf = Vec::try_with_capacity(size)?; let mut buf = <Vec<_> as VecExt<_>>::with_capacity(size, GFP_KERNEL)?;
// SAFETY: The buffer stored in `buf` is at least of size `size` and is valid for writes. // SAFETY: The buffer stored in `buf` is at least of size `size` and is valid for writes.
let mut f = unsafe { Formatter::from_buffer(buf.as_mut_ptr(), size) }; let mut f = unsafe { Formatter::from_buffer(buf.as_mut_ptr(), size) };
f.write_fmt(args)?; f.write_fmt(args)?;
@ -856,7 +856,7 @@ impl<'a> TryFrom<&'a CStr> for CString {
fn try_from(cstr: &'a CStr) -> Result<CString, AllocError> { fn try_from(cstr: &'a CStr) -> Result<CString, AllocError> {
let mut buf = Vec::new(); let mut buf = Vec::new();
buf.try_extend_from_slice(cstr.as_bytes_with_nul()) <Vec<_> as VecExt<_>>::extend_from_slice(&mut buf, cstr.as_bytes_with_nul(), GFP_KERNEL)
.map_err(|_| AllocError)?; .map_err(|_| AllocError)?;
// INVARIANT: The `CStr` and `CString` types have the same invariants for // INVARIANT: The `CStr` and `CString` types have the same invariants for

View File

@ -157,11 +157,11 @@ impl ForeignOwnable for () {
/// let mut vec = /// let mut vec =
/// ScopeGuard::new_with_data(Vec::new(), |v| pr_info!("vec had {} elements\n", v.len())); /// ScopeGuard::new_with_data(Vec::new(), |v| pr_info!("vec had {} elements\n", v.len()));
/// ///
/// vec.try_push(10u8)?; /// vec.push(10u8, GFP_KERNEL)?;
/// if arg { /// if arg {
/// return Ok(()); /// return Ok(());
/// } /// }
/// vec.try_push(20u8)?; /// vec.push(20u8, GFP_KERNEL)?;
/// Ok(()) /// Ok(())
/// } /// }
/// ///

View File

@ -22,9 +22,9 @@ impl kernel::Module for RustMinimal {
pr_info!("Am I built-in? {}\n", !cfg!(MODULE)); pr_info!("Am I built-in? {}\n", !cfg!(MODULE));
let mut numbers = Vec::new(); let mut numbers = Vec::new();
numbers.try_push(72)?; numbers.push(72, GFP_KERNEL)?;
numbers.try_push(108)?; numbers.push(108, GFP_KERNEL)?;
numbers.try_push(200)?; numbers.push(200, GFP_KERNEL)?;
Ok(RustMinimal { numbers }) Ok(RustMinimal { numbers })
} }