Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inline asm!: support writing _ in lieu of return types, for basic inference. #376

Merged
merged 4 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 157 additions & 7 deletions crates/rustc_codegen_spirv/src/builder/spirv_asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::builder_spirv::SpirvValue;
use crate::spirv_type::SpirvType;

use super::Builder;
use crate::codegen_cx::CodegenCx;
use rspirv::dr;
use rspirv::grammar::{LogicalOperand, OperandKind, OperandQuantifier};
use rspirv::spirv::{
Expand Down Expand Up @@ -131,8 +132,15 @@ impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> {
}

let mut id_map = HashMap::new();
let mut id_to_type_map = HashMap::new();
for operand in operands {
if let InlineAsmOperandRef::In { reg: _, value } = operand {
let value = value.immediate();
id_to_type_map.insert(value.def(self), value.ty);
}
}
for line in tokens {
self.codegen_asm(&mut id_map, line.into_iter());
self.codegen_asm(&mut id_map, &mut id_to_type_map, line.into_iter());
}
}
}
Expand Down Expand Up @@ -248,6 +256,10 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
self.err("OpTypeArray in asm! is not supported yet");
return;
}
Op::TypeSampledImage => SpirvType::SampledImage {
image_type: inst.operands[0].unwrap_id_ref(),
}
.def(self.span(), self),
_ => {
self.emit()
.insert_into_block(dr::InsertPoint::End, inst)
Expand All @@ -265,6 +277,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
fn codegen_asm<'a>(
&mut self,
id_map: &mut HashMap<&'a str, Word>,
id_to_type_map: &mut HashMap<Word, Word>,
mut tokens: impl Iterator<Item = Token<'a, 'cx, 'tcx>>,
) where
'cx: 'a,
Expand Down Expand Up @@ -339,7 +352,10 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
result_id,
operands: vec![],
};
self.parse_operands(id_map, tokens, &mut instruction);
self.parse_operands(id_map, id_to_type_map, tokens, &mut instruction);
if let Some(result_type) = instruction.result_type {
id_to_type_map.insert(instruction.result_id.unwrap(), result_type);
}
self.insert_inst(id_map, instruction);
if let Some(OutRegister::Place(place)) = out_register {
self.emit()
Expand All @@ -356,13 +372,15 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
fn parse_operands<'a>(
&mut self,
id_map: &mut HashMap<&'a str, Word>,
id_to_type_map: &HashMap<Word, Word>,
mut tokens: impl Iterator<Item = Token<'a, 'cx, 'tcx>>,
instruction: &mut dr::Instruction,
) where
'cx: 'a,
'tcx: 'a,
{
let mut saw_id_result = false;
let mut need_result_type_infer = false;
for &LogicalOperand { kind, quantifier } in instruction.class.operands {
if kind == OperandKind::IdResult {
assert_eq!(quantifier, OperandQuantifier::One);
Expand All @@ -375,6 +393,22 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
saw_id_result = true;
continue;
}
if kind == OperandKind::IdResultType {
assert_eq!(quantifier, OperandQuantifier::One);
if let Some(token) = tokens.next() {
if let Token::Word("_") = token {
need_result_type_infer = true;
} else if let Some(id) = self.parse_id_in(id_map, token) {
instruction.result_type = Some(id);
}
} else {
self.err(&format!(
"instruction {} expects a result type",
instruction.class.opname
));
}
continue;
}
match quantifier {
OperandQuantifier::One => {
if !self.parse_one_operand(id_map, instruction, kind, &mut tokens) {
Expand Down Expand Up @@ -406,6 +440,125 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
instruction.class.opname
));
}

if need_result_type_infer {
assert!(instruction.result_type.is_none());

match self.infer_result_type(id_to_type_map, instruction) {
Some(result_type) => instruction.result_type = Some(result_type),
None => self.err(&format!(
"instruction {} cannot have its result type inferred",
instruction.class.opname
)),
}
}
}

fn infer_result_type(
&self,
id_to_type_map: &HashMap<Word, Word>,
instruction: &dr::Instruction,
) -> Option<Word> {
use crate::spirv_type_constraints::{instruction_signatures, InstSig, TyListPat, TyPat};

struct Mismatch;

/// Recursively match `ty` against `pat`, returning one of:
/// * `Ok(None)`: `pat` matched but contained no type variables
/// * `Ok(Some(var))`: `pat` matched and `var` is the type variable
/// * `Err(Mismatch)`: `pat` didn't match or isn't supported right now
fn apply_ty_pat(
cx: &CodegenCx<'_>,
pat: &TyPat<'_>,
ty: Word,
) -> Result<Option<Word>, Mismatch> {
match pat {
TyPat::Any => Ok(None),
&TyPat::T => Ok(Some(ty)),
TyPat::Either(a, b) => {
apply_ty_pat(cx, a, ty).or_else(|Mismatch| apply_ty_pat(cx, b, ty))
}
_ => match (pat, cx.lookup_type(ty)) {
(TyPat::Void, SpirvType::Void) => Ok(None),
(TyPat::Pointer(pat), SpirvType::Pointer { pointee: ty, .. })
| (TyPat::Vector(pat), SpirvType::Vector { element: ty, .. })
| (
TyPat::Vector4(pat),
SpirvType::Vector {
element: ty,
count: 4,
},
)
| (
TyPat::Image(pat),
SpirvType::Image {
sampled_type: ty, ..
},
)
| (TyPat::SampledImage(pat), SpirvType::SampledImage { image_type: ty }) => {
apply_ty_pat(cx, pat, ty)
}
_ => Err(Mismatch),
},
}
}

// FIXME(eddyb) try multiple signatures until one fits.
let mut sig = match instruction_signatures(instruction.class.opcode)? {
[sig @ InstSig {
output: Some(_), ..
}] => *sig,
_ => return None,
};

let mut combined_var = None;

let mut ids = instruction.operands.iter().filter_map(|o| o.id_ref_any());
while let TyListPat::Cons { first: pat, suffix } = *sig.inputs {
let &ty = id_to_type_map.get(&ids.next()?)?;
match apply_ty_pat(self, pat, ty) {
Ok(Some(var)) => match combined_var {
Some(combined_var) => {
// FIXME(eddyb) this could use some error reporting
// (it's a type mismatch), although we could also
// just use the first type and let validation take
// care of the mismatch
if var != combined_var {
return None;
}
}
None => combined_var = Some(var),
},
Ok(None) => {}
Err(Mismatch) => return None,
}
sig.inputs = suffix;
}
match sig.inputs {
TyListPat::Any => {}
TyListPat::Nil => {
if ids.next().is_some() {
return None;
}
}
_ => return None,
}

let var = combined_var?;
match sig.output.unwrap() {
&TyPat::T => Some(var),
TyPat::Vector4(&TyPat::T) => Some(
SpirvType::Vector {
element: var,
count: 4,
}
.def(self.span(), self),
),
TyPat::SampledImage(&TyPat::T) => {
Some(SpirvType::SampledImage { image_type: var }.def(self.span(), self))
}
_ => None,
}
}

fn check_reg(&mut self, span: Span, reg: &InlineAsmRegOrRegClass) {
Expand Down Expand Up @@ -668,12 +821,9 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
Token::Typeof(_, _, _) => None,
};
match (kind, word) {
(OperandKind::IdResultType, _) => {
if let Some(id) = self.parse_id_in(id_map, token) {
inst.result_type = Some(id)
}
(OperandKind::IdResultType, _) | (OperandKind::IdResult, _) => {
bug!("should be handled by parse_operands")
}
(OperandKind::IdResult, _) => bug!("should be handled by parse_operands"),
(OperandKind::IdMemorySemantics, _) => {
if let Some(id) = self.parse_id_in(id_map, token) {
inst.operands.push(dr::Operand::IdMemorySemantics(id))
Expand Down
1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ mod decorations;
mod link;
mod linker;
mod spirv_type;
mod spirv_type_constraints;
mod symbols;

use builder::Builder;
Expand Down
Loading