Skip to content

Commit

Permalink
Auto merge of #129931 - DianQK:match-br-copy, r=<try>
Browse files Browse the repository at this point in the history
Merge these copy statements that simplified the canonical enum clone method by GVN

This is blocked by #128299.
  • Loading branch information
bors committed Sep 14, 2024
2 parents 5e3ede2 + 622247a commit c324112
Show file tree
Hide file tree
Showing 32 changed files with 741 additions and 280 deletions.
3 changes: 1 addition & 2 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,6 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// Now, we need to shrink the generated MIR.
&ref_prop::ReferencePropagation,
&sroa::ScalarReplacementOfAggregates,
&match_branches::MatchBranchSimplification,
// inst combine is after MatchBranchSimplification to clean up Ne(_1, false)
&multiple_return_terminators::MultipleReturnTerminators,
// After simplifycfg, it allows us to discover new opportunities for peephole
// optimizations.
Expand All @@ -604,6 +602,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&dead_store_elimination::DeadStoreElimination::Initial,
&gvn::GVN,
&simplify::SimplifyLocals::AfterGVN,
&match_branches::MatchBranchSimplification,
&dataflow_const_prop::DataflowConstProp,
&single_use_consts::SingleUseConsts,
&o1(simplify_branches::SimplifyConstCondition::AfterConstProp),
Expand Down
221 changes: 220 additions & 1 deletion compiler/rustc_mir_transform/src/match_branches.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use std::iter;
use std::{iter, usize};

use rustc_const_eval::const_eval::mk_eval_cx_for_const_val;
use rustc_index::bit_set::BitSet;
use rustc_index::IndexSlice;
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::*;
use rustc_middle::ty;
use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
use rustc_middle::ty::util::Discr;
use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt};
use rustc_mir_dataflow::impls::{borrowed_locals, MaybeTransitiveLiveLocals};
use rustc_mir_dataflow::Analysis;
use rustc_target::abi::Integer;
use rustc_type_ir::TyKind::*;

Expand Down Expand Up @@ -48,6 +54,10 @@ impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
should_cleanup = true;
continue;
}
if simplify_to_copy(tcx, body, bb_idx, param_env).is_some() {
should_cleanup = true;
continue;
}
}

if should_cleanup {
Expand Down Expand Up @@ -519,3 +529,212 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
}
}
}

/// This is primarily used to merge these copy statements that simplified the canonical enum clone method by GVN.
/// The GVN simplified
/// ```ignore (syntax-highlighting-only)
/// match a {
/// Foo::A(x) => Foo::A(*x),
/// Foo::B => Foo::B
/// }
/// ```
/// to
/// ```ignore (syntax-highlighting-only)
/// match a {
/// Foo::A(_x) => a, // copy a
/// Foo::B => Foo::B
/// }
/// ```
/// This function will simplify into a copy statement.
fn simplify_to_copy<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
switch_bb_idx: BasicBlock,
param_env: ParamEnv<'tcx>,
) -> Option<()> {
// To save compile time, only consider the first BB has a switch terminator.
if switch_bb_idx != START_BLOCK {
return None;
}
let bbs = &body.basic_blocks;
// Check if the copy source matches the following pattern.
// _2 = discriminant(*_1); // "*_1" is the expected the copy source.
// switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
let &Statement {
kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(expected_src_place))),
..
} = bbs[switch_bb_idx].statements.last()?
else {
return None;
};
let expected_src_ty = expected_src_place.ty(body.local_decls(), tcx);
if !expected_src_ty.ty.is_enum() || expected_src_ty.variant_index.is_some() {
return None;
}
// To save compile time, only consider the copy source is assigned to the return place.
let expected_dest_place = Place::return_place();
let expected_dest_ty = expected_dest_place.ty(body.local_decls(), tcx);
if expected_dest_ty.ty != expected_src_ty.ty || expected_dest_ty.variant_index.is_some() {
return None;
}
let targets = match bbs[switch_bb_idx].terminator().kind {
TerminatorKind::SwitchInt { ref discr, ref targets, .. }
if discr.place() == Some(discr_place) =>
{
targets
}
_ => return None,
};
// We require that the possible target blocks all be distinct.
if !targets.is_distinct() {
return None;
}
if !bbs[targets.otherwise()].is_empty_unreachable() {
return None;
}
// Check that destinations are identical, and if not, then don't optimize this block.
let mut target_iter = targets.iter();
let first_terminator_kind = &bbs[target_iter.next().unwrap().1].terminator().kind;
if !target_iter
.all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind)
{
return None;
}

let borrowed_locals = borrowed_locals(body);
let mut live = None;

