Skip to content

Commit

Permalink
Desugar try blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
HKalbasi committed Mar 19, 2023
1 parent 453ae2e commit f3bd169
Show file tree
Hide file tree
Showing 11 changed files with 214 additions and 83 deletions.
87 changes: 74 additions & 13 deletions crates/hir-def/src/body/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use syntax::{
ast::{
self, ArrayExprKind, AstChildren, HasArgList, HasLoopBody, HasName, LiteralKind,
self, ArrayExprKind, AstChildren, BlockExpr, HasArgList, HasLoopBody, HasName, LiteralKind,
SlicePatComponents,
},
AstNode, AstPtr, SyntaxNodePtr,
Expand Down Expand Up @@ -100,6 +100,7 @@ pub(super) fn lower(
_c: Count::new(),
},
expander,
current_try_block: None,
is_lowering_assignee_expr: false,
is_lowering_generator: false,
}
Expand All @@ -113,6 +114,7 @@ struct ExprCollector<'a> {
body: Body,
krate: CrateId,
source_map: BodySourceMap,
current_try_block: Option<LabelId>,
is_lowering_assignee_expr: bool,
is_lowering_generator: bool,
}
Expand Down Expand Up @@ -222,6 +224,10 @@ impl ExprCollector<'_> {
self.source_map.label_map.insert(src, id);
id
}
// FIXME: desugared labels don't have ptr, that's wrong and should be fixed somehow.
fn alloc_label_desugared(&mut self, label: Label) -> LabelId {
self.body.labels.alloc(label)
}
fn make_label(&mut self, label: Label, src: LabelSource) -> LabelId {
let id = self.body.labels.alloc(label);
self.source_map.label_map_back.insert(id, src);
Expand Down Expand Up @@ -259,13 +265,7 @@ impl ExprCollector<'_> {
self.alloc_expr(Expr::Let { pat, expr }, syntax_ptr)
}
ast::Expr::BlockExpr(e) => match e.modifier() {
Some(ast::BlockModifier::Try(_)) => {
self.collect_block_(e, |id, statements, tail| Expr::TryBlock {
id,
statements,
tail,
})
}
Some(ast::BlockModifier::Try(_)) => self.collect_try_block(e),
Some(ast::BlockModifier::Unsafe(_)) => {
self.collect_block_(e, |id, statements, tail| Expr::Unsafe {
id,
Expand Down Expand Up @@ -606,6 +606,59 @@ impl ExprCollector<'_> {
})
}

/// Desugar `try { <stmts>; <expr> }` into `'<new_label>: { <stmts>; ::std::ops::Try::from_output(<expr>) }`,
/// `try { <stmts>; }` into `'<new_label>: { <stmts>; ::std::ops::Try::from_output(()) }`
/// and save the `<new_label>` to use it as a break target for desugaring of the `?` operator.
fn collect_try_block(&mut self, e: BlockExpr) -> ExprId {
let Some(try_from_output) = LangItem::TryTraitFromOutput.path(self.db, self.krate) else {
return self.alloc_expr_desugared(Expr::Missing);
};
let prev_try_block = self.current_try_block.take();
self.current_try_block =
Some(self.alloc_label_desugared(Label { name: Name::generate_new_name() }));
let expr_id = self.collect_block(e);
let callee = self.alloc_expr_desugared(Expr::Path(try_from_output));
let Expr::Block { label, tail, .. } = &mut self.body.exprs[expr_id] else {
unreachable!("It is the output of collect block");
};
*label = self.current_try_block;
let next_tail = match *tail {
Some(tail) => self.alloc_expr_desugared(Expr::Call {
callee,
args: Box::new([tail]),
is_assignee_expr: false,
}),
None => {
let unit = self.alloc_expr_desugared(Expr::Tuple {
exprs: Box::new([]),
is_assignee_expr: false,
});
self.alloc_expr_desugared(Expr::Call {
callee,
args: Box::new([unit]),
is_assignee_expr: false,
})
}
};
let Expr::Block { tail, .. } = &mut self.body.exprs[expr_id] else {
unreachable!("It is the output of collect block");
};
*tail = Some(next_tail);
self.current_try_block = prev_try_block;
expr_id
}

/// Desugar `ast::TryExpr` from: `<expr>?` into:
/// ```ignore (pseudo-rust)
/// match Try::branch(<expr>) {
/// ControlFlow::Continue(val) => val,
/// ControlFlow::Break(residual) =>
/// // If there is an enclosing `try {...}`:
/// break 'catch_target Try::from_residual(residual),
/// // Otherwise:
/// return Try::from_residual(residual),
/// }
/// ```
fn collect_try_operator(&mut self, syntax_ptr: AstPtr<ast::Expr>, e: ast::TryExpr) -> ExprId {
let (try_branch, cf_continue, cf_break, try_from_residual) = 'if_chain: {
if let Some(try_branch) = LangItem::TryTraitBranch.path(self.db, self.krate) {
Expand All @@ -628,7 +681,9 @@ impl ExprCollector<'_> {
Expr::Call { callee: try_branch, args: Box::new([operand]), is_assignee_expr: false },
syntax_ptr.clone(),
);
let continue_binding = self.alloc_binding(name![v1], BindingAnnotation::Unannotated);
let continue_name = Name::generate_new_name();
let continue_binding =
self.alloc_binding(continue_name.clone(), BindingAnnotation::Unannotated);
let continue_bpat =
self.alloc_pat_desugared(Pat::Bind { id: continue_binding, subpat: None });
self.add_definition_to_binding(continue_binding, continue_bpat);
Expand All @@ -639,9 +694,10 @@ impl ExprCollector<'_> {
ellipsis: None,
}),
guard: None,
expr: self.alloc_expr(Expr::Path(Path::from(name![v1])), syntax_ptr.clone()),
expr: self.alloc_expr(Expr::Path(Path::from(continue_name)), syntax_ptr.clone()),
};
let break_binding = self.alloc_binding(name![v1], BindingAnnotation::Unannotated);
let break_name = Name::generate_new_name();
let break_binding = self.alloc_binding(break_name.clone(), BindingAnnotation::Unannotated);
let break_bpat = self.alloc_pat_desugared(Pat::Bind { id: break_binding, subpat: None });
self.add_definition_to_binding(break_binding, break_bpat);
let break_arm = MatchArm {
Expand All @@ -652,13 +708,18 @@ impl ExprCollector<'_> {
}),
guard: None,
expr: {
let x = self.alloc_expr(Expr::Path(Path::from(name![v1])), syntax_ptr.clone());
let x = self.alloc_expr(Expr::Path(Path::from(break_name)), syntax_ptr.clone());
let callee = self.alloc_expr(Expr::Path(try_from_residual), syntax_ptr.clone());
let result = self.alloc_expr(
Expr::Call { callee, args: Box::new([x]), is_assignee_expr: false },
syntax_ptr.clone(),
);
self.alloc_expr(Expr::Return { expr: Some(result) }, syntax_ptr.clone())
if let Some(label) = self.current_try_block {
let label = Some(self.body.labels[label].name.clone());
self.alloc_expr(Expr::Break { expr: Some(result), label }, syntax_ptr.clone())
} else {
self.alloc_expr(Expr::Return { expr: Some(result) }, syntax_ptr.clone())
}
},
};
let arms = Box::new([continue_arm, break_arm]);
Expand Down
3 changes: 0 additions & 3 deletions crates/hir-def/src/body/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,6 @@ impl<'a> Printer<'a> {
Expr::Unsafe { id: _, statements, tail } => {
self.print_block(Some("unsafe "), statements, tail);
}
Expr::TryBlock { id: _, statements, tail } => {
self.print_block(Some("try "), statements, tail);
}
Expr::Async { id: _, statements, tail } => {
self.print_block(Some("async "), statements, tail);
}
Expand Down
3 changes: 1 addition & 2 deletions crates/hir-def/src/body/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,7 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope
}
Expr::Unsafe { id, statements, tail }
| Expr::Async { id, statements, tail }
| Expr::Const { id, statements, tail }
| Expr::TryBlock { id, statements, tail } => {
| Expr::Const { id, statements, tail } => {
let mut scope = scopes.new_block_scope(*scope, *id, None);
// Overwrite the old scope for the block expr, so that every block scope can be found
// via the block itself (important for blocks that only contain items, no expressions).
Expand Down
6 changes: 0 additions & 6 deletions crates/hir-def/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,6 @@ pub enum Expr {
tail: Option<ExprId>,
label: Option<LabelId>,
},
TryBlock {
id: BlockId,
statements: Box<[Statement]>,
tail: Option<ExprId>,
},
Async {
id: BlockId,
statements: Box<[Statement]>,
Expand Down Expand Up @@ -310,7 +305,6 @@ impl Expr {
f(*expr);
}
Expr::Block { statements, tail, .. }
| Expr::TryBlock { statements, tail, .. }
| Expr::Unsafe { statements, tail, .. }
| Expr::Async { statements, tail, .. }
| Expr::Const { statements, tail, .. } => {
Expand Down
12 changes: 12 additions & 0 deletions crates/hir-expand/src/name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,18 @@ impl Name {
Name::new_inline("[missing name]")
}

/// Generates a new name which is only equal to itself, by incrementing a counter. Due
/// its implementation, it should not be used in things that salsa considers, like
/// type names or field names, and it should be only used in names of local variables
/// and labels and similar things.
pub fn generate_new_name() -> Name {
use std::sync::atomic::{AtomicUsize, Ordering};
static CNT: AtomicUsize = AtomicUsize::new(0);
let c = CNT.fetch_add(1, Ordering::Relaxed);
// FIXME: Currently a `__RA_generated_name` in user code will break our analysis
Name::new_inline(&format!("__RA_geneated_name_{c}"))
}

/// Returns the tuple index this name represents if it is a tuple field.
pub fn as_tuple_index(&self) -> Option<usize> {
match self.0 {
Expand Down
54 changes: 54 additions & 0 deletions crates/hir-ty/src/consteval/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,42 @@ fn loops() {
"#,
4,
);
check_number(
r#"
const GOAL: u8 = {
let mut x = 0;
loop {
x = x + 1;
if x == 5 {
break x + 2;
}
}
};
"#,
7,
);
check_number(
r#"
const GOAL: u8 = {
'a: loop {
let x = 'b: loop {
let x = 'c: loop {
let x = 'd: loop {
let x = 'e: loop {
break 'd 1;
};
break 2 + x;
};
break 3 + x;
};
break 'a 4 + x;
};
break 5 + x;
}
};
"#,
8,
);
}

