diff --git a/cranelift/isle/isle/src/compile.rs b/cranelift/isle/isle/src/compile.rs index de0a3348722b..23c11df1bbe8 100644 --- a/cranelift/isle/isle/src/compile.rs +++ b/cranelift/isle/isle/src/compile.rs @@ -8,6 +8,6 @@ pub fn compile(defs: &ast::Defs, options: &codegen::CodegenOptions) -> Result Vec { - let inst = InstId(self.insts.len()); - let mut outs = vec![]; - for (i, _arg_ty) in arg_tys.iter().enumerate() { - let val = Value::Pattern { inst, output: i }; - outs.push(val); - } - let arg_tys = arg_tys.iter().cloned().collect(); - self.add_inst(PatternInst::MatchVariant { + let outputs = arg_tys.len(); + let arg_tys = arg_tys.into(); + let inst = self.add_inst(PatternInst::MatchVariant { input, input_ty, arg_tys, variant, }); - outs + (0..outputs) + .map(|output| Value::Pattern { inst, output }) + .collect() } fn add_extract( @@ -310,14 +307,8 @@ impl PatternSequence { infallible: bool, multi: bool, ) -> Vec { - let inst = InstId(self.insts.len()); - let mut outs = vec![]; - for i in 0..output_tys.len() { - let val = Value::Pattern { inst, output: i }; - outs.push(val); - } - let output_tys = output_tys.iter().cloned().collect(); - self.add_inst(PatternInst::Extract { + let outputs = output_tys.len(); + let inst = self.add_inst(PatternInst::Extract { inputs, input_tys, output_tys, @@ -325,7 +316,9 @@ impl PatternSequence { infallible, multi, }); - outs + (0..outputs) + .map(|output| Value::Pattern { inst, output }) + .collect() } fn add_expr_seq(&mut self, seq: ExprSequence, output: Value, output_ty: TypeId) -> Value { @@ -344,7 +337,6 @@ impl PatternSequence { fn gen_pattern( &mut self, input: ValueOrArgs, - typeenv: &TypeEnv, termenv: &TermEnv, pat: &Pattern, vars: &mut StableMap, @@ -356,8 +348,7 @@ impl PatternSequence { if let Some(v) = input.to_value() { vars.insert(var, v); } - let root_term = self.gen_pattern(input, typeenv, termenv, &*subpat, vars); - root_term + self.gen_pattern(input, termenv, subpat, vars); } &Pattern::Var(ty, var) => { // Assert that the value matches the existing bound var. @@ -394,32 +385,15 @@ impl PatternSequence { let arg_tys = &termdata.arg_tys[..]; for (i, subpat) in args.iter().enumerate() { let value = self.add_arg(i, arg_tys[i]); - self.gen_pattern( - ValueOrArgs::Value(value), - typeenv, - termenv, - subpat, - vars, - ); + self.gen_pattern(ValueOrArgs::Value(value), termenv, subpat, vars); } } ValueOrArgs::Value(input) => { // Determine whether the term has an external extractor or not. let termdata = &termenv.terms[term.index()]; - let arg_tys = &termdata.arg_tys[..]; - match &termdata.kind { + let arg_values = match &termdata.kind { TermKind::EnumVariant { variant } => { - let arg_values = - self.add_match_variant(input, ty, arg_tys, *variant); - for (subpat, value) in args.iter().zip(arg_values.into_iter()) { - self.gen_pattern( - ValueOrArgs::Value(value), - typeenv, - termenv, - subpat, - vars, - ); - } + self.add_match_variant(input, ty, &termdata.arg_tys, *variant) } TermKind::Decl { extractor_kind: None, @@ -434,50 +408,36 @@ impl PatternSequence { panic!("Should have been expanded away") } TermKind::Decl { - extractor_kind: Some(ExtractorKind::ExternalExtractor { .. }), + multi, + extractor_kind: + Some(ExtractorKind::ExternalExtractor { infallible, .. }), .. } => { - let ext_sig = termdata.extractor_sig(typeenv).unwrap(); - // Evaluate all `input` args. - let mut inputs = vec![]; - let mut input_tys = vec![]; - let mut output_tys = vec![]; - let mut output_pats = vec![]; - inputs.push(input); - input_tys.push(termdata.ret_ty); - for arg in args { - output_tys.push(arg.ty()); - output_pats.push(arg); - } + let inputs = vec![input]; + let input_tys = vec![termdata.ret_ty]; + let output_tys = args.iter().map(|arg| arg.ty()).collect(); // Invoke the extractor. - let arg_values = self.add_extract( + self.add_extract( inputs, input_tys, output_tys, term, - ext_sig.infallible, - ext_sig.multi, - ); - - for (pat, &val) in output_pats.iter().zip(arg_values.iter()) { - self.gen_pattern( - ValueOrArgs::Value(val), - typeenv, - termenv, - pat, - vars, - ); - } + *infallible && !*multi, + *multi, + ) } + }; + for (pat, val) in args.iter().zip(arg_values) { + self.gen_pattern(ValueOrArgs::Value(val), termenv, pat, vars); } } } } &Pattern::And(_ty, ref children) => { for child in children { - self.gen_pattern(input, typeenv, termenv, child, vars); + self.gen_pattern(input, termenv, child, vars); } } &Pattern::Wildcard(_ty) => { @@ -506,11 +466,10 @@ impl ExprSequence { fn add_create_variant( &mut self, - inputs: &[(Value, TypeId)], + inputs: Vec<(Value, TypeId)>, ty: TypeId, variant: VariantId, ) -> Value { - let inputs = inputs.iter().cloned().collect(); let inst = self.add_inst(ExprInst::CreateVariant { inputs, ty, @@ -521,13 +480,12 @@ impl ExprSequence { fn add_construct( &mut self, - inputs: &[(Value, TypeId)], + inputs: Vec<(Value, TypeId)>, ty: TypeId, term: TermId, infallible: bool, multi: bool, ) -> Value { - let inputs = inputs.iter().cloned().collect(); let inst = self.add_inst(ExprInst::Construct { inputs, ty, @@ -551,7 +509,6 @@ impl ExprSequence { /// term ID, if any. fn gen_expr( &mut self, - typeenv: &TypeEnv, termenv: &TermEnv, expr: &Expr, vars: &StableMap, @@ -567,22 +524,22 @@ impl ExprSequence { } => { let mut vars = vars.clone(); for &(var, _var_ty, ref var_expr) in bindings { - let var_value = self.gen_expr(typeenv, termenv, &*var_expr, &vars); + let var_value = self.gen_expr(termenv, var_expr, &vars); vars.insert(var, var_value); } - self.gen_expr(typeenv, termenv, body, &vars) + self.gen_expr(termenv, body, &vars) } &Expr::Var(_ty, var_id) => vars.get(&var_id).cloned().unwrap(), &Expr::Term(ty, term, ref arg_exprs) => { let termdata = &termenv.terms[term.index()]; - let mut arg_values_tys = vec![]; - for (arg_ty, arg_expr) in termdata.arg_tys.iter().cloned().zip(arg_exprs.iter()) { - arg_values_tys - .push((self.gen_expr(typeenv, termenv, &*arg_expr, &vars), arg_ty)); - } + let arg_values_tys = arg_exprs + .iter() + .map(|arg_expr| self.gen_expr(termenv, arg_expr, vars)) + .zip(termdata.arg_tys.iter().copied()) + .collect(); match &termdata.kind { TermKind::EnumVariant { variant } => { - self.add_create_variant(&arg_values_tys[..], ty, *variant) + self.add_create_variant(arg_values_tys, ty, *variant) } TermKind::Decl { constructor_kind: Some(ConstructorKind::InternalConstructor), @@ -590,7 +547,7 @@ impl ExprSequence { .. } => { self.add_construct( - &arg_values_tys[..], + arg_values_tys, ty, term, /* infallible = */ false, @@ -604,7 +561,7 @@ impl ExprSequence { .. } => { self.add_construct( - &arg_values_tys[..], + arg_values_tys, ty, term, /* infallible = */ !pure, @@ -622,16 +579,13 @@ impl ExprSequence { } /// Build a sequence from a rule. -pub fn lower_rule( - tyenv: &TypeEnv, - termenv: &TermEnv, - rule: RuleId, -) -> (PatternSequence, ExprSequence) { +pub fn lower_rule(termenv: &TermEnv, rule: RuleId) -> (PatternSequence, ExprSequence) { let mut pattern_seq: PatternSequence = Default::default(); let mut expr_seq: ExprSequence = Default::default(); - expr_seq.pos = termenv.rules[rule.index()].pos; let ruledata = &termenv.rules[rule.index()]; + expr_seq.pos = ruledata.pos; + let mut vars = StableMap::new(); let root_term = ruledata .lhs @@ -643,7 +597,6 @@ pub fn lower_rule( // Lower the pattern, starting from the root input value. pattern_seq.gen_pattern( ValueOrArgs::ImplicitTermFromArgs(root_term), - tyenv, termenv, &ruledata.lhs, &mut vars, @@ -653,13 +606,12 @@ pub fn lower_rule( // `PatternInst::Expr` for the sub-exprs (right-hand sides). for iflet in &ruledata.iflets { let mut subexpr_seq: ExprSequence = Default::default(); - let subexpr_ret_value = subexpr_seq.gen_expr(tyenv, termenv, &iflet.rhs, &mut vars); + let subexpr_ret_value = subexpr_seq.gen_expr(termenv, &iflet.rhs, &mut vars); subexpr_seq.add_return(iflet.rhs.ty(), subexpr_ret_value); let pattern_value = pattern_seq.add_expr_seq(subexpr_seq, subexpr_ret_value, iflet.rhs.ty()); pattern_seq.gen_pattern( ValueOrArgs::Value(pattern_value), - tyenv, termenv, &iflet.lhs, &mut vars, @@ -668,7 +620,7 @@ pub fn lower_rule( // Lower the expression, making use of the bound variables // from the pattern. - let rhs_root_val = expr_seq.gen_expr(tyenv, termenv, &ruledata.rhs, &vars); + let rhs_root_val = expr_seq.gen_expr(termenv, &ruledata.rhs, &vars); // Return the root RHS value. let output_ty = ruledata.rhs.ty(); expr_seq.add_return(output_ty, rhs_root_val); diff --git a/cranelift/isle/isle/src/trie.rs b/cranelift/isle/isle/src/trie.rs index 6ac909b3f76a..5fdf09744f57 100644 --- a/cranelift/isle/isle/src/trie.rs +++ b/cranelift/isle/isle/src/trie.rs @@ -1,14 +1,14 @@ //! Trie construction. -use crate::ir::{lower_rule, ExprSequence, PatternInst, PatternSequence}; +use crate::ir::{lower_rule, ExprSequence, PatternInst}; use crate::log; -use crate::sema::{RuleId, TermEnv, TermId, TypeEnv}; +use crate::sema::{TermEnv, TermId}; use std::collections::BTreeMap; /// Construct the tries for each term. -pub fn build_tries(typeenv: &TypeEnv, termenv: &TermEnv) -> BTreeMap { - let mut builder = TermFunctionsBuilder::new(typeenv, termenv); - builder.build(); +pub fn build_tries(termenv: &TermEnv) -> BTreeMap { + let mut builder = TermFunctionsBuilder::default(); + builder.build(termenv); log!("builder: {:?}", builder); builder.finalize() } @@ -280,91 +280,43 @@ impl TrieNode { } } -/// Builder context for one function in generated code corresponding -/// to one root input term. -/// -/// A `TermFunctionBuilder` can correspond to the matching -/// control-flow and operations that we execute either when evaluating -/// *forward* on a term, trying to match left-hand sides against it -/// and transforming it into another term; or *backward* on a term, -/// trying to match another rule's left-hand side against an input to -/// produce the term in question (when the term is used in the LHS of -/// the calling term). -#[derive(Debug)] -struct TermFunctionBuilder { - trie: TrieNode, -} - -impl TermFunctionBuilder { - fn new() -> Self { - TermFunctionBuilder { - trie: TrieNode::Empty, - } - } - - fn add_rule(&mut self, prio: Prio, pattern_seq: PatternSequence, expr_seq: ExprSequence) { - let symbols = pattern_seq - .insts - .into_iter() - .map(|op| TrieSymbol::Match { op }) - .chain(std::iter::once(TrieSymbol::EndOfMatch)); - self.trie.insert(prio, symbols, expr_seq); - } - - fn sort_trie(&mut self) { - self.trie.sort(); - } +#[derive(Debug, Default)] +struct TermFunctionsBuilder { + builders_by_term: BTreeMap, } -#[derive(Debug)] -struct TermFunctionsBuilder<'a> { - typeenv: &'a TypeEnv, - termenv: &'a TermEnv, - builders_by_term: BTreeMap, -} - -impl<'a> TermFunctionsBuilder<'a> { - fn new(typeenv: &'a TypeEnv, termenv: &'a TermEnv) -> Self { - log!("typeenv: {:?}", typeenv); +impl TermFunctionsBuilder { + fn build(&mut self, termenv: &TermEnv) { log!("termenv: {:?}", termenv); - Self { - builders_by_term: BTreeMap::new(), - typeenv, - termenv, - } - } - - fn build(&mut self) { - for rule in 0..self.termenv.rules.len() { - let rule = RuleId(rule); - let prio = self.termenv.rules[rule.index()].prio; - - let (pattern, expr) = lower_rule(self.typeenv, self.termenv, rule); - let root_term = self.termenv.rules[rule.index()].lhs.root_term().unwrap(); + for rule in termenv.rules.iter() { + let (pattern, expr) = lower_rule(termenv, rule.id); + let root_term = rule.lhs.root_term().unwrap(); log!( "build:\n- rule {:?}\n- pattern {:?}\n- expr {:?}", - self.termenv.rules[rule.index()], + rule, pattern, expr ); + + let symbols = pattern + .insts + .into_iter() + .map(|op| TrieSymbol::Match { op }) + .chain(std::iter::once(TrieSymbol::EndOfMatch)); + self.builders_by_term .entry(root_term) - .or_insert_with(|| TermFunctionBuilder::new()) - .add_rule(prio, pattern.clone(), expr.clone()); + .or_insert(TrieNode::Empty) + .insert(rule.prio, symbols, expr); } for builder in self.builders_by_term.values_mut() { - builder.sort_trie(); + builder.sort(); } } fn finalize(self) -> BTreeMap { - let functions_by_term = self - .builders_by_term - .into_iter() - .map(|(term, builder)| (term, builder.trie)) - .collect::>(); - functions_by_term + self.builders_by_term } }