for (index, target_bb) in targets.iter() {
let stmts = &bbs[target_bb].statements;
if stmts.is_empty() {
return None;
}
if let [Statement { kind: StatementKind::Assign(box (place, rvalue)), .. }] =
bbs[target_bb].statements.as_slice()
{
let dest_ty = place.ty(body.local_decls(), tcx);
if dest_ty.ty != expected_src_ty.ty || dest_ty.variant_index.is_some() {
return None;
}
let ty::Adt(def, _) = dest_ty.ty.kind() else {
return None;
};
if expected_dest_place != *place {
return None;
}
match rvalue {
// Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`.
Rvalue::Use(Operand::Constant(box constant))
if let Const::Val(const_, ty) = constant.const_ =>
{
let (ecx, op) =
mk_eval_cx_for_const_val(tcx.at(constant.span), param_env, const_, ty)?;
let variant = ecx.read_discriminant(&op).ok()?;
if !def.variants()[variant].fields.is_empty() {
return None;
}
let Discr { val, .. } = ty.discriminant_for_variant(tcx, variant)?;
if val != index {
return None;
}
}
Rvalue::Use(Operand::Copy(src_place)) if *src_place == expected_src_place => {}
// Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`.
Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields)
if fields.is_empty()
&& let Some(Discr { val, .. }) =
expected_src_ty.ty.discriminant_for_variant(tcx, *variant_index)
&& val == index => {}
_ => return None,
}
} else {
// If the BB contains more than one statement, we have to check if these statements can be ignored.
let mut lived_stmts: BitSet<usize> =
BitSet::new_filled(bbs[target_bb].statements.len());
let mut expected_copy_stmt = None;
for (statement_index, statement) in bbs[target_bb].statements.iter().enumerate().rev() {
let loc = Location { block: target_bb, statement_index };
if let StatementKind::Assign(assign) = &statement.kind {
if !assign.1.is_safe_to_remove() {
return None;
}
}
match &statement.kind {
StatementKind::Assign(box (place, _))
| StatementKind::SetDiscriminant { place: box place, .. }
| StatementKind::Deinit(box place) => {
if place.is_indirect() || borrowed_locals.contains(place.local) {
return None;
}
let live = live.get_or_insert_with(|| {
MaybeTransitiveLiveLocals::new(&borrowed_locals)
.into_engine(tcx, body)
.iterate_to_fixpoint()
.into_results_cursor(body)
});
live.seek_before_primary_effect(loc);
if !live.get().contains(place.local) {
lived_stmts.remove(statement_index);
} else if let StatementKind::Assign(box (
_,
Rvalue::Use(Operand::Copy(src_place)),
)) = statement.kind
&& expected_copy_stmt.is_none()
&& expected_src_place == src_place
&& expected_dest_place == *place
{
// There is only one statement that cannot be ignored that can be used as an expected copy statement.
expected_copy_stmt = Some(statement_index);
} else {
return None;
}
}
StatementKind::StorageLive(_)
| StatementKind::StorageDead(_)
| StatementKind::Nop => (),

StatementKind::Retag(_, _)
| StatementKind::Coverage(_)
| StatementKind::Intrinsic(_)
| StatementKind::ConstEvalCounter
| StatementKind::PlaceMention(_)
| StatementKind::FakeRead(_)
| StatementKind::AscribeUserType(_, _) => {
return None;
}
}
}
let expected_copy_stmt = expected_copy_stmt?;
// We can ignore the paired StorageLive and StorageDead.
let mut storage_live_locals: BitSet<Local> = BitSet::new_empty(body.local_decls.len());
for stmt_index in lived_stmts.iter() {
let statement = &bbs[target_bb].statements[stmt_index];
match &statement.kind {
StatementKind::Assign(_) if expected_copy_stmt == stmt_index => {}
StatementKind::StorageLive(local)
if *local != expected_dest_place.local
&& storage_live_locals.insert(*local) => {}
StatementKind::StorageDead(local)
if *local != expected_dest_place.local
&& storage_live_locals.remove(*local) => {}
StatementKind::Nop => {}
_ => return None,
}
}
if !storage_live_locals.is_empty() {
return None;
}
}
}
let statement_index = bbs[switch_bb_idx].statements.len();
let parent_end = Location { block: switch_bb_idx, statement_index };
let mut patch = MirPatch::new(body);
patch.add_assign(
parent_end,
expected_dest_place,
Rvalue::Use(Operand::Copy(expected_src_place)),
);
patch.patch_terminator(switch_bb_idx, first_terminator_kind.clone());
patch.apply(body);
Some(())
}
14 changes: 8 additions & 6 deletions tests/codegen/match-optimizes-away.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
//@ compile-flags: -O
//@ compile-flags: -O -Cno-prepopulate-passes

#![crate_type = "lib"]

