Skip to content

Commit

Permalink
Fix get_ray_tracing_pipeline_shader_group_handles scoping
Browse files Browse the repository at this point in the history
  • Loading branch information
marc0246 committed Jan 28, 2025
1 parent a648f7a commit 2aed2f5
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 125 deletions.
118 changes: 0 additions & 118 deletions vulkano/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ use crate::{
instance::{Instance, InstanceOwned, InstanceOwnedDebugWrapper},
macros::{impl_id_counter, vulkan_bitflags},
memory::{ExternalMemoryHandleType, MemoryFdProperties, MemoryRequirements},
pipeline::ray_tracing::RayTracingPipeline,
Requires, RequiresAllOf, RequiresOneOf, Validated, ValidationError, Version, VulkanError,
VulkanObject,
};
Expand Down Expand Up @@ -1305,96 +1304,6 @@ impl Device {

Ok(())
}

pub fn ray_tracing_shader_group_handles(
&self,
ray_tracing_pipeline: &RayTracingPipeline,
first_group: u32,
group_count: u32,
) -> Result<ShaderGroupHandlesData, Validated<VulkanError>> {
self.validate_ray_tracing_pipeline_properties(
ray_tracing_pipeline,
first_group,
group_count,
)?;

Ok(unsafe {
self.ray_tracing_shader_group_handles_unchecked(
ray_tracing_pipeline,
first_group,
group_count,
)
}?)
}

fn validate_ray_tracing_pipeline_properties(
&self,
ray_tracing_pipeline: &RayTracingPipeline,
first_group: u32,
group_count: u32,
) -> Result<(), Box<ValidationError>> {
if !self.enabled_features().ray_tracing_pipeline
|| self
.physical_device()
.properties()
.shader_group_handle_size
.is_none()
{
Err(Box::new(ValidationError {
problem: "device property `shader_group_handle_size` is empty".into(),
requires_one_of: RequiresOneOf(&[RequiresAllOf(&[Requires::DeviceFeature(
"ray_tracing_pipeline",
)])]),
..Default::default()
}))?;
};

if (first_group + group_count) as usize > ray_tracing_pipeline.groups().len() {
Err(Box::new(ValidationError {
problem: "the sum of `first_group` and `group_count` must be less than or equal \
to the number of shader groups in the pipeline"
.into(),
vuids: &["VUID-vkGetRayTracingShaderGroupHandlesKHR-firstGroup-02419"],
..Default::default()
}))?
}

// TODO: VUID-vkGetRayTracingShaderGroupHandlesKHR-pipeline-07828

Ok(())
}

#[cfg_attr(not(feature = "document_unchecked"), doc(hidden))]
pub unsafe fn ray_tracing_shader_group_handles_unchecked(
&self,
ray_tracing_pipeline: &RayTracingPipeline,
first_group: u32,
group_count: u32,
) -> Result<ShaderGroupHandlesData, VulkanError> {
let handle_size = self
.physical_device()
.properties()
.shader_group_handle_size
.unwrap();

let mut data = vec![0u8; (handle_size * group_count) as usize];
let fns = self.fns();
unsafe {
(fns.khr_ray_tracing_pipeline
.get_ray_tracing_shader_group_handles_khr)(
self.handle,
ray_tracing_pipeline.handle(),
first_group,
group_count,
data.len(),
data.as_mut_ptr().cast(),
)
}
.result()
.map_err(VulkanError::from)?;

Ok(ShaderGroupHandlesData { data, handle_size })
}
}

impl Debug for Device {
Expand Down Expand Up @@ -2225,33 +2134,6 @@ impl<T> Deref for DeviceOwnedDebugWrapper<T> {
}
}

/// Holds the data returned by [`Device::ray_tracing_shader_group_handles`].
#[derive(Clone, Debug)]
pub struct ShaderGroupHandlesData {
data: Vec<u8>,
handle_size: u32,
}

impl ShaderGroupHandlesData {
#[inline]
pub fn data(&self) -> &[u8] {
&self.data
}

#[inline]
pub fn handle_size(&self) -> u32 {
self.handle_size
}
}

