From be2401b8bfc824026b477f11b876b5611cb204c0 Mon Sep 17 00:00:00 2001
From: Jakob Degen <jakob.e.degen@gmail.com>
Date: Mon, 26 Sep 2022 18:43:35 -0700
Subject: [PATCH] Split phase change from `MirPass`

---
 .../src/transform/promote_consts.rs           |  4 -
 compiler/rustc_middle/src/mir/mod.rs          | 39 ++++++--
 compiler/rustc_mir_transform/src/lib.rs       | 21 ++---
 compiler/rustc_mir_transform/src/marker.rs    | 20 -----
 .../rustc_mir_transform/src/pass_manager.rs   | 89 ++++++++++++-------
 compiler/rustc_mir_transform/src/shim.rs      |  4 +-
 6 files changed, 101 insertions(+), 76 deletions(-)
 delete mode 100644 compiler/rustc_mir_transform/src/marker.rs

diff --git a/compiler/rustc_const_eval/src/transform/promote_consts.rs b/compiler/rustc_const_eval/src/transform/promote_consts.rs
index 4b219300739c0..f3ae16da43bd1 100644
--- a/compiler/rustc_const_eval/src/transform/promote_consts.rs
+++ b/compiler/rustc_const_eval/src/transform/promote_consts.rs
@@ -41,10 +41,6 @@ pub struct PromoteTemps<'tcx> {
 }
 
 impl<'tcx> MirPass<'tcx> for PromoteTemps<'tcx> {
-    fn phase_change(&self) -> Option<MirPhase> {
-        Some(MirPhase::Analysis(AnalysisPhase::Initial))
-    }
-
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         // There's not really any point in promoting errorful MIR.
         //
diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs
index e0e823e2090ef..3b80af184be47 100644
--- a/compiler/rustc_middle/src/mir/mod.rs
+++ b/compiler/rustc_middle/src/mir/mod.rs
@@ -116,11 +116,6 @@ pub trait MirPass<'tcx> {
 
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>);
 
-    /// If this pass causes the MIR to enter a new phase, return that phase.
-    fn phase_change(&self) -> Option<MirPhase> {
-        None
-    }
-
     fn is_mir_dump_enabled(&self) -> bool {
         true
     }
@@ -145,6 +140,35 @@ impl MirPhase {
     }
 }
 
+impl Display for MirPhase {
+    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+        match self {
+            MirPhase::Built => write!(f, "built"),
+            MirPhase::Analysis(p) => write!(f, "analysis-{}", p),
+            MirPhase::Runtime(p) => write!(f, "runtime-{}", p),
+        }
+    }
+}
+
+impl Display for AnalysisPhase {
+    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+        match self {
+            AnalysisPhase::Initial => write!(f, "initial"),
+            AnalysisPhase::PostCleanup => write!(f, "post_cleanup"),
+        }
+    }
+}
+
+impl Display for RuntimePhase {
+    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+        match self {
+            RuntimePhase::Initial => write!(f, "initial"),
+            RuntimePhase::PostCleanup => write!(f, "post_cleanup"),
+            RuntimePhase::Optimized => write!(f, "optimized"),
+        }
+    }
+}
+
 /// Where a specific `mir::Body` comes from.
 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
 #[derive(HashStable, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable)]
@@ -207,6 +231,9 @@ pub struct Body<'tcx> {
     /// us to see the difference and forego optimization on the inlined promoted items.
     pub phase: MirPhase,
 
+    /// How many passses we have executed since starting the current phase. Used for debug output.
+    pub pass_count: usize,
+
     pub source: MirSource<'tcx>,
 
     /// A list of source scopes; these are referenced by statements
