Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HAL methods to update descriptor set #179

Merged
merged 2 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions piet-gpu-hal/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,36 @@ pub trait Device: Sized {
builder.build(self, pipeline)
}

/// Update a descriptor in a descriptor set.
///
/// The index is the same as the binding number in Vulkan.
///
/// # Safety
///
/// The descriptor set must not be used in any in-flight command buffer. The index must be valid.
/// The resource type must match that at descriptor set creation time.
unsafe fn update_buffer_descriptor(
&self,
ds: &mut Self::DescriptorSet,
index: u32,
buf: &Self::Buffer,
);

/// Update a descriptor in a descriptor set.
///
/// The index is the same as the binding number in Vulkan.
///
/// # Safety
///
/// The descriptor set must not be used in any in-flight command buffer. The index must be valid.
/// The resource type must match that at descriptor set creation time.
unsafe fn update_image_descriptor(
&self,
ds: &mut Self::DescriptorSet,
index: u32,
image: &Self::Image,
);

fn create_cmd_buf(&self) -> Result<Self::CmdBuf, Error>;

/// If the command buffer was submitted, it must complete before this is called.
Expand Down
39 changes: 33 additions & 6 deletions piet-gpu-hal/src/dx12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use raw_window_handle::{HasRawWindowHandle, RawWindowHandle};

use smallvec::SmallVec;

use crate::{BindType, BufferUsage, Error, GpuInfo, ImageLayout, MapMode, WorkgroupLimits, ImageFormat, ComputePassDescriptor};
use crate::{
BindType, BufferUsage, ComputePassDescriptor, Error, GpuInfo, ImageFormat, ImageLayout,
MapMode, WorkgroupLimits,
};

