Skip to content

Commit

Permalink
[mtl] reusable compute passes
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark committed Aug 2, 2018
1 parent d674137 commit c5561fa
Showing 1 changed file with 135 additions and 118 deletions.
253 changes: 135 additions & 118 deletions src/backend/metal/src/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ enum PassDoor<'a> {
/// A helper temporary object that consumes state-setting commands only
/// applicable to a render pass currently encoded.
enum PreRender<'a> {
Immediate(&'a metal::RenderCommandEncoder),
Immediate(&'a metal::RenderCommandEncoderRef),
Deferred(&'a mut Vec<soft::RenderCommand<soft::Own>>),
Void,
}
Expand All @@ -759,12 +759,29 @@ impl<'a> PreRender<'a> {
PreRender::Void => (),
}
}

fn issue_many<'b, I>(&mut self, commands: I)
where
I: Iterator<Item = soft::RenderCommand<&'b soft::Own>>
{
match *self {
PreRender::Immediate(encoder) => {
for com in commands {
exec_render(encoder, com);
}
}
PreRender::Deferred(ref mut list) => {
list.extend(commands.map(soft::RenderCommand::own))
}
PreRender::Void => {}
}
}
}

/// A helper temporary object that consumes state-setting commands only
/// applicable to a compute pass currently encoded.
enum PreCompute<'a> {
Immediate(&'a metal::ComputeCommandEncoder),
Immediate(&'a metal::ComputeCommandEncoderRef),
Deferred(&'a mut Vec<soft::ComputeCommand<soft::Own>>),
Void,
}
Expand All @@ -777,6 +794,23 @@ impl<'a> PreCompute<'a> {
PreCompute::Void => (),
}
}

fn issue_many<'b, I>(&mut self, commands: I)
where
I: Iterator<Item = soft::ComputeCommand<&'b soft::Own>>
{
match *self {
PreCompute::Immediate(encoder) => {
for com in commands {
exec_compute(encoder, com);
}
}
PreCompute::Deferred(ref mut list) => {
list.extend(commands.map(soft::ComputeCommand::own))
}
PreCompute::Void => {}
}
}
}

