Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check bounds of config space for MMIO transport #167

Merged
merged 5 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/aarch64/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ extern "C" fn main(x0: u64, x1: u64, x2: u64, x3: u64) {
debug!("Found VirtIO MMIO device at {:?}", region);

let header = NonNull::new(region.starting_address as *mut VirtIOHeader).unwrap();
match unsafe { MmioTransport::new(header) } {
match unsafe { MmioTransport::new(header, region.size.unwrap()) } {
Err(e) => warn!("Error creating VirtIO MMIO transport: {}", e),
Ok(transport) => {
info!(
Expand Down
2 changes: 1 addition & 1 deletion examples/riscv/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ fn virtio_probe(node: FdtNode) {
node.compatible().map(Compatible::first),
);
let header = NonNull::new(vaddr as *mut VirtIOHeader).unwrap();
match unsafe { MmioTransport::new(header) } {
match unsafe { MmioTransport::new(header, size) } {
Err(e) => warn!("Error creating VirtIO MMIO transport: {}", e),
Ok(transport) => {
info!(
Expand Down
5 changes: 4 additions & 1 deletion examples/x86_64/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ default = ["tcp"]
[dependencies]
log = "0.4.17"
spin = "0.9"
x86_64 = "0.14"
x86_64 = { version = "0.14.12", default-features = false, features = [
"instructions",
"abi_x86_interrupt",
] }
uart_16550 = "0.2"
linked_list_allocator = "0.10"
lazy_static = { version = "1.4.0", features = ["spin_no_std"] }
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
//! use core::ptr::NonNull;
//! use virtio_drivers::transport::mmio::{MmioTransport, VirtIOHeader};
//!
//! # fn example(mmio_device_address: usize) {
//! # fn example(mmio_device_address: usize, mmio_size: usize) {
//! let header = NonNull::new(mmio_device_address as *mut VirtIOHeader).unwrap();
//! let transport = unsafe { MmioTransport::new(header) }.unwrap();
//! let transport = unsafe { MmioTransport::new(header, mmio_size) }.unwrap();
//! # }
//! ```
//!
Expand Down
24 changes: 18 additions & 6 deletions src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
pub fn new<T: Transport>(
transport: &mut T,
idx: u16,
indirect: bool,

Check warning on line 80 in src/queue.rs

View workflow job for this annotation

GitHub Actions / build

unused variable: `indirect`

Check warning on line 80 in src/queue.rs

View workflow job for this annotation

GitHub Actions / build

unused variable: `indirect`
event_idx: bool,
) -> Result<Self> {
#[allow(clippy::let_unit_value)]
Expand Down Expand Up @@ -995,7 +995,9 @@
#[test]
fn queue_too_big() {
let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
let mut transport =
unsafe { MmioTransport::new(NonNull::from(&mut header), size_of::<VirtIOHeader>()) }
.unwrap();
assert_eq!(
VirtQueue::<FakeHal, 8>::new(&mut transport, 0, false, false).unwrap_err(),
Error::InvalidParam
Expand All @@ -1005,7 +1007,9 @@
#[test]
fn queue_already_used() {
let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
let mut transport =
unsafe { MmioTransport::new(NonNull::from(&mut header), size_of::<VirtIOHeader>()) }
.unwrap();
VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
assert_eq!(
VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap_err(),
Expand All @@ -1016,7 +1020,9 @@
#[test]
fn add_empty() {
let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
let mut transport =
unsafe { MmioTransport::new(NonNull::from(&mut header), size_of::<VirtIOHeader>()) }
.unwrap();
let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
assert_eq!(
unsafe { queue.add(&[], &mut []) }.unwrap_err(),
Expand All @@ -1027,7 +1033,9 @@
#[test]
fn add_too_many() {
let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
let mut transport =
unsafe { MmioTransport::new(NonNull::from(&mut header), size_of::<VirtIOHeader>()) }
.unwrap();
let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
assert_eq!(queue.available_desc(), 4);
assert_eq!(
Expand All @@ -1039,7 +1047,9 @@
#[test]
fn add_buffers() {
let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
let mut transport =
unsafe { MmioTransport::new(NonNull::from(&mut header), size_of::<VirtIOHeader>()) }
.unwrap();
let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
assert_eq!(queue.available_desc(), 4);

Expand Down Expand Up @@ -1102,7 +1112,9 @@
use core::ptr::slice_from_raw_parts;

let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
let mut transport =
unsafe { MmioTransport::new(NonNull::from(&mut header), size_of::<VirtIOHeader>()) }
.unwrap();
let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, true, false).unwrap();
assert_eq!(queue.available_desc(), 4);

Expand Down
69 changes: 47 additions & 22 deletions src/transport/mmio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use core::{
mem::{align_of, size_of},
ptr::NonNull,
};
use zerocopy::{FromBytes, Immutable, IntoBytes};

const MAGIC_VALUE: u32 = 0x7472_6976;
pub(crate) const LEGACY_VERSION: u32 = 1;
Expand Down Expand Up @@ -61,6 +62,9 @@ pub enum MmioError {
/// The header reports a device ID of 0.
#[error("Device ID was zero")]
ZeroDeviceId,
/// The MMIO region size was smaller than the header size we expect.
#[error("MMIO region too small")]
MmioRegionTooSmall,
}

/// MMIO Device Register Interface, both legacy and modern.
Expand Down Expand Up @@ -263,6 +267,8 @@ impl VirtIOHeader {
pub struct MmioTransport {
header: NonNull<VirtIOHeader>,
version: MmioVersion,
/// The size in bytes of the config space.
config_space_size: usize,
}

impl MmioTransport {
Expand All @@ -272,16 +278,23 @@ impl MmioTransport {
/// # Safety
/// `header` must point to a properly aligned valid VirtIO MMIO region, which must remain valid
/// for the lifetime of the transport that is returned.
pub unsafe fn new(header: NonNull<VirtIOHeader>) -> Result<Self, MmioError> {
pub unsafe fn new(header: NonNull<VirtIOHeader>, mmio_size: usize) -> Result<Self, MmioError> {
let magic = volread!(header, magic);
if magic != MAGIC_VALUE {
return Err(MmioError::BadMagic(magic));
}
if volread!(header, device_id) == 0 {
return Err(MmioError::ZeroDeviceId);
}
let Some(config_space_size) = mmio_size.checked_sub(CONFIG_SPACE_OFFSET) else {
return Err(MmioError::MmioRegionTooSmall);
};
let version = volread!(header, version).try_into()?;
Ok(Self { header, version })
Ok(Self {
header,
version,
config_space_size,
})
}

/// Gets the version of the VirtIO MMIO transport.
Expand Down Expand Up @@ -484,40 +497,52 @@ impl Transport for MmioTransport {
}
}

fn read_config_space<T>(&self, offset: usize) -> Result<T, Error> {
fn read_config_space<T: FromBytes>(&self, offset: usize) -> Result<T, Error> {
assert!(align_of::<T>() <= 4,
"Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.",
align_of::<T>());
assert!(offset % align_of::<T>() == 0);

// SAFETY: The caller of `MmioTransport::new` guaranteed that the header pointer was valid,
// which includes the config space.
unsafe {
Ok(self
.header
.cast::<T>()
.byte_add(CONFIG_SPACE_OFFSET)
.byte_add(offset)
.read_volatile())
if self.config_space_size < offset + size_of::<T>() {
Err(Error::ConfigSpaceTooSmall)
} else {
// SAFETY: The caller of `MmioTransport::new` guaranteed that the header pointer was valid,
// which includes the config space.
unsafe {
Ok(self
.header
.cast::<T>()
.byte_add(CONFIG_SPACE_OFFSET)
.byte_add(offset)
.read_volatile())
}
}
}

fn write_config_space<T>(&mut self, offset: usize, value: T) -> Result<(), Error> {
fn write_config_space<T: IntoBytes + Immutable>(
&mut self,
offset: usize,
value: T,
) -> Result<(), Error> {
assert!(align_of::<T>() <= 4,
"Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.",
align_of::<T>());
assert!(offset % align_of::<T>() == 0);

// SAFETY: The caller of `MmioTransport::new` guaranteed that the header pointer was valid,
// which includes the config space.
unsafe {
self.header
.cast::<T>()
.byte_add(CONFIG_SPACE_OFFSET)
.byte_add(offset)
.write_volatile(value);
if self.config_space_size < offset + size_of::<T>() {
Err(Error::ConfigSpaceTooSmall)
} else {
// SAFETY: The caller of `MmioTransport::new` guaranteed that the header pointer was valid,
// which includes the config space.
unsafe {
self.header
.cast::<T>()
.byte_add(CONFIG_SPACE_OFFSET)
.byte_add(offset)
.write_volatile(value);
}
Ok(())
}
Ok(())
}
}

Expand Down
8 changes: 6 additions & 2 deletions src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use bitflags::{bitflags, Flags};
use core::{fmt::Debug, ops::BitAnd};
use log::debug;
pub use some::SomeTransport;
use zerocopy::{FromBytes, IntoBytes};
use zerocopy::{FromBytes, Immutable, IntoBytes};

/// A VirtIO transport layer.
pub trait Transport {
Expand Down Expand Up @@ -105,7 +105,11 @@ pub trait Transport {
fn read_config_space<T: FromBytes>(&self, offset: usize) -> Result<T>;

/// Writes a value to the device config space.
fn write_config_space<T: IntoBytes>(&mut self, offset: usize, value: T) -> Result<()>;
fn write_config_space<T: IntoBytes + Immutable>(
&mut self,
offset: usize,
value: T,
) -> Result<()>;
}

bitflags! {
Expand Down
15 changes: 9 additions & 6 deletions src/transport/pci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use core::{
mem::{align_of, size_of},
ptr::{addr_of_mut, NonNull},
};
use zerocopy::{FromBytes, Immutable, IntoBytes};

/// The PCI vendor ID for VirtIO devices.
const VIRTIO_VENDOR_ID: u16 = 0x1af4;
Expand Down Expand Up @@ -325,7 +326,7 @@ impl Transport for PciTransport {
isr_status & 0x3 != 0
}

fn read_config_space<T>(&self, offset: usize) -> Result<T, Error> {
fn read_config_space<T: FromBytes>(&self, offset: usize) -> Result<T, Error> {
assert!(align_of::<T>() <= 4,
"Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.",
align_of::<T>());
Expand All @@ -338,15 +339,18 @@ impl Transport for PciTransport {
// SAFETY: If we have a config space pointer it must be valid for its length, and we just
// checked that the offset and size of the access was within the length.
unsafe {
// TODO: Use NonNull::as_non_null_ptr once it is stable.
Ok((config_space.as_ptr() as *mut T)
Ok((config_space.as_ptr().cast::<T>())
.byte_add(offset)
.read_volatile())
}
}
}

fn write_config_space<T>(&mut self, offset: usize, value: T) -> Result<(), Error> {
fn write_config_space<T: IntoBytes + Immutable>(
&mut self,
offset: usize,
value: T,
) -> Result<(), Error> {
assert!(align_of::<T>() <= 4,
"Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.",
align_of::<T>());
Expand All @@ -359,8 +363,7 @@ impl Transport for PciTransport {
// SAFETY: If we have a config space pointer it must be valid for its length, and we just
// checked that the offset and size of the access was within the length.
unsafe {
// TODO: Use NonNull::as_non_null_ptr once it is stable.
(config_space.as_ptr() as *mut T)
(config_space.as_ptr().cast::<T>())
.byte_add(offset)
.write_volatile(value);
}
Expand Down
8 changes: 6 additions & 2 deletions src/transport/some.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use zerocopy::{FromBytes, IntoBytes};
use zerocopy::{FromBytes, Immutable, IntoBytes};

use super::{mmio::MmioTransport, pci::PciTransport, DeviceStatus, DeviceType, Transport};
use crate::{PhysAddr, Result};
Expand Down Expand Up @@ -130,7 +130,11 @@ impl Transport for SomeTransport {
}
}

fn write_config_space<T: IntoBytes>(&mut self, offset: usize, value: T) -> Result<()> {
fn write_config_space<T: IntoBytes + Immutable>(
&mut self,
offset: usize,
value: T,
) -> Result<()> {
match self {
Self::Mmio(mmio) => mmio.write_config_space(offset, value),
Self::Pci(pci) => pci.write_config_space(offset, value),
Expand Down
Loading