Skip to content

Commit

Permalink
Ray Tracing Pipeline: SBT fixes & refactor examples (#2617)
Browse files Browse the repository at this point in the history
* refactor copy SBT handles

* switch to perspective_rh & refactor build_top_level_acceleration_structure

* buffer alignment
  • Loading branch information
ComfyFluffy authored Jan 29, 2025
1 parent d31b578 commit 6ae10df
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 51 deletions.
18 changes: 9 additions & 9 deletions examples/ray-tracing-auto/scene.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,18 @@ impl Scene {

let tlas = unsafe {
build_top_level_acceleration_structure(
blas.clone(),
vec![AccelerationStructureInstance {
acceleration_structure_reference: blas.device_address().into(),
..Default::default()
}],
memory_allocator.clone(),
command_buffer_allocator.clone(),
app.device.clone(),
app.queue.clone(),
)
};

let proj = Mat4::perspective_rh_gl(std::f32::consts::FRAC_PI_2, 4.0 / 3.0, 0.01, 100.0);
let proj = Mat4::perspective_rh(std::f32::consts::FRAC_PI_2, 4.0 / 3.0, 0.01, 100.0);
let view = Mat4::look_at_rh(
Vec3::new(0.0, 0.0, 1.0),
Vec3::new(0.0, 0.0, 0.0),
Expand Down Expand Up @@ -436,16 +439,13 @@ unsafe fn build_acceleration_structure_triangles(
}

unsafe fn build_top_level_acceleration_structure(
acceleration_structure: Arc<AccelerationStructure>,
as_instances: Vec<AccelerationStructureInstance>,
allocator: Arc<dyn MemoryAllocator>,
command_buffer_allocator: Arc<dyn CommandBufferAllocator>,
device: Arc<Device>,
queue: Arc<Queue>,
) -> Arc<AccelerationStructure> {
let as_instance = AccelerationStructureInstance {
acceleration_structure_reference: acceleration_structure.device_address().into(),
..Default::default()
};
let primitive_count = as_instances.len() as u32;

let instance_buffer = Buffer::from_iter(
allocator.clone(),
Expand All @@ -459,7 +459,7 @@ unsafe fn build_top_level_acceleration_structure(
| MemoryTypeFilter::HOST_SEQUENTIAL_WRITE,
..Default::default()
},
[as_instance],
as_instances,
)
.unwrap();

Expand All @@ -471,7 +471,7 @@ unsafe fn build_top_level_acceleration_structure(

build_acceleration_structure_common(
geometries,
1,
primitive_count,
AccelerationStructureType::TopLevel,
allocator,
command_buffer_allocator,
Expand Down
18 changes: 9 additions & 9 deletions examples/ray-tracing/scene.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,18 @@ impl SceneTask {

let tlas = unsafe {
build_top_level_acceleration_structure(
blas.clone(),
vec![AccelerationStructureInstance {
acceleration_structure_reference: blas.device_address().into(),
..Default::default()
}],
memory_allocator.clone(),
command_buffer_allocator.clone(),
app.device.clone(),
app.queue.clone(),
)
};

let proj = Mat4::perspective_rh_gl(std::f32::consts::FRAC_PI_2, 4.0 / 3.0, 0.01, 100.0);
let proj = Mat4::perspective_rh(std::f32::consts::FRAC_PI_2, 4.0 / 3.0, 0.01, 100.0);
let view = Mat4::look_at_rh(
Vec3::new(0.0, 0.0, 1.0),
Vec3::new(0.0, 0.0, 0.0),
Expand Down Expand Up @@ -454,16 +457,13 @@ unsafe fn build_acceleration_structure_triangles(
}

unsafe fn build_top_level_acceleration_structure(
acceleration_structure: Arc<AccelerationStructure>,
as_instances: Vec<AccelerationStructureInstance>,
allocator: Arc<dyn MemoryAllocator>,
command_buffer_allocator: Arc<dyn CommandBufferAllocator>,
device: Arc<Device>,
queue: Arc<Queue>,
) -> Arc<AccelerationStructure> {
let as_instance = AccelerationStructureInstance {
acceleration_structure_reference: acceleration_structure.device_address().into(),
..Default::default()
};
let primitive_count = as_instances.len() as u32;

let instance_buffer = Buffer::from_iter(
allocator.clone(),
Expand All @@ -477,7 +477,7 @@ unsafe fn build_top_level_acceleration_structure(
| MemoryTypeFilter::HOST_SEQUENTIAL_WRITE,
..Default::default()
},
[as_instance],
as_instances,
)
.unwrap();

Expand All @@ -489,7 +489,7 @@ unsafe fn build_top_level_acceleration_structure(

build_acceleration_structure_common(
geometries,
1,
primitive_count,
AccelerationStructureType::TopLevel,
allocator,
command_buffer_allocator,
Expand Down
107 changes: 74 additions & 33 deletions vulkano/src/pipeline/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,18 @@ use super::{
PipelineShaderStageCreateInfoFields1Vk, PipelineShaderStageCreateInfoFields2Vk,
};
use crate::{
buffer::{Buffer, BufferCreateInfo, BufferUsage, Subbuffer},
buffer::{AllocateBufferError, Buffer, BufferCreateInfo, BufferUsage, Subbuffer},
device::{Device, DeviceOwned, DeviceOwnedDebugWrapper},
instance::InstanceOwnedDebugWrapper,
macros::impl_id_counter,
memory::{
allocator::{align_up, AllocationCreateInfo, MemoryAllocator, MemoryTypeFilter},
allocator::{
align_up, AllocationCreateInfo, DeviceLayout, MemoryAllocator, MemoryTypeFilter,
},
DeviceAlignment,
},
shader::{spirv::ExecutionModel, DescriptorBindingRequirements},
StridedDeviceAddressRegion, Validated, ValidationError, VulkanError, VulkanObject,
DeviceSize, StridedDeviceAddressRegion, Validated, ValidationError, VulkanError, VulkanObject,
};
use foldhash::{HashMap, HashSet};
use smallvec::SmallVec;
Expand Down Expand Up @@ -393,6 +395,19 @@ impl RayTracingPipelineCreateInfo {
..Default::default()
}));
}

let has_raygen = stages.iter().any(|stage| {
stage.entry_point.info().execution_model == ExecutionModel::RayGenerationKHR
});
if !has_raygen {
return Err(Box::new(ValidationError {
context: "stages".into(),
problem: "does not contain a `RayGeneration` shader".into(),
vuids: &["VUID-VkRayTracingPipelineCreateInfoKHR-stage-03425"],
..Default::default()
}));
}

for stage in stages {
stage.validate(device).map_err(|err| {
err.add_context("stages")
Expand Down Expand Up @@ -819,40 +834,53 @@ impl ShaderBindingTable {
allocator: Arc<dyn MemoryAllocator>,
ray_tracing_pipeline: &RayTracingPipeline,
) -> Result<Self, Validated<VulkanError>> {
let mut miss_shader_count: u64 = 0;
let mut hit_shader_count: u64 = 0;
let mut callable_shader_count: u64 = 0;
// VUID-vkCmdTraceRaysKHR-size-04023
// There should be exactly one raygen shader group.
let mut raygen_shader_handle = None;
let mut miss_shader_handles = Vec::new();
let mut hit_shader_handles = Vec::new();
let mut callable_shader_handles = Vec::new();

let handle_data = ray_tracing_pipeline
.device()
.ray_tracing_shader_group_handles(
ray_tracing_pipeline,
0,
ray_tracing_pipeline.groups().len() as u32,
)?;
let mut handle_iter = handle_data.iter();

for group in ray_tracing_pipeline.groups() {
let handle = handle_iter.next().unwrap();
match group {
RayTracingShaderGroupCreateInfo::General { general_shader } => {
match ray_tracing_pipeline.stages()[*general_shader as usize]
.entry_point
.info()
.execution_model
{
ExecutionModel::RayGenerationKHR => {}
ExecutionModel::MissKHR => miss_shader_count += 1,
ExecutionModel::CallableKHR => callable_shader_count += 1,
ExecutionModel::RayGenerationKHR => {
raygen_shader_handle = Some(handle);
}
ExecutionModel::MissKHR => {
miss_shader_handles.push(handle);
}
ExecutionModel::CallableKHR => {
callable_shader_handles.push(handle);
}
_ => {
panic!("Unexpected shader type in general shader group");
}
}
}
RayTracingShaderGroupCreateInfo::ProceduralHit { .. }
| RayTracingShaderGroupCreateInfo::TrianglesHit { .. } => {
hit_shader_count += 1;
hit_shader_handles.push(handle);
}
}
}

let handle_data = ray_tracing_pipeline
.device()
.ray_tracing_shader_group_handles(
ray_tracing_pipeline,
0,
ray_tracing_pipeline.groups().len() as u32,
)?;
let raygen_shader_handle = raygen_shader_handle.expect("no raygen shader group found");

let properties = ray_tracing_pipeline.device().physical_device().properties();
let handle_size_aligned = align_up(
Expand All @@ -873,29 +901,29 @@ impl ShaderBindingTable {
let mut miss = StridedDeviceAddressRegion {
stride: handle_size_aligned,
size: align_up(
handle_size_aligned * miss_shader_count,
handle_size_aligned * miss_shader_handles.len() as u64,
shader_group_base_alignment,
),
device_address: 0,
};
let mut hit = StridedDeviceAddressRegion {
stride: handle_size_aligned,
size: align_up(
handle_size_aligned * hit_shader_count,
handle_size_aligned * hit_shader_handles.len() as u64,
shader_group_base_alignment,
),
device_address: 0,
};
let mut callable = StridedDeviceAddressRegion {
stride: handle_size_aligned,
size: align_up(
handle_size_aligned * callable_shader_count,
handle_size_aligned * callable_shader_handles.len() as u64,
shader_group_base_alignment,
),
device_address: 0,
};

let sbt_buffer = Buffer::new_slice::<u8>(
let sbt_buffer = new_bytes_buffer_with_alignment(
allocator,
BufferCreateInfo {
usage: BufferUsage::TRANSFER_SRC
Expand All @@ -909,6 +937,7 @@ impl ShaderBindingTable {
..Default::default()
},
raygen.size + miss.size + hit.size + callable.size,
shader_group_base_alignment,
)
.expect("todo: raytracing: better error type for buffer errors");

Expand All @@ -920,26 +949,21 @@ impl ShaderBindingTable {
{
let mut sbt_buffer_write = sbt_buffer.write().unwrap();

let mut handle_iter = handle_data.iter();

let handle_size = handle_data.handle_size() as usize;
sbt_buffer_write[..handle_size].copy_from_slice(handle_iter.next().unwrap());
sbt_buffer_write[..handle_size].copy_from_slice(raygen_shader_handle);
let mut offset = raygen.size as usize;
for _ in 0..miss_shader_count {
sbt_buffer_write[offset..offset + handle_size]
.copy_from_slice(handle_iter.next().unwrap());
for handle in miss_shader_handles {
sbt_buffer_write[offset..offset + handle_size].copy_from_slice(handle);
offset += miss.stride as usize;
}
offset = (raygen.size + miss.size) as usize;
for _ in 0..hit_shader_count {
sbt_buffer_write[offset..offset + handle_size]
.copy_from_slice(handle_iter.next().unwrap());
for handle in hit_shader_handles {
sbt_buffer_write[offset..offset + handle_size].copy_from_slice(handle);
offset += hit.stride as usize;
}
offset = (raygen.size + miss.size + hit.size) as usize;
for _ in 0..callable_shader_count {
sbt_buffer_write[offset..offset + handle_size]
.copy_from_slice(handle_iter.next().unwrap());
for handle in callable_shader_handles {
sbt_buffer_write[offset..offset + handle_size].copy_from_slice(handle);
offset += callable.stride as usize;
}
}
Expand All @@ -955,3 +979,20 @@ impl ShaderBindingTable {
})
}
}

fn new_bytes_buffer_with_alignment(
allocator: Arc<dyn MemoryAllocator>,
create_info: BufferCreateInfo,
allocation_info: AllocationCreateInfo,
size: DeviceSize,
alignment: DeviceAlignment,
) -> Result<Subbuffer<[u8]>, Validated<AllocateBufferError>> {
let layout = DeviceLayout::from_size_alignment(size, alignment.as_devicesize()).unwrap();
let buffer = Subbuffer::new(Buffer::new(
allocator,
create_info,
allocation_info,
layout,
)?);
Ok(buffer)
}

0 comments on commit 6ae10df

Please sign in to comment.