#[test]
Expand Down Expand Up @@ -1019,6 +1055,24 @@ fn try_operator() {
);
}

#[test]
fn try_block() {
check_number(
r#"
//- minicore: option, try
const fn g(x: Option<i32>, y: Option<i32>) -> i32 {
let r = try { x? * y? };
match r {
Some(k) => k,
None => 5,
}
}
const GOAL: i32 = g(Some(10), Some(20)) + g(Some(30), None) + g(None, Some(40)) + g(None, None);
"#,
215,
);
}

#[test]
fn or_pattern() {
check_number(
Expand Down
4 changes: 0 additions & 4 deletions crates/hir-ty/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1025,10 +1025,6 @@ impl<'a> InferenceContext<'a> {
self.resolve_lang_item(lang)?.as_trait()
}

fn resolve_ops_try_output(&self) -> Option<TypeAliasId> {
self.resolve_output_on(self.resolve_lang_trait(LangItem::Try)?)
}

fn resolve_ops_neg_output(&self) -> Option<TypeAliasId> {
self.resolve_output_on(self.resolve_lang_trait(LangItem::Neg)?)
}
Expand Down
20 changes: 0 additions & 20 deletions crates/hir-ty/src/infer/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,26 +159,6 @@ impl<'a> InferenceContext<'a> {
})
.1
}
Expr::TryBlock { id: _, statements, tail } => {
// The type that is returned from the try block
let try_ty = self.table.new_type_var();
if let Some(ty) = expected.only_has_type(&mut self.table) {
self.unify(&try_ty, &ty);
}

// The ok-ish type that is expected from the last expression
let ok_ty =
self.resolve_associated_type(try_ty.clone(), self.resolve_ops_try_output());

self.infer_block(
tgt_expr,
statements,
*tail,
None,
&Expectation::has_type(ok_ty.clone()),
);
try_ty
}
Expr::Async { id: _, statements, tail } => {
let ret_ty = self.table.new_type_var();
let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe);
Expand Down
1 change: 0 additions & 1 deletion crates/hir-ty/src/infer/mutability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ impl<'a> InferenceContext<'a> {
}
Expr::Let { pat, expr } => self.infer_mut_expr(*expr, self.pat_bound_mutability(*pat)),
Expr::Block { id: _, statements, tail, label: _ }
| Expr::TryBlock { id: _, statements, tail }
| Expr::Async { id: _, statements, tail }
| Expr::Const { id: _, statements, tail }
| Expr::Unsafe { id: _, statements, tail } => {
Expand Down
Loading

0 comments on commit f3bd169

Please sign in to comment.