Skip to content
This repository has been archived by the owner on Jan 29, 2025. It is now read-only.

[spv-in] Convert conditional backedges to break if. #2290

Merged
merged 1 commit into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/front/spv/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -597,15 +597,19 @@ impl<'function> BlockContext<'function> {
crate::Span::default(),
)
}
super::BodyFragment::Loop { body, continuing } => {
super::BodyFragment::Loop {
body,
continuing,
break_if,
} => {
let body = lower_impl(blocks, bodies, body);
let continuing = lower_impl(blocks, bodies, continuing);

block.push(
crate::Statement::Loop {
body,
continuing,
break_if: None,
break_if,
},
crate::Span::default(),
)
Expand Down
121 changes: 104 additions & 17 deletions src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ enum BodyFragment {
Loop {
body: BodyIndex,
continuing: BodyIndex,

/// If the SPIR-V loop's back-edge branch is conditional, this is the
/// expression that must be `false` for the back-edge to be taken, with
/// `true` being for the "loop merge" (which breaks out of the loop).
break_if: Option<Handle<crate::Expression>>,
eddyb marked this conversation as resolved.
Show resolved Hide resolved
},
Switch {
selector: Handle<crate::Expression>,
Expand Down Expand Up @@ -429,7 +434,7 @@ struct PhiExpression {
expressions: Vec<(spirv::Word, spirv::Word)>,
}

#[derive(Debug)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum MergeBlockInformation {
LoopMerge,
LoopContinue,
Expand Down Expand Up @@ -3114,35 +3119,121 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
get_expr_handle!(condition_id, lexp)
};

// HACK(eddyb) Naga doesn't seem to have this helper,
// so it's declared on the fly here for convenience.
#[derive(Copy, Clone)]
struct BranchTarget {
label_id: spirv::Word,
merge_info: Option<MergeBlockInformation>,
}
let branch_target = |label_id| BranchTarget {
label_id,
merge_info: ctx.mergers.get(&label_id).copied(),
};

let true_target = branch_target(self.next()?);
let false_target = branch_target(self.next()?);

// Consume branch weights
for _ in 4..inst.wc {
let _ = self.next()?;
}

// Handle `OpBranchConditional`s used at the end of a loop
// body's "continuing" section as a "conditional backedge",
// i.e. a `do`-`while` condition, or `break if` in WGSL.

// HACK(eddyb) this has to go to the parent *twice*, because
// `OpLoopMerge` left the "continuing" section nested in the
// loop body in terms of `parent`, but not `BodyFragment`.
let parent_body_idx = ctx.bodies[body_idx].parent;
let parent_parent_body_idx = ctx.bodies[parent_body_idx].parent;
match ctx.bodies[parent_parent_body_idx].data[..] {
// The `OpLoopMerge`'s `continuing` block and the loop's
// backedge block may not be the same, but they'll both
// belong to the same body.
[.., BodyFragment::Loop {
body: loop_body_idx,
continuing: loop_continuing_idx,
break_if: ref mut break_if_slot @ None,
}] if body_idx == loop_continuing_idx => {
eddyb marked this conversation as resolved.
Show resolved Hide resolved
// Try both orderings of break-vs-backedge, because
// SPIR-V is symmetrical here, unlike WGSL `break if`.
let break_if_cond = [true, false].into_iter().find_map(|true_breaks| {
let (break_candidate, backedge_candidate) = if true_breaks {
(true_target, false_target)
} else {
(false_target, true_target)
};

if break_candidate.merge_info
!= Some(MergeBlockInformation::LoopMerge)
{
return None;
}

// HACK(eddyb) since Naga doesn't explicitly track
// backedges, this is checking for the outcome of
// `OpLoopMerge` below (even if it looks weird).
let backedge_candidate_is_backedge =
backedge_candidate.merge_info.is_none()
&& ctx.body_for_label.get(&backedge_candidate.label_id)
== Some(&loop_body_idx);
if !backedge_candidate_is_backedge {
return None;
}

Some(if true_breaks {
condition
} else {
ctx.expressions.append(
crate::Expression::Unary {
op: crate::UnaryOperator::Not,
expr: condition,
},
span,
)
})
});

if let Some(break_if_cond) = break_if_cond {
*break_if_slot = Some(break_if_cond);

// This `OpBranchConditional` ends the "continuing"
// section of the loop body as normal, with the
// `break if` condition having been stashed above.
break None;
}
}
_ => {}
}

block.extend(emitter.finish(ctx.expressions));
ctx.blocks.insert(block_id, block);
let body = &mut ctx.bodies[body_idx];
body.data.push(BodyFragment::BlockId(block_id));

let true_id = self.next()?;
let false_id = self.next()?;

let same_target = true_id == false_id;
let same_target = true_target.label_id == false_target.label_id;

// Start a body block for the `accept` branch.
let accept = ctx.bodies.len();
let mut accept_block = Body::with_parent(body_idx);

// If the `OpBranchConditional`target is somebody else's
// If the `OpBranchConditional` target is somebody else's
// merge or continue block, then put a `Break` or `Continue`
// statement in this new body block.
if let Some(info) = ctx.mergers.get(&true_id) {
if let Some(info) = true_target.merge_info {
merger(
match same_target {
true => &mut ctx.bodies[body_idx],
false => &mut accept_block,
},
info,
&info,
)
} else {
// Note the body index for the block we're branching to.
let prev = ctx.body_for_label.insert(
true_id,
true_target.label_id,
match same_target {
true => body_idx,
false => accept,
Expand All @@ -3161,10 +3252,10 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let reject = ctx.bodies.len();
let mut reject_block = Body::with_parent(body_idx);

if let Some(info) = ctx.mergers.get(&false_id) {
merger(&mut reject_block, info)
if let Some(info) = false_target.merge_info {
merger(&mut reject_block, &info)
} else {
let prev = ctx.body_for_label.insert(false_id, reject);
let prev = ctx.body_for_label.insert(false_target.label_id, reject);
debug_assert!(prev.is_none());
}

Expand All @@ -3177,11 +3268,6 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
reject,
});

// Consume branch weights
for _ in 4..inst.wc {
let _ = self.next()?;
}

return Ok(());
}
Op::Switch => {
Expand Down Expand Up @@ -3351,6 +3437,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
parent_body.data.push(BodyFragment::Loop {
body: loop_body_idx,
continuing: continue_idx,
break_if: None,
});
body_idx = loop_body_idx;
}
Expand Down
Binary file added tests/in/spv/do-while.spv
Binary file not shown.
64 changes: 64 additions & 0 deletions tests/in/spv/do-while.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
;; Ensure that `do`-`while`-style loops, with conditional backedges, are properly
;; supported, via `break if` (as `continuing { ... if c { break; } }` is illegal).
;;
;; The SPIR-V below was compiled from this GLSL fragment shader:
;; ```glsl
;; #version 450
;;
;; void f(bool cond) {
;; do {} while(cond);
;; }
;;
;; void main() {
;; f(false);
;; }
;; ```

OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main"
OpExecutionMode %main OriginUpperLeft
OpSource GLSL 450
OpName %main "main"
OpName %f_b1_ "f(b1;"
OpName %cond "cond"
OpName %param "param"
%void = OpTypeVoid
%3 = OpTypeFunction %void
%bool = OpTypeBool
%_ptr_Function_bool = OpTypePointer Function %bool
%8 = OpTypeFunction %void %_ptr_Function_bool
%false = OpConstantFalse %bool

%main = OpFunction %void None %3
%5 = OpLabel
%param = OpVariable %_ptr_Function_bool Function
OpStore %param %false
%19 = OpFunctionCall %void %f_b1_ %param
OpReturn
OpFunctionEnd

%f_b1_ = OpFunction %void None %8
%cond = OpFunctionParameter %_ptr_Function_bool

%11 = OpLabel
OpBranch %12

%12 = OpLabel
OpLoopMerge %14 %15 None
OpBranch %13

%13 = OpLabel
OpBranch %15

;; This is the "continuing" block, and it contains a conditional branch between
;; the backedge (back to the loop header) and the loop merge ("break") target.
%15 = OpLabel
%16 = OpLoad %bool %cond
OpBranchConditional %16 %12 %14

%14 = OpLabel
OpReturn

OpFunctionEnd
33 changes: 33 additions & 0 deletions tests/out/glsl/do-while.main.Fragment.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#version 310 es

precision highp float;
precision highp int;


void fb1_(inout bool cond) {
bool loop_init = true;
while(true) {
if (!loop_init) {
bool _e6 = cond;
bool unnamed = !(_e6);
if (unnamed) {
break;
}
}
loop_init = false;
continue;
}
return;
}

void main_1() {
bool param = false;
param = false;
fb1_(param);
return;
}

void main() {
main_1();
}

31 changes: 31 additions & 0 deletions tests/out/hlsl/do-while.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

void fb1_(inout bool cond)
{
bool loop_init = true;
while(true) {
if (!loop_init) {
bool _expr6 = cond;
bool unnamed = !(_expr6);
if (unnamed) {
break;
}
}
loop_init = false;
continue;
}
return;
}

void main_1()
{
bool param = (bool)0;

param = false;
fb1_(param);
return;
}

void main()
{
main_1();
}
3 changes: 3 additions & 0 deletions tests/out/hlsl/do-while.hlsl.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
vertex=()
fragment=(main:ps_5_1 )
compute=()
37 changes: 37 additions & 0 deletions tests/out/msl/do-while.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// language: metal2.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;


void fb1_(
thread bool& cond
) {
bool loop_init = true;
while(true) {
if (!loop_init) {
bool _e6 = cond;
bool unnamed = !(_e6);
if (!(cond)) {
break;
}
}
loop_init = false;
continue;
}
return;
}

void main_1(
) {
bool param = {};
param = false;
fb1_(param);
return;
}

fragment void main_(
) {
main_1();
}
Loading