Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ray Tracing Pipeline: SBT fixes & refactor examples #2617

Merged
merged 3 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions examples/ray-tracing-auto/scene.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,18 @@ impl Scene {

let tlas = unsafe {
build_top_level_acceleration_structure(
blas.clone(),
vec![AccelerationStructureInstance {
acceleration_structure_reference: blas.device_address().into(),
..Default::default()
}],
memory_allocator.clone(),
command_buffer_allocator.clone(),
app.device.clone(),
app.queue.clone(),
)
};

let proj = Mat4::perspective_rh_gl(std::f32::consts::FRAC_PI_2, 4.0 / 3.0, 0.01, 100.0);
let proj = Mat4::perspective_rh(std::f32::consts::FRAC_PI_2, 4.0 / 3.0, 0.01, 100.0);
let view = Mat4::look_at_rh(
Vec3::new(0.0, 0.0, 1.0),
Vec3::new(0.0, 0.0, 0.0),
Expand Down Expand Up @@ -436,16 +439,13 @@ unsafe fn build_acceleration_structure_triangles(
}

unsafe fn build_top_level_acceleration_structure(
acceleration_structure: Arc<AccelerationStructure>,
as_instances: Vec<AccelerationStructureInstance>,
allocator: Arc<dyn MemoryAllocator>,
command_buffer_allocator: Arc<dyn CommandBufferAllocator>,
device: Arc<Device>,
queue: Arc<Queue>,
) -> Arc<AccelerationStructure> {
let as_instance = AccelerationStructureInstance {
acceleration_structure_reference: acceleration_structure.device_address().into(),
..Default::default()
};
let primitive_count = as_instances.len() as u32;

let instance_buffer = Buffer::from_iter(
allocator.clone(),
Expand All @@ -459,7 +459,7 @@ unsafe fn build_top_level_acceleration_structure(
| MemoryTypeFilter::HOST_SEQUENTIAL_WRITE,
..Default::default()
},
[as_instance],
as_instances,
)
.unwrap();

Expand All @@ -471,7 +471,7 @@ unsafe fn build_top_level_acceleration_structure(

build_acceleration_structure_common(
geometries,
1,
primitive_count,
AccelerationStructureType::TopLevel,
allocator,
command_buffer_allocator,
Expand Down
18 changes: 9 additions & 9 deletions examples/ray-tracing/scene.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,18 @@ impl SceneTask {

let tlas = unsafe {
build_top_level_acceleration_structure(
blas.clone(),
vec![AccelerationStructureInstance {
acceleration_structure_reference: blas.device_address().into(),
..Default::default()
}],
memory_allocator.clone(),
command_buffer_allocator.clone(),
app.device.clone(),
app.queue.clone(),
)
};

let proj = Mat4::perspective_rh_gl(std::f32::consts::FRAC_PI_2, 4.0 / 3.0, 0.01, 100.0);
let proj = Mat4::perspective_rh(std::f32::consts::FRAC_PI_2, 4.0 / 3.0, 0.01, 100.0);
let view = Mat4::look_at_rh(
Vec3::new(0.0, 0.0, 1.0),
Vec3::new(0.0, 0.0, 0.0),
Expand Down Expand Up @@ -454,16 +457,13 @@ unsafe fn build_acceleration_structure_triangles(
}

unsafe fn build_top_level_acceleration_structure(
acceleration_structure: Arc<AccelerationStructure>,
as_instances: Vec<AccelerationStructureInstance>,
allocator: Arc<dyn MemoryAllocator>,
command_buffer_allocator: Arc<dyn CommandBufferAllocator>,
device: Arc<Device>,
queue: Arc<Queue>,
) -> Arc<AccelerationStructure> {
let as_instance = AccelerationStructureInstance {
acceleration_structure_reference: acceleration_structure.device_address().into(),
..Default::default()
};
let primitive_count = as_instances.len() as u32;

let instance_buffer = Buffer::from_iter(
allocator.clone(),
Expand All @@ -477,7 +477,7 @@ unsafe fn build_top_level_acceleration_structure(
| MemoryTypeFilter::HOST_SEQUENTIAL_WRITE,
..Default::default()
},
[as_instance],
as_instances,
)
.unwrap();

Expand All @@ -489,7 +489,7 @@ unsafe fn build_top_level_acceleration_structure(

build_acceleration_structure_common(
geometries,
1,
primitive_count,
AccelerationStructureType::TopLevel,
allocator,
command_buffer_allocator,
Expand Down
107 changes: 74 additions & 33 deletions vulkano/src/pipeline/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,18 @@ use super::{
PipelineShaderStageCreateInfoFields1Vk, PipelineShaderStageCreateInfoFields2Vk,
};
use crate::{
buffer::{Buffer, BufferCreateInfo, BufferUsage, Subbuffer},
buffer::{AllocateBufferError, Buffer, BufferCreateInfo, BufferUsage, Subbuffer},
device::{Device, DeviceOwned, DeviceOwnedDebugWrapper},
instance::InstanceOwnedDebugWrapper,
macros::impl_id_counter,
memory::{
allocator::{align_up, AllocationCreateInfo, MemoryAllocator, MemoryTypeFilter},
allocator::{
align_up, AllocationCreateInfo, DeviceLayout, MemoryAllocator, MemoryTypeFilter,
},
DeviceAlignment,
},
shader::{spirv::ExecutionModel, DescriptorBindingRequirements},
StridedDeviceAddressRegion, Validated, ValidationError, VulkanError, VulkanObject,
DeviceSize, StridedDeviceAddressRegion, Validated, ValidationError, VulkanError, VulkanObject,
};
use foldhash::{HashMap, HashSet};
use smallvec::SmallVec;
Expand Down Expand Up @@ -393,6 +395,19 @@ impl RayTracingPipelineCreateInfo {
..Default::default()
}));
}

let has_raygen = stages.iter().any(|stage| {
stage.entry_point.info().execution_model == ExecutionModel::RayGenerationKHR
});
if !has_raygen {
return Err(Box::new(ValidationError {
context: "stages".into(),
problem: "does not contain a `RayGeneration` shader".into(),
vuids: &["VUID-VkRayTracingPipelineCreateInfoKHR-stage-03425"],
..Default::default()
}));
}

for stage in stages {
stage.validate(device).map_err(|err| {
err.add_context("stages")
Expand Down Expand Up @@ -819,40 +834,53 @@ impl ShaderBindingTable {
allocator: Arc<dyn MemoryAllocator>,
ray_tracing_pipeline: &RayTracingPipeline,
) -> Result<Self, Validated<VulkanError>> {
let mut miss_shader_count: u64 = 0;
let mut hit_shader_count: u64 = 0;
let mut callable_shader_count: u64 = 0;
// VUID-vkCmdTraceRaysKHR-size-04023
// There should be exactly one raygen shader group.
let mut raygen_shader_handle = None;
let mut miss_shader_handles = Vec::new();
let mut hit_shader_handles = Vec::new();
let mut callable_shader_handles = Vec::new();

let handle_data = ray_tracing_pipeline
.device()
.ray_tracing_shader_group_handles(
ray_tracing_pipeline,
0,
ray_tracing_pipeline.groups().len() as u32,
)?;
let mut handle_iter = handle_data.iter();

for group in ray_tracing_pipeline.groups() {
let handle = handle_iter.next().unwrap();
match group {
RayTracingShaderGroupCreateInfo::General { general_shader } => {
match ray_tracing_pipeline.stages()[*general_shader as usize]
.entry_point
.info()
.execution_model
{
ExecutionModel::RayGenerationKHR => {}
ExecutionModel::MissKHR => miss_shader_count += 1,
ExecutionModel::CallableKHR => callable_shader_count += 1,
ExecutionModel::RayGenerationKHR => {
raygen_shader_handle = Some(handle);
}
ExecutionModel::MissKHR => {
miss_shader_handles.push(handle);
}
ExecutionModel::CallableKHR => {
callable_shader_handles.push(handle);
}
_ => {
panic!("Unexpected shader type in general shader group");
}
}
}
RayTracingShaderGroupCreateInfo::ProceduralHit { .. }
| RayTracingShaderGroupCreateInfo::TrianglesHit { .. } => {
hit_shader_count += 1;
hit_shader_handles.push(handle);
}
}
}

let handle_data = ray_tracing_pipeline
.device()
.ray_tracing_shader_group_handles(
ray_tracing_pipeline,
0,
ray_tracing_pipeline.groups().len() as u32,
)?;
let raygen_shader_handle = raygen_shader_handle.expect("no raygen shader group found");

let properties = ray_tracing_pipeline.device().physical_device().properties();
let handle_size_aligned = align_up(
Expand All @@ -873,29 +901,29 @@ impl ShaderBindingTable {
let mut miss = StridedDeviceAddressRegion {
stride: handle_size_aligned,
size: align_up(
handle_size_aligned * miss_shader_count,
handle_size_aligned * miss_shader_handles.len() as u64,
shader_group_base_alignment,
),
device_address: 0,
};
let mut hit = StridedDeviceAddressRegion {
stride: handle_size_aligned,
size: align_up(
handle_size_aligned * hit_shader_count,
handle_size_aligned * hit_shader_handles.len() as u64,
shader_group_base_alignment,
),
device_address: 0,
};
let mut callable = StridedDeviceAddressRegion {
stride: handle_size_aligned,
size: align_up(
handle_size_aligned * callable_shader_count,
handle_size_aligned * callable_shader_handles.len() as u64,
shader_group_base_alignment,
),
device_address: 0,
};

let sbt_buffer = Buffer::new_slice::<u8>(
let sbt_buffer = new_bytes_buffer_with_alignment(
allocator,
BufferCreateInfo {
usage: BufferUsage::TRANSFER_SRC
Expand All @@ -909,6 +937,7 @@ impl ShaderBindingTable {
..Default::default()
},
raygen.size + miss.size + hit.size + callable.size,
shader_group_base_alignment,
)
.expect("todo: raytracing: better error type for buffer errors");

Expand All @@ -920,26 +949,21 @@ impl ShaderBindingTable {
{
let mut sbt_buffer_write = sbt_buffer.write().unwrap();

let mut handle_iter = handle_data.iter();

let handle_size = handle_data.handle_size() as usize;
sbt_buffer_write[..handle_size].copy_from_slice(handle_iter.next().unwrap());
sbt_buffer_write[..handle_size].copy_from_slice(raygen_shader_handle);
let mut offset = raygen.size as usize;
for _ in 0..miss_shader_count {
sbt_buffer_write[offset..offset + handle_size]
.copy_from_slice(handle_iter.next().unwrap());
for handle in miss_shader_handles {
sbt_buffer_write[offset..offset + handle_size].copy_from_slice(handle);
offset += miss.stride as usize;
}
offset = (raygen.size + miss.size) as usize;
for _ in 0..hit_shader_count {
sbt_buffer_write[offset..offset + handle_size]
.copy_from_slice(handle_iter.next().unwrap());
for handle in hit_shader_handles {
sbt_buffer_write[offset..offset + handle_size].copy_from_slice(handle);
offset += hit.stride as usize;
}
offset = (raygen.size + miss.size + hit.size) as usize;
for _ in 0..callable_shader_count {
sbt_buffer_write[offset..offset + handle_size]
.copy_from_slice(handle_iter.next().unwrap());
for handle in callable_shader_handles {
sbt_buffer_write[offset..offset + handle_size].copy_from_slice(handle);
offset += callable.stride as usize;
}
}
Expand All @@ -955,3 +979,20 @@ impl ShaderBindingTable {
})
}
}

fn new_bytes_buffer_with_alignment(
allocator: Arc<dyn MemoryAllocator>,
create_info: BufferCreateInfo,
allocation_info: AllocationCreateInfo,
size: DeviceSize,
alignment: DeviceAlignment,
) -> Result<Subbuffer<[u8]>, Validated<AllocateBufferError>> {
let layout = DeviceLayout::from_size_alignment(size, alignment.as_devicesize()).unwrap();
let buffer = Subbuffer::new(Buffer::new(
allocator,
create_info,
allocation_info,
layout,
)?);
Ok(buffer)
}