Skip to content

Commit

Permalink
Suggest using deref in patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
uellenberg committed Dec 13, 2024
1 parent e217f94 commit 831f454
Show file tree
Hide file tree
Showing 12 changed files with 544 additions and 64 deletions.
16 changes: 16 additions & 0 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,22 @@ pub fn is_range_literal(expr: &Expr<'_>) -> bool {
}
}

/// Checks if the specified expression needs parentheses for prefix
/// or postfix suggestions to be valid.
/// For example, `a + b` requires parentheses to suggest `&(a + b)`,
/// but just `a` does not.
/// Similarly, `(a + b).c()` also requires parentheses.
/// This should not be used for other types of suggestions.
pub fn expr_needs_parens(expr: &Expr<'_>) -> bool {
match expr.kind {
// parenthesize if needed (Issue #46756)
ExprKind::Cast(_, _) | ExprKind::Binary(_, _, _) => true,
// parenthesize borrows of range literals (Issue #54505)
_ if is_range_literal(expr) => true,
_ => false,
}
}

#[derive(Debug, Clone, Copy, HashStable_Generic)]
pub enum ExprKind<'hir> {
/// Allow anonymous constants from an inline `const` block
Expand Down
17 changes: 3 additions & 14 deletions compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use rustc_hir::def::{CtorKind, CtorOf, DefKind, Res};
use rustc_hir::lang_items::LangItem;
use rustc_hir::{
Arm, CoroutineDesugaring, CoroutineKind, CoroutineSource, Expr, ExprKind, GenericBound, HirId,
Node, Path, QPath, Stmt, StmtKind, TyKind, WherePredicateKind,
Node, Path, QPath, Stmt, StmtKind, TyKind, WherePredicateKind, expr_needs_parens,
};
use rustc_hir_analysis::collect::suggest_impl_trait;
use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer;
Expand All @@ -35,7 +35,6 @@ use tracing::{debug, instrument};

use super::FnCtxt;
use crate::fn_ctxt::rustc_span::BytePos;
use crate::hir::is_range_literal;
use crate::method::probe;
use crate::method::probe::{IsSuggestion, Mode, ProbeScope};
use crate::{errors, fluent_generated as fluent};
Expand Down Expand Up @@ -2648,7 +2647,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
}