pub enum Three {
Expand All @@ -19,8 +19,9 @@ pub enum Four {
#[no_mangle]
pub fn three_valued(x: Three) -> Three {
// CHECK-LABEL: @three_valued
// CHECK-NEXT: {{^.*:$}}
// CHECK-NEXT: ret i8 %0
// CHECK-SAME: (i8{{.*}} [[X:%x]])
// CHECK-NEXT: start:
// CHECK-NEXT: ret i8 [[X]]
match x {
Three::A => Three::A,
Three::B => Three::B,
Expand All @@ -31,8 +32,9 @@ pub fn three_valued(x: Three) -> Three {
#[no_mangle]
pub fn four_valued(x: Four) -> Four {
// CHECK-LABEL: @four_valued
// CHECK-NEXT: {{^.*:$}}
// CHECK-NEXT: ret i16 %0
// CHECK-SAME: (i16{{.*}} [[X:%x]])
// CHECK-NEXT: start:
// CHECK-NEXT: ret i16 [[X]]
match x {
Four::A => Four::A,
Four::B => Four::B,
Expand Down
7 changes: 1 addition & 6 deletions tests/codegen/try_question_mark_nop.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
//@ compile-flags: -O -Z merge-functions=disabled --edition=2021
//@ only-x86_64
// FIXME: Remove the `min-llvm-version`.
//@ min-llvm-version: 19

#![crate_type = "lib"]
#![feature(try_blocks)]

use std::ops::ControlFlow::{self, Break, Continue};
use std::ptr::NonNull;

// FIXME: The `trunc` and `select` instructions can be eliminated.
// CHECK-LABEL: @option_nop_match_32
#[no_mangle]
pub fn option_nop_match_32(x: Option<u32>) -> Option<u32> {
// CHECK: start:
// CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw i32 %0 to i1
// CHECK-NEXT: [[FIRST:%.*]] = select i1 [[TRUNC]], i32 %0
// CHECK-NEXT: insertvalue { i32, i32 } poison, i32 [[FIRST]]
// CHECK-NEXT: insertvalue { i32, i32 }
// CHECK-NEXT: insertvalue { i32, i32 }
// CHECK-NEXT: ret { i32, i32 }
match x {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
let _6: *mut [bool; 0];
scope 6 {
scope 10 (inlined NonNull::<[bool; 0]>::new_unchecked) {
let mut _8: bool;
let _9: ();
let mut _10: *mut ();
let mut _11: *const [bool; 0];
let _8: ();
let mut _9: *mut ();
let mut _10: *const [bool; 0];
scope 11 (inlined core::ub_checks::check_language_ub) {
let mut _11: bool;
scope 12 (inlined core::ub_checks::check_language_ub::runtime) {
}
}
Expand All @@ -44,18 +44,18 @@
StorageLive(_1);
StorageLive(_2);
StorageLive(_3);
StorageLive(_9);
StorageLive(_8);
StorageLive(_4);
StorageLive(_5);
StorageLive(_6);
StorageLive(_7);
_7 = const 1_usize;
_6 = const {0x1 as *mut [bool; 0]};
StorageDead(_7);
StorageLive(_10);
StorageLive(_11);
StorageLive(_8);
_8 = UbChecks();
switchInt(move _8) -> [0: bb4, otherwise: bb2];
_11 = UbChecks();
switchInt(copy _11) -> [0: bb4, otherwise: bb2];
}

bb1: {
Expand All @@ -64,28 +64,28 @@
}

bb2: {
StorageLive(_10);
_10 = const {0x1 as *mut ()};
_9 = NonNull::<T>::new_unchecked::precondition_check(const {0x1 as *mut ()}) -> [return: bb3, unwind unreachable];
StorageLive(_9);
_9 = const {0x1 as *mut ()};
_8 = NonNull::<T>::new_unchecked::precondition_check(const {0x1 as *mut ()}) -> [return: bb3, unwind unreachable];
}

bb3: {
StorageDead(_10);
StorageDead(_9);
goto -> bb4;
}

bb4: {
StorageDead(_8);
_11 = const {0x1 as *const [bool; 0]};
_10 = const {0x1 as *const [bool; 0]};
_5 = const NonNull::<[bool; 0]> {{ pointer: {0x1 as *const [bool; 0]} }};
StorageDead(_11);
StorageDead(_10);
StorageDead(_6);
_4 = const Unique::<[bool; 0]> {{ pointer: NonNull::<[bool; 0]> {{ pointer: {0x1 as *const [bool; 0]} }}, _marker: PhantomData::<[bool; 0]> }};
StorageDead(_5);
_3 = const Unique::<[bool]> {{ pointer: NonNull::<[bool]> {{ pointer: Indirect { alloc_id: ALLOC0, offset: Size(0 bytes) }: *const [bool] }}, _marker: PhantomData::<[bool]> }};
StorageDead(_4);
_2 = const Box::<[bool]>(Unique::<[bool]> {{ pointer: NonNull::<[bool]> {{ pointer: Indirect { alloc_id: ALLOC1, offset: Size(0 bytes) }: *const [bool] }}, _marker: PhantomData::<[bool]> }}, std::alloc::Global);
StorageDead(_9);
StorageDead(_8);
StorageDead(_3);
_1 = const A {{ foo: Box::<[bool]>(Unique::<[bool]> {{ pointer: NonNull::<[bool]> {{ pointer: Indirect { alloc_id: ALLOC2, offset: Size(0 bytes) }: *const [bool] }}, _marker: PhantomData::<[bool]> }}, std::alloc::Global) }};
StorageDead(_2);
Expand Down
Loading

0 comments on commit c324112

Please sign in to comment.