@@ -292,6 +319,7 @@ impl<'tcx> Body<'tcx> {
 
         let mut body = Body {
             phase: MirPhase::Built,
+            pass_count: 1,
             source,
             basic_blocks: BasicBlocks::new(basic_blocks),
             source_scopes,
@@ -325,6 +353,7 @@ impl<'tcx> Body<'tcx> {
     pub fn new_cfg_only(basic_blocks: IndexVec<BasicBlock, BasicBlockData<'tcx>>) -> Self {
         let mut body = Body {
             phase: MirPhase::Built,
+            pass_count: 1,
             source: MirSource::item(CRATE_DEF_ID.to_def_id()),
             basic_blocks: BasicBlocks::new(basic_blocks),
             source_scopes: IndexVec::new(),
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index 5c411fa565784..5be2232547bd6 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -71,7 +71,6 @@ mod inline;
 mod instcombine;
 mod lower_intrinsics;
 mod lower_slice_len;
-mod marker;
 mod match_branches;
 mod multiple_return_terminators;
 mod normalize_array_len;
@@ -303,6 +302,7 @@ fn mir_const<'tcx>(
             &simplify::SimplifyCfg::new("initial"),
             &rustc_peek::SanityCheck, // Just a lint
         ],
+        None,
     );
     tcx.alloc_steal_mir(body)
 }
@@ -342,6 +342,7 @@ fn mir_promoted<'tcx>(
             &simplify::SimplifyCfg::new("promote-consts"),
             &coverage::InstrumentCoverage,
         ],
+        Some(MirPhase::Analysis(AnalysisPhase::Initial)),
     );
 
     let promoted = promote_pass.promoted_fragments.into_inner();
@@ -409,10 +410,8 @@ fn inner_mir_for_ctfe(tcx: TyCtxt<'_>, def: ty::WithOptConstParam<LocalDefId>) -
             pm::run_passes(
                 tcx,
                 &mut body,
-                &[
-                    &const_prop::ConstProp,
-                    &marker::PhaseChange(MirPhase::Runtime(RuntimePhase::Optimized)),
-                ],
+                &[&const_prop::ConstProp],
+                Some(MirPhase::Runtime(RuntimePhase::Optimized)),
             );
         }
     }
@@ -474,6 +473,7 @@ fn run_analysis_to_runtime_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>
                 &remove_uninit_drops::RemoveUninitDrops,
                 &simplify::SimplifyCfg::new("remove-false-edges"),
             ],
+            None,
         );
         check_consts::post_drop_elaboration::check_live_drops(tcx, &body); // FIXME: make this a MIR lint
     }
@@ -498,10 +498,9 @@ fn run_analysis_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         &cleanup_post_borrowck::CleanupNonCodegenStatements,
         &simplify::SimplifyCfg::new("early-opt"),
         &deref_separator::Derefer,
-        &marker::PhaseChange(MirPhase::Analysis(AnalysisPhase::PostCleanup)),
     ];
 
-    pm::run_passes(tcx, body, passes);
+    pm::run_passes(tcx, body, passes, Some(MirPhase::Analysis(AnalysisPhase::PostCleanup)));
 }
 
 /// Returns the sequence of passes that lowers analysis to runtime MIR.
@@ -526,9 +525,8 @@ fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         // CTFE support for aggregates.
         &deaggregator::Deaggregator,
         &Lint(const_prop_lint::ConstProp),
-        &marker::PhaseChange(MirPhase::Runtime(RuntimePhase::Initial)),
     ];
-    pm::run_passes_no_validate(tcx, body, passes);
+    pm::run_passes_no_validate(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::Initial)));
 }
 
 /// Returns the sequence of passes that do the initial cleanup of runtime MIR.
@@ -537,10 +535,9 @@ fn run_runtime_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         &elaborate_box_derefs::ElaborateBoxDerefs,
         &lower_intrinsics::LowerIntrinsics,
         &simplify::SimplifyCfg::new("elaborate-drops"),
-        &marker::PhaseChange(MirPhase::Runtime(RuntimePhase::PostCleanup)),
     ];
 
-    pm::run_passes(tcx, body, passes);
+    pm::run_passes(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::PostCleanup)));
 }
 
 fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
@@ -591,10 +588,10 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
             &deduplicate_blocks::DeduplicateBlocks,
             // Some cleanup necessary at least for LLVM and potentially other codegen backends.
             &add_call_guards::CriticalCallEdges,
