Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Portable RT: Fixes for shader-binding-table assembly (sbt data) #1993

Merged
merged 9 commits into from
Jan 30, 2025
74 changes: 57 additions & 17 deletions framework/decode/vulkan_address_replacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ VulkanAddressReplacer::VulkanAddressReplacer(const VulkanDeviceInfo*
const encode::VulkanDeviceTable* device_table,
const encode::VulkanInstanceTable* instance_table,
const decode::CommonObjectInfoTable& object_table) :
device_table_(device_table)
device_table_(device_table), object_table_(&object_table)
{
GFXRECON_ASSERT(device_info != nullptr && device_table != nullptr && instance_table != nullptr);
physical_device_info_ = object_table.GetVkPhysicalDeviceInfo(device_info->parent_id);
Expand Down Expand Up @@ -354,8 +354,8 @@ void VulkanAddressReplacer::ProcessCmdTraceRays(
}
auto input_addresses = reinterpret_cast<VkDeviceAddress*>(pipeline_context_sbt.input_handle_buffer.mapped_data);

std::unordered_map<const VkStridedDeviceAddressRegionKHR*, uint32_t> num_handles_map;
uint32_t num_addresses = 0;
std::unordered_map<const VkStridedDeviceAddressRegionKHR*, uint32_t> num_addresses_map;
uint32_t num_addresses = 0;

{
const auto handle_size_capture = static_cast<uint32_t>(util::aligned_value(
Expand All @@ -365,10 +365,18 @@ void VulkanAddressReplacer::ProcessCmdTraceRays(
{
if (region != nullptr && region->size != 0 && region->stride != 0)
{
num_handles_map[region] = region->size / region->stride;
num_addresses_map[region] = region->size / capture_ray_properties_->shaderGroupHandleSize;

for (uint32_t offset = 0; offset < region->size; offset += region->stride)
uint32_t capture_handle_size = capture_ray_properties_->shaderGroupHandleSize;
if (region->stride > capture_handle_size)
{
uint32_t payload_size = region->stride - capture_handle_size;
GFXRECON_LOG_DEBUG("Extra data in sbt: %d", payload_size);
}

for (uint32_t offset = 0; offset < region->size; offset += capture_handle_size)
{
// input-address are handles and extra data
input_addresses[num_addresses++] = region->deviceAddress + offset;
}
}
Expand Down Expand Up @@ -430,15 +438,17 @@ void VulkanAddressReplacer::ProcessCmdTraceRays(
uint32_t num_handles_limit = region->size / region->stride;
auto group_size = static_cast<uint32_t>(util::aligned_value(
num_handles_limit * handle_size_aligned, replay_ray_properties_->shaderGroupBaseAlignment));
sbt_offset += group_size;

// adjust group-size
region->size = group_size;
region->stride = handle_size_aligned;
// increase group-size/stride, if required
region->size = std::max<VkDeviceSize>(group_size, region->size);
region->stride = std::max<VkDeviceSize>(handle_size_aligned, region->stride);

sbt_offset += region->size;
}
}
// raygen: stride == size
raygen_sbt->size = raygen_sbt->stride = replay_ray_properties_->shaderGroupBaseAlignment;
raygen_sbt->size = raygen_sbt->stride =
util::aligned_value(raygen_sbt->stride, replay_ray_properties_->shaderGroupBaseAlignment);

if (!create_buffer(shadow_buf_context,
sbt_offset,
Expand All @@ -457,7 +467,7 @@ void VulkanAddressReplacer::ProcessCmdTraceRays(
{
if (region != nullptr && region->size != 0 && region->stride != 0)
{
uint32_t num_handles = num_handles_map[region];
uint32_t num_handles = num_addresses_map[region];

// assign shadow-sbt-address
region->deviceAddress = shadow_buf_context.device_address + sbt_offset;
Expand Down Expand Up @@ -497,7 +507,20 @@ void VulkanAddressReplacer::ProcessCmdTraceRays(
VK_ACCESS_SHADER_READ_BIT);
}

// set previous push-constant data
// set previous compute-pipeline, if any
if (command_buffer_info->bound_pipelines.count(VK_PIPELINE_BIND_POINT_COMPUTE))
{
auto* previous_pipeline = object_table_->GetVkPipelineInfo(
command_buffer_info->bound_pipelines.at(VK_PIPELINE_BIND_POINT_COMPUTE));
GFXRECON_ASSERT(previous_pipeline);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previous_pipeline MUST be valid?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. at this point we already checked the bound_pipelines map and know a compute-pipeline was previously bound (rare'ish case anyway, since the command-buffer is used to record a raytracing pipeline). so in fact, we only assert that it can also be mapped.

Copy link
Contributor Author

@fabian-lunarg fabian-lunarg Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double-checked the assumption, I think it's correct. the only caveat I could think of if a compute-pipeline would have been bound and then unbound again. but that's not allowed ("pipeline must be a valid VkPipeline handle")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@antonio-lunarg is your concern resolved?

if (previous_pipeline != nullptr)
{
device_table_->CmdBindPipeline(
command_buffer_info->handle, VK_PIPELINE_BIND_POINT_COMPUTE, previous_pipeline->handle);
}
}

// set previous push-constant data, if any
if (!command_buffer_info->push_constant_data.empty())
{
device_table_->CmdPushConstants(command_buffer_info->handle,
Expand Down Expand Up @@ -578,13 +601,17 @@ void VulkanAddressReplacer::ProcessCmdBuildAccelerationStructuresKHR(
&build_size_info);
}

bool as_buffer_usable = false;

// retrieve VkAccelerationStructureKHR -> VkBuffer -> check/correct size
auto* acceleration_structure_info =
address_tracker.GetAccelerationStructureByHandle(build_geometry_info.dstAccelerationStructure);
GFXRECON_ASSERT(acceleration_structure_info != nullptr);
auto* buffer_info = address_tracker.GetBufferByHandle(acceleration_structure_info->buffer);
bool as_buffer_usable =
buffer_info != nullptr && buffer_info->size >= build_size_info.accelerationStructureSize;
if (acceleration_structure_info != nullptr)
{
auto* buffer_info = address_tracker.GetBufferByHandle(acceleration_structure_info->buffer);
as_buffer_usable =
buffer_info != nullptr && buffer_info->size >= build_size_info.accelerationStructureSize;
}

// determine required size of scratch-buffer
uint32_t scratch_size = build_geometry_info.mode == VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR
Expand Down Expand Up @@ -811,7 +838,20 @@ void VulkanAddressReplacer::ProcessCmdBuildAccelerationStructuresKHR(
VK_ACCESS_SHADER_READ_BIT);
}

// set previous push-constant data
// set previous compute-pipeline, if any
if (command_buffer_info->bound_pipelines.count(VK_PIPELINE_BIND_POINT_COMPUTE))
{
auto* previous_pipeline = object_table_->GetVkPipelineInfo(
command_buffer_info->bound_pipelines.at(VK_PIPELINE_BIND_POINT_COMPUTE));
GFXRECON_ASSERT(previous_pipeline);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@antonio-lunarg is your concern resolved?

if (previous_pipeline != nullptr)
{
device_table_->CmdBindPipeline(
command_buffer_info->handle, VK_PIPELINE_BIND_POINT_COMPUTE, previous_pipeline->handle);
}
}

// set previous push-constant data, if any
if (!command_buffer_info->push_constant_data.empty())
{
device_table_->CmdPushConstants(command_buffer_info->handle,
Expand Down
1 change: 1 addition & 0 deletions framework/decode/vulkan_address_replacer.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ class VulkanAddressReplacer
bool swap_acceleration_structure_handle(VkAccelerationStructureKHR& handle);

const encode::VulkanDeviceTable* device_table_ = nullptr;
const decode::CommonObjectInfoTable* object_table_ = nullptr;
VkPhysicalDeviceMemoryProperties memory_properties_ = {};
std::optional<VkPhysicalDeviceRayTracingPipelinePropertiesKHR> capture_ray_properties_{}, replay_ray_properties_{};
std::optional<VkPhysicalDeviceAccelerationStructurePropertiesKHR> replay_acceleration_structure_properties_{};
Expand Down
Loading
Loading