impl ShaderGroupHandlesData {
/// Returns an iterator over the handles in the data.
#[inline]
pub fn iter(&self) -> impl ExactSizeIterator<Item = &[u8]> {
self.data().chunks_exact(self.handle_size as usize)
}
}

#[cfg(test)]
mod tests {
use crate::device::{
Expand Down
102 changes: 95 additions & 7 deletions vulkano/src/pipeline/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,72 @@ impl RayTracingPipeline {
pub fn flags(&self) -> PipelineCreateFlags {
self.flags
}

/// Retrieves the opaque handles of shaders in the ray tracing pipeline.
///
/// Handles for `group_count` groups are retrieved, starting at `first_group`. The group
/// indices correspond to the [`groups`] the pipeline was created with.
pub fn group_handles(
&self,
first_group: u32,
group_count: u32,
) -> Result<ShaderGroupHandlesData, Validated<VulkanError>> {
self.validate_group_handles(first_group, group_count)?;

Ok(unsafe { self.group_handles_unchecked(first_group, group_count) }?)
}

fn validate_group_handles(
&self,
first_group: u32,
group_count: u32,
) -> Result<(), Box<ValidationError>> {
if (first_group + group_count) as usize > self.groups().len() {
Err(Box::new(ValidationError {
problem: "the sum of `first_group` and `group_count` must be less than or equal \
to the number of shader groups in the pipeline"
.into(),
vuids: &["VUID-vkGetRayTracingShaderGroupHandlesKHR-firstGroup-02419"],
..Default::default()
}))?
}

// TODO: VUID-vkGetRayTracingShaderGroupHandlesKHR-pipeline-07828

Ok(())
}

#[cfg_attr(not(feature = "document_unchecked"), doc(hidden))]
pub unsafe fn group_handles_unchecked(
&self,
first_group: u32,
group_count: u32,
) -> Result<ShaderGroupHandlesData, VulkanError> {
let handle_size = self
.device()
.physical_device()
.properties()
.shader_group_handle_size
.unwrap();

let mut data = vec![0u8; (handle_size * group_count) as usize];
let fns = self.device().fns();
unsafe {
(fns.khr_ray_tracing_pipeline
.get_ray_tracing_shader_group_handles_khr)(
self.device().handle(),
self.handle(),
first_group,
group_count,
data.len(),
data.as_mut_ptr().cast(),
)
}
.result()
.map_err(VulkanError::from)?;

Ok(ShaderGroupHandlesData { data, handle_size })
}
}

impl Pipeline for RayTracingPipeline {
Expand Down Expand Up @@ -817,6 +883,33 @@ pub(crate) struct RayTracingPipelineCreateInfoFields3Vk {
pub(crate) stages_fields2_vk: SmallVec<[PipelineShaderStageCreateInfoFields2Vk; 5]>,
}

/// Holds the data returned by [`RayTracingPipeline::group_handles`].
#[derive(Clone, Debug)]
pub struct ShaderGroupHandlesData {
data: Vec<u8>,
handle_size: u32,
}

impl ShaderGroupHandlesData {
/// Returns the opaque handle data as one blob.
#[inline]
pub fn data(&self) -> &[u8] {
&self.data
}

/// Returns the shader group handle size of the device.
#[inline]
pub fn handle_size(&self) -> u32 {
self.handle_size
}

/// Returns an iterator over the data of each group handle.
#[inline]
pub fn iter(&self) -> impl ExactSizeIterator<Item = &[u8]> {
self.data().chunks_exact(self.handle_size as usize)
}
}

/// An object that holds the strided addresses of the shader groups in a shader binding table.
#[derive(Debug, Clone)]
pub struct ShaderBindingTableAddresses {
Expand Down Expand Up @@ -870,13 +963,8 @@ impl ShaderBindingTable {
}
}

let handle_data = ray_tracing_pipeline
.device()
.ray_tracing_shader_group_handles(
ray_tracing_pipeline,
0,
ray_tracing_pipeline.groups().len() as u32,
)?;
let handle_data =
ray_tracing_pipeline.group_handles(0, ray_tracing_pipeline.groups().len() as u32)?;

let properties = ray_tracing_pipeline.device().physical_device().properties();
let handle_size_aligned = align_up(
Expand Down

0 comments on commit 2aed2f5

Please sign in to comment.