-            &marker::PhaseChange(MirPhase::Runtime(RuntimePhase::Optimized)),
             // Dump the end result for testing and debugging purposes.
             &dump_mir::Marker("PreCodegen"),
         ],
+        Some(MirPhase::Runtime(RuntimePhase::Optimized)),
     );
 }
 
diff --git a/compiler/rustc_mir_transform/src/marker.rs b/compiler/rustc_mir_transform/src/marker.rs
deleted file mode 100644
index 06819fc1d37d4..0000000000000
--- a/compiler/rustc_mir_transform/src/marker.rs
+++ /dev/null
@@ -1,20 +0,0 @@
-use std::borrow::Cow;
-
-use crate::MirPass;
-use rustc_middle::mir::{Body, MirPhase};
-use rustc_middle::ty::TyCtxt;
-
-/// Changes the MIR phase without changing the MIR itself.
-pub struct PhaseChange(pub MirPhase);
-
-impl<'tcx> MirPass<'tcx> for PhaseChange {
-    fn phase_change(&self) -> Option<MirPhase> {
-        Some(self.0)
-    }
-
-    fn name(&self) -> Cow<'_, str> {
-        Cow::from(format!("PhaseChange-{:?}", self.0))
-    }
-
-    fn run_pass(&self, _: TyCtxt<'tcx>, _body: &mut Body<'tcx>) {}
-}
diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs
index 67dae71468f90..230c6a7cb4b00 100644
--- a/compiler/rustc_mir_transform/src/pass_manager.rs
+++ b/compiler/rustc_mir_transform/src/pass_manager.rs
@@ -66,10 +66,6 @@ where
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         self.1.run_pass(tcx, body)
     }
-
-    fn phase_change(&self) -> Option<MirPhase> {
-        self.1.phase_change()
-    }
 }
 
 /// Run the sequence of passes without validating the MIR after each pass. The MIR is still
@@ -78,23 +74,28 @@ pub fn run_passes_no_validate<'tcx>(
     tcx: TyCtxt<'tcx>,
     body: &mut Body<'tcx>,
     passes: &[&dyn MirPass<'tcx>],
+    phase_change: Option<MirPhase>,
 ) {
-    run_passes_inner(tcx, body, passes, false);
+    run_passes_inner(tcx, body, passes, phase_change, false);
 }
 