use self::{
descriptor::{CpuHeapRefOwned, DescriptorPool, GpuHeapRefOwned},
Expand Down Expand Up @@ -322,7 +325,12 @@ impl crate::backend::Device for Dx12Device {
Ok(())
}

unsafe fn create_image2d(&self, width: u32, height: u32, format: ImageFormat) -> Result<Self::Image, Error> {
unsafe fn create_image2d(
&self,
width: u32,
height: u32,
format: ImageFormat,
) -> Result<Self::Image, Error> {
let format = match format {
ImageFormat::A8 => winapi::shared::dxgiformat::DXGI_FORMAT_R8_UNORM,
ImageFormat::Rgba8 => winapi::shared::dxgiformat::DXGI_FORMAT_R8G8B8A8_UNORM,
Expand Down Expand Up @@ -391,10 +399,7 @@ impl crate::backend::Device for Dx12Device {
std::ptr::copy_nonoverlapping(mapped, buf.as_mut_ptr() as *mut u8, size);
self.unmap_buffer(&pool.buf, 0, size as u64, MapMode::Read)?;
let tsp = (self.ts_freq as f64).recip();
let result = buf
.iter()
.map(|ts| *ts as f64 * tsp)
.collect();
let result = buf.iter().map(|ts| *ts as f64 * tsp).collect();
Ok(result)
}

Expand Down Expand Up @@ -569,6 +574,28 @@ impl crate::backend::Device for Dx12Device {
DescriptorSetBuilder::default()
}

unsafe fn update_buffer_descriptor(
&self,
ds: &mut Self::DescriptorSet,
index: u32,
buf: &Self::Buffer,
) {
let src_cpu_ref = buf.cpu_ref.as_ref().unwrap().handle();
ds.gpu_ref
.copy_one_descriptor(&self.device, src_cpu_ref, index);
}

unsafe fn update_image_descriptor(
&self,
ds: &mut Self::DescriptorSet,
index: u32,
image: &Self::Image,
) {
let src_cpu_ref = image.cpu_ref.as_ref().unwrap().handle();
ds.gpu_ref
.copy_one_descriptor(&self.device, src_cpu_ref, index);
}

unsafe fn create_sampler(&self, _params: crate::SamplerParams) -> Result<Self::Sampler, Error> {
todo!()
}
Expand Down
16 changes: 15 additions & 1 deletion piet-gpu-hal/src/dx12/descriptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ pub struct GpuHeapRefOwned {
heap_ref: GpuHeapRef,
cpu_handle: D3D12_CPU_DESCRIPTOR_HANDLE,
gpu_handle: D3D12_GPU_DESCRIPTOR_HANDLE,
increment_size: u32,
free_list: Weak<Mutex<DescriptorFreeList>>,
}

Expand Down Expand Up @@ -137,10 +138,13 @@ impl DescriptorPool {

pub fn alloc_gpu(&mut self, device: &Device, n: u32) -> Result<GpuHeapRefOwned, Error> {
let free_list = &self.free_list;
let heap_type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
let increment_size = unsafe { device.get_descriptor_increment_size(heap_type) };
let mk_owned = |heap_ref, cpu_handle, gpu_handle| GpuHeapRefOwned {
heap_ref,
cpu_handle,
gpu_handle,
increment_size,
free_list: Arc::downgrade(free_list),
};
let mut free_list = free_list.lock().unwrap();
Expand All @@ -158,7 +162,6 @@ impl DescriptorPool {
}
unsafe {
let size = n.max(GPU_CHUNK_SIZE).next_power_of_two();
let heap_type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
let desc = D3D12_DESCRIPTOR_HEAP_DESC {
Type: heap_type,
NumDescriptors: size,
Expand Down Expand Up @@ -246,6 +249,17 @@ impl GpuHeapRefOwned {
D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV,
);
}

pub unsafe fn copy_one_descriptor(
&self,
device: &Device,
src: D3D12_CPU_DESCRIPTOR_HANDLE,
index: u32,
) {
let mut dst = self.cpu_handle;
dst.ptr += (index * self.increment_size) as usize;
device.copy_one_descriptor(dst, src, D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
}
}

impl Deref for CpuHeapRefOwned {
Expand Down
10 changes: 10 additions & 0 deletions piet-gpu-hal/src/dx12/wrappers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,16 @@ impl Device {
);
}

pub unsafe fn copy_one_descriptor(
&self,
dst: d3d12::D3D12_CPU_DESCRIPTOR_HANDLE,
src: d3d12::D3D12_CPU_DESCRIPTOR_HANDLE,
descriptor_heap_type: d3d12::D3D12_DESCRIPTOR_HEAP_TYPE,
) {
self.0
.CopyDescriptorsSimple(1, dst, src, descriptor_heap_type);
}

pub unsafe fn create_compute_pipeline_state(
&self,
compute_pipeline_desc: &d3d12::D3D12_COMPUTE_PIPELINE_STATE_DESC,
Expand Down
24 changes: 24 additions & 0 deletions piet-gpu-hal/src/hub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,30 @@ impl Session {
DescriptorSetBuilder(self.0.device.descriptor_set_builder())
}

/// Update a buffer in a descriptor set.
pub unsafe fn update_buffer_descriptor(
&self,
ds: &mut DescriptorSet,
index: u32,
buffer: &Buffer,
) {
self.0
.device
.update_buffer_descriptor(ds, index, &buffer.0.buffer)
}

/// Update an image in a descriptor set.
pub unsafe fn update_image_descriptor(
&self,
ds: &mut DescriptorSet,
index: u32,
image: &Image,
) {
self.0
.device
.update_image_descriptor(ds, index, &image.0.image)
}

/// Create a query pool for timestamp queries.
pub fn create_query_pool(&self, n_queries: u32) -> Result<QueryPool, Error> {
self.0.device.create_query_pool(n_queries)
Expand Down
18 changes: 18 additions & 0 deletions piet-gpu-hal/src/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,24 @@ impl crate::backend::Device for MtlDevice {
DescriptorSetBuilder::default()
}

unsafe fn update_buffer_descriptor(
&self,
ds: &mut Self::DescriptorSet,
index: u32,
buf: &Self::Buffer,
) {
ds.buffers[index as usize] = buf.clone();
}

unsafe fn update_image_descriptor(
&self,
ds: &mut Self::DescriptorSet,
index: u32,
image: &Self::Image,
) {
ds.images[index as usize - ds.buffers.len()] = image.clone();
}

fn create_cmd_buf(&self) -> Result<Self::CmdBuf, Error> {
let cmd_queue = self.cmd_queue.lock().unwrap();
// A discussion about autorelease pools.
Expand Down
26 changes: 26 additions & 0 deletions piet-gpu-hal/src/mux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,32 @@ impl Device {
}
}

pub unsafe fn update_buffer_descriptor(
&self,
ds: &mut DescriptorSet,
index: u32,
buffer: &Buffer,
) {
mux_match! { self;
Device::Vk(d) => d.update_buffer_descriptor(ds.vk_mut(), index, buffer.vk()),
Device::Dx12(d) => d.update_buffer_descriptor(ds.dx12_mut(), index, buffer.dx12()),
Device::Mtl(d) => d.update_buffer_descriptor(ds.mtl_mut(), index, buffer.mtl()),
}
}

pub unsafe fn update_image_descriptor(
&self,
ds: &mut DescriptorSet,
index: u32,
image: &Image,
) {
mux_match! { self;
Device::Vk(d) => d.update_image_descriptor(ds.vk_mut(), index, image.vk()),
Device::Dx12(d) => d.update_image_descriptor(ds.dx12_mut(), index, image.dx12()),
Device::Mtl(d) => d.update_image_descriptor(ds.mtl_mut(), index, image.mtl()),
}
}

pub fn create_cmd_buf(&self) -> Result<CmdBuf, Error> {
mux_match! { self;
Device::Vk(d) => d.create_cmd_buf().map(CmdBuf::Vk),
Expand Down
70 changes: 60 additions & 10 deletions piet-gpu-hal/src/vulkan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ use std::os::raw::c_char;
use std::sync::Arc;

use ash::extensions::{ext::DebugUtils, khr};
use ash::{vk, Device, Entry, Instance};
use ash::vk::DebugUtilsLabelEXT;
use ash::{vk, Device, Entry, Instance};

use smallvec::SmallVec;

use crate::backend::Device as DeviceTrait;
use crate::{
BindType, BufferUsage, Error, GpuInfo, ImageFormat, ImageLayout, MapMode, SamplerParams, SubgroupSize,
WorkgroupLimits, ComputePassDescriptor,
BindType, BufferUsage, ComputePassDescriptor, Error, GpuInfo, ImageFormat, ImageLayout,
MapMode, SamplerParams, SubgroupSize, WorkgroupLimits,
};

pub struct VkInstance {
Expand Down Expand Up @@ -320,7 +320,10 @@ impl VkInstance {
let queue_index = 0;
let queue = device.get_device_queue(qfi, queue_index);

let device = Arc::new(RawDevice { device, dbg_loader: self.dbg_loader.clone() });
let device = Arc::new(RawDevice {
device,
dbg_loader: self.dbg_loader.clone(),
});

let props = self.instance.get_physical_device_properties(pdevice);
let timestamp_period = props.limits.timestamp_period;
Expand Down Expand Up @@ -536,7 +539,12 @@ impl crate::backend::Device for VkDevice {
Ok(())
}

unsafe fn create_image2d(&self, width: u32, height: u32, format: ImageFormat) -> Result<Self::Image, Error> {
unsafe fn create_image2d(
&self,
width: u32,
height: u32,
format: ImageFormat,
) -> Result<Self::Image, Error> {
let device = &self.device.device;
let extent = vk::Extent3D {
width,
Expand Down Expand Up @@ -720,6 +728,49 @@ impl crate::backend::Device for VkDevice {
}
}

unsafe fn update_buffer_descriptor(
&self,
ds: &mut Self::DescriptorSet,
index: u32,
buf: &Self::Buffer,
) {
let device = &self.device.device;
device.update_descriptor_sets(
&[vk::WriteDescriptorSet::builder()
.dst_set(ds.descriptor_set)
.dst_binding(index)
.descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
.buffer_info(&[vk::DescriptorBufferInfo::builder()
.buffer(buf.buffer)
.offset(0)
.range(vk::WHOLE_SIZE)
.build()])
.build()],
&[],
);
}

unsafe fn update_image_descriptor(
&self,
ds: &mut Self::DescriptorSet,
index: u32,
image: &Self::Image,
) {
let device = &self.device.device;
device.update_descriptor_sets(
&[vk::WriteDescriptorSet::builder()
.dst_set(ds.descriptor_set)
.dst_binding(index)
.descriptor_type(vk::DescriptorType::STORAGE_IMAGE)
.image_info(&[vk::DescriptorImageInfo::builder()
.image_view(image.image_view)
.image_layout(vk::ImageLayout::GENERAL)
.build()])
.build()],
&[],
);
}

fn create_cmd_buf(&self) -> Result<CmdBuf, Error> {
unsafe {
let device = &self.device.device;
Expand Down Expand Up @@ -773,10 +824,7 @@ impl crate::backend::Device for VkDevice {
let flags = vk::QueryResultFlags::TYPE_64 | vk::QueryResultFlags::WAIT;
device.get_query_pool_results(pool.pool, 0, pool.n_queries, &mut buf, flags)?;
let tsp = self.timestamp_period as f64 * 1e-9;
let result = buf
.iter()
.map(|ts| *ts as f64 * tsp)
.collect();
let result = buf.iter().map(|ts| *ts as f64 * tsp).collect();
Ok(result)
}

Expand Down Expand Up @@ -1129,7 +1177,9 @@ impl crate::backend::CmdBuf<VkDevice> for CmdBuf {
unsafe fn begin_debug_label(&mut self, label: &str) {
if let Some(utils) = &self.device.dbg_loader {
let label_cstr = CString::new(label).unwrap();
let label_ext = DebugUtilsLabelEXT::builder().label_name(&label_cstr).build();
let label_ext = DebugUtilsLabelEXT::builder()
.label_name(&label_cstr)
.build();
utils.cmd_begin_debug_utils_label(self.cmd_buf, &label_ext);
}
}
Expand Down