diff --git a/Cargo.toml b/Cargo.toml index c56b81d9e..438083a40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -170,6 +170,9 @@ termios = [] # Enable `rustix::mm::*`. mm = [] +# Enable `rustix::numa::*`. +numa = [] + # Enable `rustix::pipe::*`. pipe = [] @@ -194,6 +197,7 @@ all-apis = [ "mm", "mount", "net", + "numa", "param", "pipe", "process", diff --git a/src/backend/linux_raw/conv.rs b/src/backend/linux_raw/conv.rs index 89a174c46..a6b0885dc 100644 --- a/src/backend/linux_raw/conv.rs +++ b/src/backend/linux_raw/conv.rs @@ -818,6 +818,22 @@ impl<'a, Num: ArgNumber> From> for ArgReg<'a, Num> } } +#[cfg(feature = "numa")] +impl<'a, Num: ArgNumber> From for ArgReg<'a, Num> { + #[inline] + fn from(flags: crate::numa::Mode) -> Self { + c_uint(flags.bits()) + } +} + +#[cfg(feature = "numa")] +impl<'a, Num: ArgNumber> From for ArgReg<'a, Num> { + #[inline] + fn from(flags: crate::numa::ModeFlags) -> Self { + c_uint(flags.bits()) + } +} + impl<'a, Num: ArgNumber, T> From<&'a mut MaybeUninit> for ArgReg<'a, Num> { #[inline] fn from(t: &'a mut MaybeUninit) -> Self { diff --git a/src/backend/linux_raw/mod.rs b/src/backend/linux_raw/mod.rs index 0d4e5332d..ad55af5cb 100644 --- a/src/backend/linux_raw/mod.rs +++ b/src/backend/linux_raw/mod.rs @@ -51,6 +51,8 @@ pub(crate) mod mount; pub(crate) mod mount; // for deprecated mount functions in "fs" #[cfg(feature = "net")] pub(crate) mod net; +#[cfg(feature = "numa")] +pub(crate) mod numa; #[cfg(any( feature = "param", feature = "process", diff --git a/src/backend/linux_raw/numa/mod.rs b/src/backend/linux_raw/numa/mod.rs new file mode 100644 index 000000000..1e0181a99 --- /dev/null +++ b/src/backend/linux_raw/numa/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod syscalls; +pub(crate) mod types; diff --git a/src/backend/linux_raw/numa/syscalls.rs b/src/backend/linux_raw/numa/syscalls.rs new file mode 100644 index 000000000..d6644ccee --- /dev/null +++ b/src/backend/linux_raw/numa/syscalls.rs @@ -0,0 +1,92 @@ +//! linux_raw syscalls supporting `rustix::numa`. +//! +//! # Safety +//! +//! See the `rustix::backend` module documentation for details. + +#![allow(unsafe_code)] +#![allow(clippy::undocumented_unsafe_blocks)] + +use super::types::{Mode, ModeFlags}; + +use crate::backend::c; +use crate::backend::conv::{c_uint, pass_usize, ret, zero}; +use crate::io; +use core::mem::MaybeUninit; + +/// # Safety +/// +/// `mbind` is primarily unsafe due to the `addr` parameter, as anything +/// working with memory pointed to by raw pointers is unsafe. +#[inline] +pub(crate) unsafe fn mbind( + addr: *mut c::c_void, + length: usize, + mode: Mode, + nodemask: &[u64], + flags: ModeFlags, +) -> io::Result<()> { + ret(syscall!( + __NR_mbind, + addr, + pass_usize(length), + mode, + nodemask.as_ptr(), + pass_usize(nodemask.len() * u64::BITS as usize), + flags + )) +} + +/// # Safety +/// +/// `set_mempolicy` is primarily unsafe due to the `addr` parameter, +/// as anything working with memory pointed to by raw pointers is +/// unsafe. +#[inline] +pub(crate) unsafe fn set_mempolicy(mode: Mode, nodemask: &[u64]) -> io::Result<()> { + ret(syscall!( + __NR_set_mempolicy, + mode, + nodemask.as_ptr(), + pass_usize(nodemask.len() * u64::BITS as usize) + )) +} + +/// # Safety +/// +/// `get_mempolicy` is primarily unsafe due to the `addr` parameter, +/// as anything working with memory pointed to by raw pointers is +/// unsafe. +#[inline] +pub(crate) unsafe fn get_mempolicy_node(addr: *mut c::c_void) -> io::Result { + let mut mode = MaybeUninit::::uninit(); + + ret(syscall!( + __NR_get_mempolicy, + &mut mode, + zero(), + zero(), + addr, + c_uint(linux_raw_sys::general::MPOL_F_NODE | linux_raw_sys::general::MPOL_F_ADDR) + ))?; + + Ok(mode.assume_init()) +} + +#[inline] +pub(crate) fn get_mempolicy_next_node() -> io::Result { + let mut mode = MaybeUninit::::uninit(); + + unsafe { + ret(syscall!( + __NR_get_mempolicy, + &mut mode, + zero(), + zero(), + zero(), + c_uint(linux_raw_sys::general::MPOL_F_NODE) + ))?; + + Ok(mode.assume_init()) + } +} diff --git a/src/backend/linux_raw/numa/types.rs b/src/backend/linux_raw/numa/types.rs new file mode 100644 index 000000000..88991c7d5 --- /dev/null +++ b/src/backend/linux_raw/numa/types.rs @@ -0,0 +1,52 @@ +use bitflags::bitflags; + +bitflags! { + /// `MPOL_*` and `MPOL_F_*` flags for use with [`mbind`]. + /// + /// [`mbind`]: crate::io::mbind + #[repr(transparent)] + #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] + pub struct Mode: u32 { + /// `MPOL_F_STATIC_NODES` + const STATIC_NODES = linux_raw_sys::general::MPOL_F_STATIC_NODES; + /// `MPOL_F_RELATIVE_NODES` + const RELATIVE_NODES = linux_raw_sys::general::MPOL_F_RELATIVE_NODES; + /// `MPOL_F_NUMA_BALANCING` + const NUMA_BALANCING = linux_raw_sys::general::MPOL_F_NUMA_BALANCING; + + /// `MPOL_DEFAULT` + const DEFAULT = linux_raw_sys::general::MPOL_DEFAULT as u32; + /// `MPOL_PREFERRED` + const PREFERRED = linux_raw_sys::general::MPOL_PREFERRED as u32; + /// `MPOL_BIND` + const BIND = linux_raw_sys::general::MPOL_BIND as u32; + /// `MPOL_INTERLEAVE` + const INTERLEAVE = linux_raw_sys::general::MPOL_INTERLEAVE as u32; + /// `MPOL_LOCAL` + const LOCAL = linux_raw_sys::general::MPOL_LOCAL as u32; + /// `MPOL_PREFERRED_MANY` + const PREFERRED_MANY = linux_raw_sys::general::MPOL_PREFERRED_MANY as u32; + + /// + const _ = !0; + } +} + +bitflags! { + /// `MPOL_MF_*` flags for use with [`mbind`]. + /// + /// [`mbind`]: crate::io::mbind + #[repr(transparent)] + #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] + pub struct ModeFlags: u32 { + /// `MPOL_MF_STRICT` + const STRICT = linux_raw_sys::general::MPOL_MF_STRICT; + /// `MPOL_MF_MOVE` + const MOVE = linux_raw_sys::general::MPOL_MF_MOVE; + /// `MPOL_MF_MOVE_ALL` + const MOVE_ALL = linux_raw_sys::general::MPOL_MF_MOVE_ALL; + + /// + const _ = !0; + } +} diff --git a/src/lib.rs b/src/lib.rs index 170e3f958..66da372ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -218,6 +218,10 @@ pub mod mount; #[cfg(feature = "net")] #[cfg_attr(doc_cfg, doc(cfg(feature = "net")))] pub mod net; +#[cfg(linux_kernel)] +#[cfg(feature = "numa")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "numa")))] +pub mod numa; #[cfg(not(any(windows, target_os = "espidf")))] #[cfg(feature = "param")] #[cfg_attr(doc_cfg, doc(cfg(feature = "param")))] diff --git a/src/numa/mod.rs b/src/numa/mod.rs new file mode 100644 index 000000000..92e928733 --- /dev/null +++ b/src/numa/mod.rs @@ -0,0 +1,108 @@ +//! The `numa` API. +//! +//! # Safety +//! +//! `mbind` and related functions manipulate raw pointers and have special +//! semantics and are wildly unsafe. +#![allow(unsafe_code)] + +use crate::{backend, io}; +use core::ffi::c_void; + +pub use backend::numa::types::{Mode, ModeFlags}; + +/// `mbind(addr, len, mode, nodemask)`-Set memory policy for a memory range. +/// +/// # Safety +/// +/// This function operates on raw pointers, but it should only be used +/// on memory which the caller owns. +/// +/// # References +/// - [Linux] +/// +/// [Linux]: https://man7.org/linux/man-pages/man2/mbind.2.html +#[cfg(linux_kernel)] +#[inline] +pub unsafe fn mbind( + addr: *mut c_void, + len: usize, + mode: Mode, + nodemask: &[u64], + flags: ModeFlags, +) -> io::Result<()> { + backend::numa::syscalls::mbind(addr, len, mode, nodemask, flags) +} + +/// `set_mempolicy(mode, nodemask)`-Set default NUMA memory policy for +/// a thread and its children. +/// +/// # Safety +/// +/// This function operates on raw pointers, but it should only be used +/// on memory which the caller owns. +/// +/// # References +/// - [Linux] +/// +/// [Linux]: https://man7.org/linux/man-pages/man2/set_mempolicy.2.html +#[cfg(linux_kernel)] +#[inline] +pub unsafe fn set_mempolicy(mode: Mode, nodemask: &[u64]) -> io::Result<()> { + backend::numa::syscalls::set_mempolicy(mode, nodemask) +} + +/// `get_mempolicy_node(addr)`-Return the node ID of the node on which +/// the address addr is allocated. +/// +/// If flags specifies both MPOL_F_NODE and MPOL_F_ADDR, +/// get_mempolicy() will return the node ID of the node on which the +/// address addr is allocated into the location pointed to by mode. +/// If no page has yet been allocated for the specified address, +/// get_mempolicy() will allocate a page as if the thread had +/// performed a read (load) access to that address, and return the ID +/// of the node where that page was allocated. +/// +/// # Safety +/// +/// This function operates on raw pointers, but it should only be used +/// on memory which the caller owns. +/// +/// # References +/// - [Linux] +/// +/// [Linux]: https://man7.org/linux/man-pages/man2/get_mempolicy.2.html +#[cfg(linux_kernel)] +#[inline] +pub unsafe fn get_mempolicy_node(addr: *mut c_void) -> io::Result { + backend::numa::syscalls::get_mempolicy_node(addr) +} + +/// `get_mempolicy_next_node(addr)`-Return node ID of the next node +/// that will be used for interleaving of internal kernel pages +/// allocated on behalf of the thread. +/// +/// If flags specifies MPOL_F_NODE, but not MPOL_F_ADDR, and the +/// thread's current policy is MPOL_INTERLEAVE, then get_mempolicy() +/// will return in the location pointed to by a non-NULL mode +/// argument, the node ID of the next node that will be used for +/// interleaving of internal kernel pages allocated on behalf of the +/// thread. These allocations include pages for memory-mapped files +/// in process memory ranges mapped using the mmap(2) call with the +/// MAP_PRIVATE flag for read accesses, and in memory ranges mapped +/// with the MAP_SHARED flag for all accesses. +/// +/// # Safety +/// +/// This function operates on raw pointers, but it should only be used +/// on memory which the caller owns. +/// +/// # References +/// - [Linux] +/// +/// [Linux]: https://man7.org/linux/man-pages/man2/get_mempolicy.2.html +#[cfg(linux_kernel)] +#[inline] +pub unsafe fn get_mempolicy_next_node() -> io::Result { + backend::numa::syscalls::get_mempolicy_next_node() +} diff --git a/tests/numa/main.rs b/tests/numa/main.rs new file mode 100644 index 000000000..4470c31ba --- /dev/null +++ b/tests/numa/main.rs @@ -0,0 +1,40 @@ +#[cfg(all(feature = "mm", feature = "fs"))] +#[test] +fn test_mbind() { + let size = 8192; + + unsafe { + let vaddr = rustix::mm::mmap_anonymous( + std::ptr::null_mut(), + size, + rustix::mm::ProtFlags::READ | rustix::mm::ProtFlags::WRITE, + rustix::mm::MapFlags::PRIVATE, + ) + .unwrap(); + + vaddr.cast::().write(100); + + let mask = &[1]; + rustix::numa::mbind( + vaddr, + size, + rustix::numa::Mode::BIND | rustix::numa::Mode::STATIC_NODES, + mask, + rustix::numa::ModeFlags::empty(), + ) + .unwrap(); + + rustix::numa::get_mempolicy_node(vaddr).unwrap(); + + match rustix::numa::get_mempolicy_next_node() { + Err(rustix::io::Errno::INVAL) => (), + _ => panic!( + "rustix::numa::get_mempolicy_next_node() should return EINVAL for MPOL_DEFAULT" + ), + } + + rustix::numa::set_mempolicy(rustix::numa::Mode::INTERLEAVE, mask).unwrap(); + + rustix::numa::get_mempolicy_next_node().unwrap(); + } +}