impl CommandSink {
Expand Down Expand Up @@ -805,17 +839,9 @@ impl CommandSink {
where
I: Iterator<Item = soft::RenderCommand<&'a soft::Own>>,
{
match self.pre_render() {
PreRender::Immediate(encoder) => {
for command in commands {
exec_render(encoder, command);
}
}
PreRender::Deferred(ref mut list) => {
list.extend(commands.into_iter().map(soft::RenderCommand::own))
}
PreRender::Void => panic!("Not in render encoding state!"),
}
let mut pre = self.pre_render();
debug_assert!(!pre.is_void());
pre.issue_many(commands);
}

/// Issue provided blit commands. This function doesn't expect an active blit pass,
Expand All @@ -824,22 +850,21 @@ impl CommandSink {
where
I: Iterator<Item = soft::BlitCommand>,
{
match *self {
enum PreBlit<'b> {
Immediate(&'b metal::BlitCommandEncoderRef),
Deferred(&'b mut Vec<soft::BlitCommand>),
}

let pre = match *self {
CommandSink::Immediate { encoder_state: EncoderState::Blit(ref encoder), .. } => {
for command in commands {
exec_blit(encoder, command);
}
PreBlit::Immediate(encoder)
}
CommandSink::Immediate { ref cmd_buffer, ref mut encoder_state, ref mut num_passes, .. } => {
*num_passes += 1;
encoder_state.end();
let encoder = cmd_buffer.new_blit_command_encoder().to_owned();

for command in commands {
exec_blit(&encoder, command);
}

*encoder_state = EncoderState::Blit(encoder);
let encoder = cmd_buffer.new_blit_command_encoder();
*encoder_state = EncoderState::Blit(encoder.to_owned());
PreBlit::Immediate(encoder)
}
CommandSink::Deferred { ref mut is_encoding, ref mut journal } => {
*is_encoding = true;
Expand All @@ -848,19 +873,33 @@ impl CommandSink {
journal.stop();
journal.passes.push((soft::Pass::Blit, journal.blit_commands.len() .. 0));
}
journal.blit_commands.extend(commands);
PreBlit::Deferred(&mut journal.blit_commands)
}
CommandSink::Remote { pass: Some(EncodePass::Blit(ref mut list)), .. } => {
list.extend(commands);
PreBlit::Deferred(list)
}
CommandSink::Remote { ref queue, ref cmd_buffer, ref mut pass, ref mut capacity, .. } => {
if let Some(pass) = pass.take() {
pass.update(capacity);
pass.schedule(queue, cmd_buffer);
}
let mut list = Vec::with_capacity(capacity.blit);
list.extend(commands);
*pass = Some(EncodePass::Blit(list));
match *pass {
Some(EncodePass::Blit(ref mut list)) => PreBlit::Deferred(list),
_ => unreachable!()
}
}
};

match pre {
PreBlit::Immediate(encoder) => {
for com in commands {
exec_blit(encoder, com);
}
}
PreBlit::Deferred(list) => {
list.extend(commands);
}
}
}
Expand All @@ -885,21 +924,46 @@ impl CommandSink {
}
}

/// Issue provided compute commands, expecting that we are encoding a compute pass.
fn compute_commands<'a, I>(&mut self, commands: I)
where
I: Iterator<Item = soft::ComputeCommand<&'a soft::Own>>,
{
match self.pre_compute() {
PreCompute::Immediate(ref encoder) => {
for command in commands {
exec_compute(encoder, command);
}
/// Switch the active encoder to compute.
/// Second returned value is `true` if the switch has just happened.
fn switch_compute(&mut self) -> (PreCompute, bool) {
match *self {
CommandSink::Immediate { encoder_state: EncoderState::Compute(ref encoder), .. } => {
(PreCompute::Immediate(encoder), false)
}
PreCompute::Deferred(ref mut list) => {
list.extend(commands.into_iter().map(soft::ComputeCommand::own));
CommandSink::Immediate { ref cmd_buffer, ref mut encoder_state, ref mut num_passes, .. } => {
*num_passes += 1;
encoder_state.end();
let encoder = cmd_buffer.new_compute_command_encoder();
*encoder_state = EncoderState::Compute(encoder.to_owned());
(PreCompute::Immediate(encoder), true)
}
CommandSink::Deferred { ref mut is_encoding, ref mut journal } => {
*is_encoding = true;
let switch = if let Some(&(soft::Pass::Compute, _)) = journal.passes.last() {
false
} else {
journal.stop();
journal.passes.push((soft::Pass::Compute, journal.compute_commands.len() .. 0));
true
};
(PreCompute::Deferred(&mut journal.compute_commands), switch)
}
CommandSink::Remote { pass: Some(EncodePass::Compute(ref mut list)), .. } => {
(PreCompute::Deferred(list), false)
}
CommandSink::Remote { ref queue, ref cmd_buffer, ref mut pass, ref mut capacity, .. } => {
if let Some(pass) = pass.take() {
pass.update(capacity);
pass.schedule(queue, cmd_buffer);
}
let mut list = Vec::with_capacity(capacity.compute);
*pass = Some(EncodePass::Compute(list));
match *pass {
Some(EncodePass::Compute(ref mut list)) => (PreCompute::Deferred(list), true),
_ => unreachable!()
}
}
PreCompute::Void => panic!("Not in compute encoding state!"),
}
}

Expand All @@ -921,6 +985,21 @@ impl CommandSink {
}
}

fn quick_compute<'a, I>(&mut self, label: &str, commands: I)
where
I: Iterator<Item = soft::ComputeCommand<&'a soft::Own>>
{
{
let (mut pre, switch) = self.switch_compute();
assert!(switch);
pre.issue_many(commands);
if let PreCompute::Immediate(encoder) = pre {
encoder.set_label(label);
}
}
self.stop_encoding();
}

fn begin_render_pass<'a, I>(
&mut self,
door: PassDoor,
Expand Down Expand Up @@ -970,55 +1049,6 @@ impl CommandSink {
}
}
}

fn begin_compute_pass<'a, I>(
&mut self,
door: PassDoor,
init_commands: I,
) where
I: Iterator<Item = soft::ComputeCommand<&'a soft::Own>>,
{
self.stop_encoding();

match *self {
CommandSink::Immediate { ref cmd_buffer, ref mut encoder_state, ref mut num_passes, .. } => {
*num_passes += 1;
autoreleasepool(|| {
let encoder = cmd_buffer.new_compute_command_encoder();
for command in init_commands {
exec_compute(encoder, command);
}
match door {
PassDoor::Open => {
*encoder_state = EncoderState::Compute(encoder.to_owned());
}
PassDoor::Closed { label } => {
encoder.set_label(label);
encoder.end_encoding();
}
}
})
}
CommandSink::Deferred { ref mut is_encoding, ref mut journal } => {
let mut range = journal.compute_commands.len() .. 0;
journal.compute_commands.extend(init_commands.map(soft::ComputeCommand::own));
match door {
PassDoor::Open => *is_encoding = true,
PassDoor::Closed {..} => range.end = journal.compute_commands.len(),
};
journal.passes.push((soft::Pass::Compute, range))
}
CommandSink::Remote { ref queue, ref cmd_buffer, ref mut pass, ref capacity, .. } => {
let mut list = Vec::with_capacity(capacity.compute);
list.extend(init_commands.map(soft::ComputeCommand::own));
let new_pass = EncodePass::Compute(list);
match door {
PassDoor::Open => *pass = Some(new_pass),
PassDoor::Closed { .. } => new_pass.schedule(queue, cmd_buffer),
}
}
}
}
}

#[derive(Clone, Copy, Debug)]
Expand Down Expand Up @@ -1880,10 +1910,7 @@ impl com::RawCommandBuffer<Backend> for CommandBuffer {
},
];

inner.sink().begin_compute_pass(
PassDoor::Closed { label: "fill_buffer" },
commands.iter().cloned(),
);
inner.sink().quick_compute("fill_buffer", commands.iter().cloned());
}

fn update_buffer(
Expand Down Expand Up @@ -3110,40 +3137,34 @@ impl com::RawCommandBuffer<Backend> for CommandBuffer {
}

fn dispatch(&mut self, count: WorkGroupCount) {
let init_commands = self.state.make_compute_commands();
let mut inner = self.inner.borrow_mut();
let (mut pre, init) = inner.sink().switch_compute();
if init {
pre.issue_many(self.state.make_compute_commands());
}

let command = soft::ComputeCommand::Dispatch {
pre.issue(soft::ComputeCommand::Dispatch {
wg_size: self.state.work_group_size,
wg_count: MTLSize {
width: count[0] as _,
height: count[1] as _,
depth: count[2] as _,
},
};

let mut inner = self.inner.borrow_mut();
let sink = inner.sink();
//TODO: re-use compute encoders
sink.begin_compute_pass(PassDoor::Open, init_commands);
sink.compute_commands(iter::once(command));
sink.stop_encoding();
});
}

fn dispatch_indirect(&mut self, buffer: &native::Buffer, offset: buffer::Offset) {
let init_commands = self.state.make_compute_commands();
let mut inner = self.inner.borrow_mut();
let (mut pre, init) = inner.sink().switch_compute();
if init {
pre.issue_many(self.state.make_compute_commands());
}

let command = soft::ComputeCommand::DispatchIndirect {
pre.issue(soft::ComputeCommand::DispatchIndirect {
wg_size: self.state.work_group_size,
buffer: BufferPtr(buffer.raw.as_ptr()),
offset,
};

let mut inner = self.inner.borrow_mut();
let sink = inner.sink();
//TODO: re-use compute encoders
sink.begin_compute_pass(PassDoor::Open, init_commands);
sink.compute_commands(iter::once(command));
sink.stop_encoding();
});
}

fn copy_buffer<T>(
Expand Down Expand Up @@ -3212,12 +3233,8 @@ impl com::RawCommandBuffer<Backend> for CommandBuffer {
if !blit_commands.is_empty() {
sink.blit_commands(blit_commands.into_iter());
}

if compute_commands.len() > 1 { // first is bind PSO
sink.begin_compute_pass(
PassDoor::Closed { label: "copy_buffer" },
compute_commands.into_iter(),
);
sink.quick_compute("copy_buffer", compute_commands.into_iter());
}
}

Expand Down

0 comments on commit c5561fa

Please sign in to comment.