let make_sugg = |expr: &Expr<'_>, span: Span, sugg: &str| {
if self.needs_parentheses(expr) {
if expr_needs_parens(expr) {
(
vec![
(span.shrink_to_lo(), format!("{prefix}{sugg}(")),
Expand Down Expand Up @@ -2861,7 +2860,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
return None;
}

if self.needs_parentheses(expr) {
if expr_needs_parens(expr) {
return Some((
vec![
(span, format!("{suggestion}(")),
Expand Down Expand Up @@ -2902,16 +2901,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
false
}

fn needs_parentheses(&self, expr: &hir::Expr<'_>) -> bool {
match expr.kind {
// parenthesize if needed (Issue #46756)
hir::ExprKind::Cast(_, _) | hir::ExprKind::Binary(_, _, _) => true,
// parenthesize borrows of range literals (Issue #54505)
_ if is_range_literal(expr) => true,
_ => false,
}
}

pub(crate) fn suggest_cast(
&self,
err: &mut Diag<'_>,
Expand Down
30 changes: 28 additions & 2 deletions compiler/rustc_hir_typeck/src/pat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ use rustc_errors::{
};
use rustc_hir::def::{CtorKind, DefKind, Res};
use rustc_hir::pat_util::EnumerateAndAdjustIterator;
use rustc_hir::{self as hir, BindingMode, ByRef, HirId, LangItem, Mutability, Pat, PatKind};
use rustc_hir::{
self as hir, BindingMode, ByRef, ExprKind, HirId, LangItem, Mutability, Pat, PatKind,
expr_needs_parens,
};
use rustc_infer::infer;
use rustc_middle::traits::PatternOriginExpr;
use rustc_middle::ty::{self, Ty, TypeVisitableExt};
use rustc_middle::{bug, span_bug};
use rustc_session::lint::builtin::NON_EXHAUSTIVE_OMITTED_PATTERNS;
Expand Down Expand Up @@ -94,10 +98,32 @@ struct PatInfo<'a, 'tcx> {

impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
fn pattern_cause(&self, ti: &TopInfo<'tcx>, cause_span: Span) -> ObligationCause<'tcx> {
// If origin_expr exists, then expected represents the type of origin_expr.
// If span also exists, then span == origin_expr.span (although it doesn't need to exist).
// In that case, we can peel away references from both and treat them
// as the same.
let origin_expr_info = ti.origin_expr.map(|mut cur_expr| {
let mut count = 0;

// cur_ty may have more layers of references than cur_expr.
// We can only make suggestions about cur_expr, however, so we'll
// use that as our condition for stopping.
while let ExprKind::AddrOf(.., inner) = &cur_expr.kind {
cur_expr = inner;
count += 1;
}

PatternOriginExpr {
peeled_span: cur_expr.span,
peeled_count: count,
peeled_prefix_suggestion_parentheses: expr_needs_parens(cur_expr),
}
});

let code = ObligationCauseCode::Pattern {
span: ti.span,
root_ty: ti.expected,
origin_expr: ti.origin_expr.is_some(),
origin_expr: origin_expr_info,
};
self.cause(cause_span, code)
}
Expand Down
23 changes: 21 additions & 2 deletions compiler/rustc_middle/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ pub enum ObligationCauseCode<'tcx> {
span: Option<Span>,
/// The root expected type induced by a scrutinee or type expression.
root_ty: Ty<'tcx>,
/// Whether the `Span` came from an expression or a type expression.
origin_expr: bool,
/// Information about the `Span`, if it came from an expression, otherwise `None`.
origin_expr: Option<PatternOriginExpr>,
},

/// Computing common supertype in an if expression
Expand Down Expand Up @@ -530,6 +530,25 @@ pub struct MatchExpressionArmCause<'tcx> {
pub tail_defines_return_position_impl_trait: Option<LocalDefId>,
}

/// Information about the origin expression of a pattern, relevant to diagnostics.
/// Fields here refer to the scrutinee of a pattern.
/// If the scrutinee isn't given in the diagnostic, then this won't exist.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[derive(TypeFoldable, TypeVisitable, HashStable, TyEncodable, TyDecodable)]
pub struct PatternOriginExpr {
/// A span representing the scrutinee expression, with all leading references
/// peeled from the expression.
/// Only references in the expression are peeled - if the expression refers to a variable
/// whose type is a reference, then that reference is kept because it wasn't created
/// in the expression.
pub peeled_span: Span,
/// The number of references that were peeled to produce `peeled_span`.
pub peeled_count: usize,
/// Does the peeled expression need to be wrapped in parentheses for
/// a prefix suggestion (i.e., dereference) to be valid.
pub peeled_prefix_suggestion_parentheses: bool,
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[derive(TypeFoldable, TypeVisitable, HashStable, TyEncodable, TyDecodable)]
pub struct IfExpressionCause<'tcx> {
Expand Down
115 changes: 94 additions & 21 deletions compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ use rustc_hir::{self as hir};
use rustc_macros::extension;
use rustc_middle::bug;
use rustc_middle::dep_graph::DepContext;
use rustc_middle::traits::PatternOriginExpr;
use rustc_middle::ty::error::{ExpectedFound, TypeError, TypeErrorToStringExt};
use rustc_middle::ty::print::{PrintError, PrintTraitRefExt as _, with_forced_trimmed_paths};
use rustc_middle::ty::{
self, List, Region, Ty, TyCtxt, TypeFoldable, TypeSuperVisitable, TypeVisitable,
self, List, ParamEnv, Region, Ty, TyCtxt, TypeFoldable, TypeSuperVisitable, TypeVisitable,
TypeVisitableExt,
};
use rustc_span::def_id::LOCAL_CRATE;
Expand All @@ -77,7 +78,7 @@ use crate::error_reporting::TypeErrCtxt;
use crate::errors::{ObligationCauseFailureCode, TypeErrorAdditionalDiags};
use crate::infer;
use crate::infer::relate::{self, RelateResult, TypeRelation};
use crate::infer::{InferCtxt, TypeTrace, ValuePairs};
use crate::infer::{InferCtxt, InferCtxtExt as _, TypeTrace, ValuePairs};
use crate::solve::deeply_normalize_for_diagnostics;
use crate::traits::{
IfExpressionCause, MatchExpressionArmCause, ObligationCause, ObligationCauseCode,
Expand Down Expand Up @@ -433,38 +434,71 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
cause: &ObligationCause<'tcx>,
exp_found: Option<ty::error::ExpectedFound<Ty<'tcx>>>,
terr: TypeError<'tcx>,
param_env: Option<ParamEnv<'tcx>>,
) {
match *cause.code() {
ObligationCauseCode::Pattern { origin_expr: true, span: Some(span), root_ty } => {
let ty = self.resolve_vars_if_possible(root_ty);
if !matches!(ty.kind(), ty::Infer(ty::InferTy::TyVar(_) | ty::InferTy::FreshTy(_)))
{
ObligationCauseCode::Pattern {
origin_expr: Some(origin_expr),
span: Some(span),
root_ty,
} => {
let expected_ty = self.resolve_vars_if_possible(root_ty);
if !matches!(
expected_ty.kind(),
ty::Infer(ty::InferTy::TyVar(_) | ty::InferTy::FreshTy(_))
) {
// don't show type `_`
if span.desugaring_kind() == Some(DesugaringKind::ForLoop)
&& let ty::Adt(def, args) = ty.kind()
&& let ty::Adt(def, args) = expected_ty.kind()
&& Some(def.did()) == self.tcx.get_diagnostic_item(sym::Option)
{
err.span_label(
span,
format!("this is an iterator with items of type `{}`", args.type_at(0)),
);
} else {
err.span_label(span, format!("this expression has type `{ty}`"));
err.span_label(span, format!("this expression has type `{expected_ty}`"));
}
}
if let Some(ty::error::ExpectedFound { found, .. }) = exp_found
&& ty.boxed_ty() == Some(found)
&& let Ok(snippet) = self.tcx.sess.source_map().span_to_snippet(span)
&& let Ok(mut peeled_snippet) =
self.tcx.sess.source_map().span_to_snippet(origin_expr.peeled_span)
{
err.span_suggestion(
span,
"consider dereferencing the boxed value",
format!("*{snippet}"),
Applicability::MachineApplicable,
);
// Parentheses are needed for cases like as casts.
// We use the peeled_span for deref suggestions.
// It's also safe to use for box, since box only triggers if there
// wasn't a reference to begin with.
if origin_expr.peeled_prefix_suggestion_parentheses {
peeled_snippet = format!("({peeled_snippet})");
}

// Try giving a box suggestion first, as it is a special case of the
// deref suggestion.
if expected_ty.boxed_ty() == Some(found) {
err.span_suggestion_verbose(
span,
"consider dereferencing the boxed value",
format!("*{peeled_snippet}"),
Applicability::MachineApplicable,
);
} else if let Some(param_env) = param_env
&& let Some(prefix) = self.should_deref_suggestion_on_mismatch(
param_env,
found,
expected_ty,
origin_expr,
)
{
err.span_suggestion_verbose(
span,
"consider dereferencing to access the inner value using the Deref trait",
format!("{prefix}{peeled_snippet}"),
Applicability::MaybeIncorrect,
);
}
}
}
ObligationCauseCode::Pattern { origin_expr: false, span: Some(span), .. } => {
ObligationCauseCode::Pattern { origin_expr: None, span: Some(span), .. } => {
err.span_label(span, "expected due to this");
}
ObligationCauseCode::BlockTailExpression(
Expand Down Expand Up @@ -618,6 +652,45 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
}
}

/// Determines whether deref_to == <deref_from as Deref>::Target, and if so,
/// returns a prefix that should be added to deref_from as a suggestion.
fn should_deref_suggestion_on_mismatch(
&self,
param_env: ParamEnv<'tcx>,
deref_to: Ty<'tcx>,
deref_from: Ty<'tcx>,
origin_expr: PatternOriginExpr,
) -> Option<String> {
// origin_expr contains stripped away versions of our expression.
// We'll want to use that to avoid suggesting things like *&x.
// However, the type that we have access to hasn't been stripped away,
// so we need to ignore the first n dereferences, where n is the number
// that's been stripped away in origin_expr.

// Find a way to autoderef from deref_from to deref_to.
let Some((num_derefs, (after_deref_ty, _))) = (self.autoderef_steps)(deref_from)
.into_iter()
.enumerate()
.find(|(_, (ty, _))| self.infcx.can_eq(param_env, *ty, deref_to))
else {
return None;
};

if num_derefs <= origin_expr.peeled_count {
return None;
}

let deref_part = "*".repeat(num_derefs - origin_expr.peeled_count);

// If the user used a reference in the original expression, they probably
// want the suggestion to still give a reference.
if deref_from.is_ref() && !after_deref_ty.is_ref() {
Some(format!("&{deref_part}"))
} else {
Some(deref_part)
}
}

/// Given that `other_ty` is the same as a type argument for `name` in `sub`, populate `value`
/// highlighting `name` and every type argument that isn't at `pos` (which is `other_ty`), and
/// populate `other_value` with `other_ty`.
Expand Down Expand Up @@ -1406,8 +1479,8 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
Variable(ty::error::ExpectedFound<Ty<'a>>),
Fixed(&'static str),
}
let (expected_found, exp_found, is_simple_error, values) = match values {
None => (None, Mismatch::Fixed("type"), false, None),
let (expected_found, exp_found, is_simple_error, values, param_env) = match values {
None => (None, Mismatch::Fixed("type"), false, None, None),
Some(ty::ParamEnvAnd { param_env, value: values }) => {
let mut values = self.resolve_vars_if_possible(values);
if self.next_trait_solver() {
Expand Down Expand Up @@ -1459,7 +1532,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
diag.downgrade_to_delayed_bug();
return;
};
(Some(vals), exp_found, is_simple_error, Some(values))
(Some(vals), exp_found, is_simple_error, Some(values), Some(param_env))
}
};

Expand Down Expand Up @@ -1791,7 +1864,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {

// It reads better to have the error origin as the final
// thing.
self.note_error_origin(diag, cause, exp_found, terr);
self.note_error_origin(diag, cause, exp_found, terr, param_env);

debug!(?diag);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
(Some(ty), _) if self.same_type_modulo_infer(ty, exp_found.found) => match cause.code()
{
ObligationCauseCode::Pattern { span: Some(then_span), origin_expr, .. } => {
origin_expr.then_some(ConsiderAddingAwait::FutureSugg {
origin_expr.is_some().then_some(ConsiderAddingAwait::FutureSugg {
span: then_span.shrink_to_hi(),
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ use rustc_hir::def_id::DefId;
use rustc_hir::intravisit::Visitor;
use rustc_hir::lang_items::LangItem;
use rustc_hir::{
CoroutineDesugaring, CoroutineKind, CoroutineSource, Expr, HirId, Node, is_range_literal,
CoroutineDesugaring, CoroutineKind, CoroutineSource, Expr, HirId, Node, expr_needs_parens,
is_range_literal,
};
use rustc_infer::infer::{BoundRegionConversionTime, DefineOpaqueTypes, InferCtxt, InferOk};
use rustc_middle::hir::map;
Expand Down Expand Up @@ -1391,13 +1392,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
let Some(expr) = expr_finder.result else {
return false;
};
let needs_parens = match expr.kind {
// parenthesize if needed (Issue #46756)
hir::ExprKind::Cast(_, _) | hir::ExprKind::Binary(_, _, _) => true,
// parenthesize borrows of range literals (Issue #54505)
_ if is_range_literal(expr) => true,
_ => false,
};
let needs_parens = expr_needs_parens(expr);

let span = if needs_parens { span } else { span.shrink_to_lo() };
let suggestions = if !needs_parens {
Expand Down
5 changes: 5 additions & 0 deletions tests/ui/closures/2229_closure_analysis/issue-118144.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ LL | V(x) = func_arg;
| ^^^^ -------- this expression has type `&mut V`
| |
| expected `&mut V`, found `V`
|
help: consider dereferencing to access the inner value using the Deref trait
|
LL | V(x) = &*func_arg;
| ~~~~~~~~~~

error: aborting due to 1 previous error

Expand Down
Loading

0 comments on commit 831f454

Please sign in to comment.