-pub fn run_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, passes: &[&dyn MirPass<'tcx>]) {
-    run_passes_inner(tcx, body, passes, true);
+/// The optional `phase_change` is applied after executing all the passes, if present
+pub fn run_passes<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    body: &mut Body<'tcx>,
+    passes: &[&dyn MirPass<'tcx>],
+    phase_change: Option<MirPhase>,
+) {
+    run_passes_inner(tcx, body, passes, phase_change, true);
 }
 
 fn run_passes_inner<'tcx>(
     tcx: TyCtxt<'tcx>,
     body: &mut Body<'tcx>,
     passes: &[&dyn MirPass<'tcx>],
+    phase_change: Option<MirPhase>,
     validate_each: bool,
 ) {
-    let start_phase = body.phase;
-    let mut cnt = 0;
-
     let validate = validate_each & tcx.sess.opts.unstable_opts.validate_mir;
     let overridden_passes = &tcx.sess.opts.unstable_opts.mir_enable_passes;
     trace!(?overridden_passes);
@@ -102,7 +103,6 @@ fn run_passes_inner<'tcx>(
     for pass in passes {
         let name = pass.name();
 
-        // Gather information about what we should be doing for this pass
         let overridden =
             overridden_passes.iter().rev().find(|(s, _)| s == &*name).map(|(_name, polarity)| {
                 trace!(
@@ -112,32 +112,44 @@ fn run_passes_inner<'tcx>(
                 );
                 *polarity
             });
-        let is_enabled = overridden.unwrap_or_else(|| pass.is_enabled(&tcx.sess));
-        let new_phase = pass.phase_change();
-        let dump_enabled = (is_enabled && pass.is_mir_dump_enabled()) || new_phase.is_some();
-        let validate = (validate && is_enabled)
-            || new_phase == Some(MirPhase::Runtime(RuntimePhase::Optimized));
+        if !overridden.unwrap_or_else(|| pass.is_enabled(&tcx.sess)) {
+            continue;
+        }
+
+        let dump_enabled = pass.is_mir_dump_enabled();
 
         if dump_enabled {
-            dump_mir(tcx, body, start_phase, &name, cnt, false);
-        }
-        if is_enabled {
-            pass.run_pass(tcx, body);
+            dump_mir_for_pass(tcx, body, &name, false);
         }
-        if dump_enabled {
-            dump_mir(tcx, body, start_phase, &name, cnt, true);
-            cnt += 1;
+        if validate {
+            validate_body(tcx, body, format!("before pass {}", name));
         }
-        if let Some(new_phase) = pass.phase_change() {
-            if body.phase >= new_phase {
-                panic!("Invalid MIR phase transition from {:?} to {:?}", body.phase, new_phase);
-            }
 
-            body.phase = new_phase;
+        pass.run_pass(tcx, body);
+
+        if dump_enabled {
+            dump_mir_for_pass(tcx, body, &name, true);
         }
         if validate {
             validate_body(tcx, body, format!("after pass {}", name));
         }
+
+        body.pass_count += 1;
+    }
+
+    if let Some(new_phase) = phase_change {
+        if body.phase >= new_phase {
+            panic!("Invalid MIR phase transition from {:?} to {:?}", body.phase, new_phase);
+        }
+
+        body.phase = new_phase;
+
+        dump_mir_for_phase_change(tcx, body);
+        if validate || new_phase == MirPhase::Runtime(RuntimePhase::Optimized) {
+            validate_body(tcx, body, format!("after phase change to {}", new_phase));
+        }
+
+        body.pass_count = 1;
     }
 }
 
@@ -145,22 +157,33 @@ pub fn validate_body<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, when: Strin
     validate::Validator { when, mir_phase: body.phase }.run_pass(tcx, body);
 }
 
-pub fn dump_mir<'tcx>(
+pub fn dump_mir_for_pass<'tcx>(
     tcx: TyCtxt<'tcx>,
     body: &Body<'tcx>,
-    phase: MirPhase,
     pass_name: &str,
-    cnt: usize,
     is_after: bool,
 ) {
-    let phase_index = phase.phase_index();
+    let phase_index = body.phase.phase_index();
 
     mir::dump_mir(
         tcx,
-        Some(&format_args!("{:03}-{:03}", phase_index, cnt)),
+        Some(&format_args!("{:03}-{:03}", phase_index, body.pass_count)),
         pass_name,
         if is_after { &"after" } else { &"before" },
         body,
         |_, _| Ok(()),
     );
 }
+
+pub fn dump_mir_for_phase_change<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
+    let phase_index = body.phase.phase_index();
+
+    mir::dump_mir(
+        tcx,
+        Some(&format_args!("{:03}-000", phase_index)),
+        &format!("{}", body.phase),
+        &"after",
+        body,
+        |_, _| Ok(()),
+    )
+}
diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs
index 6ca58ee458c56..c19380ef89cc5 100644
--- a/compiler/rustc_mir_transform/src/shim.rs
+++ b/compiler/rustc_mir_transform/src/shim.rs
@@ -17,7 +17,7 @@ use std::iter;
 
 use crate::util::expand_aggregate;
 use crate::{
-    abort_unwinding_calls, add_call_guards, add_moves_for_packed_drops, deref_separator, marker,
+    abort_unwinding_calls, add_call_guards, add_moves_for_packed_drops, deref_separator,
     pass_manager as pm, remove_noop_landing_pads, simplify,
 };
 use rustc_middle::mir::patch::MirPatch;
@@ -97,8 +97,8 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
             &simplify::SimplifyCfg::new("make_shim"),
             &add_call_guards::CriticalCallEdges,
             &abort_unwinding_calls::AbortUnwindingCalls,
-            &marker::PhaseChange(MirPhase::Runtime(RuntimePhase::Optimized)),
         ],
+        Some(MirPhase::Runtime(RuntimePhase::Optimized)),
     );
 
     debug!("make_shim({:?}) = {:?}", instance, result);