Skip to content

Commit

Permalink
Add basic support for struct DSTs (#504)
Browse files Browse the repository at this point in the history
* Add basic support for struct DSTs

* Add tests

* cleanup tests

* Update with entry changes, address review

* Address review

* Update allocate_const_scalar.stderr

* Add ArrayStride decoration to OpTypeRuntimeArray
  • Loading branch information
Hentropy authored Mar 29, 2021
1 parent c3a3b20 commit 05ce407
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 34 deletions.
94 changes: 76 additions & 18 deletions crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ use crate::builder_spirv::SpirvValue;
use crate::spirv_type::SpirvType;
use rspirv::dr::Operand;
use rspirv::spirv::{Decoration, ExecutionModel, FunctionControl, StorageClass, Word};
use rustc_codegen_ssa::traits::BaseTypeMethods;
use rustc_hir as hir;
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::layout::{HasParamEnv, TyAndLayout};
use rustc_middle::ty::{Instance, Ty, TyKind};
use rustc_span::Span;
use rustc_target::abi::call::{FnAbi, PassMode};
use rustc_target::abi::LayoutOf;
use rustc_target::abi::{
call::{ArgAbi, ArgAttribute, ArgAttributes, FnAbi, PassMode},
LayoutOf, Size,
};
use std::collections::HashMap;

impl<'tcx> CodegenCx<'tcx> {
Expand All @@ -37,9 +40,27 @@ impl<'tcx> CodegenCx<'tcx> {
};
let fn_hir_id = self.tcx.hir().local_def_id_to_hir_id(local_id);
let body = self.tcx.hir().body(self.tcx.hir().body_owned_by(fn_hir_id));
const EMPTY: ArgAttribute = ArgAttribute::empty();
for (abi, arg) in fn_abi.args.iter().zip(body.params) {
match abi.mode {
PassMode::Direct(_) | PassMode::Indirect { .. } => {}
PassMode::Direct(_)
| PassMode::Indirect { .. }
// plain DST/RTA/VLA
| PassMode::Pair(
ArgAttributes {
pointee_size: Size::ZERO,
..
},
ArgAttributes { regular: EMPTY, .. },
)
// DST struct with fields before the DST member
| PassMode::Pair(
ArgAttributes { .. },
ArgAttributes {
pointee_size: Size::ZERO,
..
},
) => {}
_ => self.tcx.sess.span_err(
arg.span,
&format!("PassMode {:?} invalid for entry point parameter", abi.mode),
Expand All @@ -63,7 +84,7 @@ impl<'tcx> CodegenCx<'tcx> {
self.shader_entry_stub(
self.tcx.def_span(instance.def_id()),
entry_func,
fn_abi,
&fn_abi.args,
body.params,
name,
execution_model,
Expand All @@ -82,7 +103,7 @@ impl<'tcx> CodegenCx<'tcx> {
&self,
span: Span,
entry_func: SpirvValue,
entry_fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
arg_abis: &[ArgAbi<'tcx, Ty<'tcx>>],
hir_params: &[hir::Param<'tcx>],
name: String,
execution_model: ExecutionModel,
Expand All @@ -94,25 +115,22 @@ impl<'tcx> CodegenCx<'tcx> {
}
.def(span, self);
let entry_func_return_type = match self.lookup_type(entry_func.ty) {
SpirvType::Function {
return_type,
arguments: _,
} => return_type,
SpirvType::Function { return_type, .. } => return_type,
other => self.tcx.sess.fatal(&format!(
"Invalid entry_stub type: {}",
other.debug(entry_func.ty, self)
)),
};
let mut decoration_locations = HashMap::new();
// Create OpVariables before OpFunction so they're global instead of local vars.
let declared_params = entry_fn_abi
.args
let declared_params = arg_abis
.iter()
.zip(hir_params)
.map(|(entry_fn_arg, hir_param)| {
self.declare_parameter(entry_fn_arg.layout, hir_param, &mut decoration_locations)
})
.collect::<Vec<_>>();
let len_t = self.type_isize();
let mut emit = self.emit_global();
let fn_id = emit
.begin_function(void, None, FunctionControl::NONE, fn_void_void)
Expand All @@ -121,12 +139,19 @@ impl<'tcx> CodegenCx<'tcx> {
// Adjust any global `OpVariable`s as needed (e.g. loading from `Input`s).
let arguments: Vec<_> = declared_params
.iter()
.zip(&entry_fn_abi.args)
.zip(arg_abis)
.zip(hir_params)
.map(|((&(var, storage_class), entry_fn_arg), hir_param)| {
match entry_fn_arg.layout.ty.kind() {
TyKind::Ref(..) => var,

.flat_map(|((&(var, storage_class), entry_fn_arg), hir_param)| {
let mut dst_len_arg = None;
let arg = match entry_fn_arg.layout.ty.kind() {
TyKind::Ref(_, ty, _) => {
if !ty.is_sized(self.tcx.at(span), self.param_env()) {
dst_len_arg.replace(
self.dst_length_argument(&mut emit, ty, hir_param, len_t, var),
);
}
var
}
_ => match entry_fn_arg.mode {
PassMode::Indirect { .. } => var,
PassMode::Direct(_) => {
Expand All @@ -142,7 +167,8 @@ impl<'tcx> CodegenCx<'tcx> {
}
_ => unreachable!(),
},
}
};
std::iter::once(arg).chain(dst_len_arg)
})
.collect();
emit.function_call(
Expand Down Expand Up @@ -170,6 +196,38 @@ impl<'tcx> CodegenCx<'tcx> {
fn_id
}

fn dst_length_argument(
&self,
emit: &mut std::cell::RefMut<'_, rspirv::dr::Builder>,
ty: Ty<'tcx>,
hir_param: &hir::Param<'tcx>,
len_t: Word,
var: Word,
) -> Word {
match ty.kind() {
TyKind::Adt(adt_def, substs) => {
let (member_idx, field_def) = adt_def.all_fields().enumerate().last().unwrap();
let field_ty = field_def.ty(self.tcx, substs);
if !matches!(field_ty.kind(), TyKind::Slice(..)) {
self.tcx.sess.span_fatal(
hir_param.ty_span,
"DST parameters are currently restricted to a reference to a struct whose last field is a slice.",
)
}
emit.array_length(len_t, None, var, member_idx as u32)
.unwrap()
}
TyKind::Slice(..) | TyKind::Str => self.tcx.sess.span_fatal(
hir_param.ty_span,
"Straight slices are not yet supported, wrap the slice in a newtype.",
),
_ => self
.tcx
.sess
.span_fatal(hir_param.ty_span, "Unsupported parameter type."),
}
}

fn declare_parameter(
&self,
layout: TyAndLayout<'tcx>,
Expand Down
11 changes: 11 additions & 0 deletions crates/rustc_codegen_spirv/src/spirv_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,17 @@ impl SpirvType {
}
Self::RuntimeArray { element } => {
let result = cx.emit_global().type_runtime_array(element);
// ArrayStride decoration wants in *bytes*
let element_size = cx
.lookup_type(element)
.sizeof(cx)
.expect("Element of sized array must be sized")
.bytes();
cx.emit_global().decorate(
result,
Decoration::ArrayStride,
iter::once(Operand::LiteralInt32(element_size as u32)),
);
if cx.kernel_mode {
cx.zombie_with_span(result, def_span, "RuntimeArray in kernel mode");
}
Expand Down
82 changes: 67 additions & 15 deletions crates/spirv-builder/src/test/basic.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{dis_fn, dis_globals, val, val_vulkan};
use super::{dis_entry_fn, dis_fn, dis_globals, val, val_vulkan};
use std::ffi::OsStr;

struct SetEnvVar<'a> {
Expand Down Expand Up @@ -183,20 +183,21 @@ OpEntryPoint Fragment %1 "main"
OpExecutionMode %1 OriginUpperLeft
OpName %2 "test_project::add_decorate"
OpName %3 "test_project::main"
OpDecorate %4 DescriptorSet 0
OpDecorate %4 Binding 0
%5 = OpTypeVoid
%6 = OpTypeFunction %5
%7 = OpTypeInt 32 0
%8 = OpTypePointer Function %7
%9 = OpConstant %7 1
%10 = OpTypeFloat 32
%11 = OpTypeImage %10 2D 0 0 0 1 Unknown
%12 = OpTypeSampledImage %11
%13 = OpTypeRuntimeArray %12
%14 = OpTypePointer UniformConstant %13
%4 = OpVariable %14 UniformConstant
%15 = OpTypePointer UniformConstant %12"#,
OpDecorate %4 ArrayStride 4
OpDecorate %5 DescriptorSet 0
OpDecorate %5 Binding 0
%6 = OpTypeVoid
%7 = OpTypeFunction %6
%8 = OpTypeInt 32 0
%9 = OpTypePointer Function %8
%10 = OpConstant %8 1
%11 = OpTypeFloat 32
%12 = OpTypeImage %11 2D 0 0 0 1 Unknown
%13 = OpTypeSampledImage %12
%4 = OpTypeRuntimeArray %13
%14 = OpTypePointer UniformConstant %4
%5 = OpVariable %14 UniformConstant
%15 = OpTypePointer UniformConstant %13"#,
);
}

Expand Down Expand Up @@ -479,3 +480,54 @@ fn ptr_copy_from_method() {
"#
);
}

#[test]
fn index_user_dst() {
dis_entry_fn(
r#"
#[spirv(fragment)]
pub fn main(
#[spirv(uniform, descriptor_set = 0, binding = 0)] slice: &mut SliceF32,
) {
let float: f32 = slice.rta[0];
let _ = float;
}
pub struct SliceF32 {
rta: [f32],
}
"#,
"main",
r#"%1 = OpFunction %2 None %3
%4 = OpLabel
%5 = OpArrayLength %6 %7 0
%8 = OpCompositeInsert %9 %7 %10 0
%11 = OpCompositeInsert %9 %5 %8 1
%12 = OpAccessChain %13 %7 %14
%15 = OpULessThan %16 %14 %5
OpSelectionMerge %17 None
OpBranchConditional %15 %18 %19
%18 = OpLabel
%20 = OpAccessChain %13 %7 %14
%21 = OpInBoundsAccessChain %22 %20 %14
%23 = OpLoad %24 %21
OpReturn
%19 = OpLabel
OpBranch %25
%25 = OpLabel
OpBranch %26
%26 = OpLabel
%27 = OpPhi %16 %28 %25 %28 %29
OpLoopMerge %30 %29 None
OpBranchConditional %27 %31 %30
%31 = OpLabel
OpBranch %29
%29 = OpLabel
OpBranch %26
%30 = OpLabel
OpUnreachable
%17 = OpLabel
OpUnreachable
OpFunctionEnd"#,
)
}
27 changes: 27 additions & 0 deletions crates/spirv-builder/src/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,33 @@ fn dis_fn(src: &str, func: &str, expect: &str) {
assert_str_eq(expect, &func.disassemble())
}

fn dis_entry_fn(src: &str, func: &str, expect: &str) {
let _lock = global_lock();
let module = read_module(&build(src)).unwrap();
let id = module
.entry_points
.iter()
.find(|inst| inst.operands.last().unwrap().unwrap_literal_string() == func)
.unwrap_or_else(|| {
panic!(
"no entry point with the name `{}` found in:\n{}\n",
func,
module.disassemble()
)
})
.operands[1]
.unwrap_id_ref();
let mut func = module
.functions
.into_iter()
.find(|f| f.def_id().unwrap() == id)
.unwrap();
// Compact to make IDs more stable
compact_ids(&mut func);
use rspirv::binary::Disassemble;
assert_str_eq(expect, &func.disassemble())
}

fn dis_globals(src: &str, expect: &str) {
let _lock = global_lock();
let module = read_module(&build(src)).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion tests/ui/lang/core/ptr/allocate_const_scalar.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ error: pointer has non-null integer address
|
= note: Stack:
allocate_const_scalar::main
Unnamed function ID %4
Unnamed function ID %5

error: invalid binary:0:0 - No OpEntryPoint instruction was found. This is only allowed if the Linkage capability is being used.
|
Expand Down

0 comments on commit 05ce407

Please sign in to comment.