Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dest prop: Support removing writes when this unblocks optimizations #105813

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 96 additions & 48 deletions compiler/rustc_mir_transform/src/dest_prop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,23 +207,27 @@ impl<'tcx> MirPass<'tcx> for DestinationPropagation {

// This is the set of merges we will apply this round. It is a subset of the candidates.
let mut merges = FxHashMap::default();
let mut remove_writes = FxHashMap::default();

for (src, candidates) in candidates.c.iter() {
if merged_locals.contains(*src) {
for (src, candidates) in candidates.c.drain() {
if merged_locals.contains(src) {
continue;
}
let Some(dest) =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let Some(dest) =
let Some((dest, writes_to_remove)) =

candidates.iter().find(|dest| !merged_locals.contains(**dest)) else {
candidates.into_iter().find(|(dest, _)| !merged_locals.contains(*dest)) else {
continue;
};
if !tcx.consider_optimizing(|| {
format!("{} round {}", tcx.def_path_str(def_id), round_count)
}) {
break;
}
merges.insert(*src, *dest);
merged_locals.insert(*src);
merged_locals.insert(*dest);
merged_locals.insert(src);
merged_locals.insert(dest.0);
merges.insert(src, dest.0);
if !dest.1.is_empty() {
remove_writes.insert(dest.0, dest.1);
}
}
trace!(merging = ?merges);

Expand All @@ -232,7 +236,7 @@ impl<'tcx> MirPass<'tcx> for DestinationPropagation {
}
round_count += 1;

apply_merges(body, tcx, &merges, &merged_locals);
apply_merges(body, tcx, &merges, &remove_writes, &merged_locals);
}

trace!(round_count);
Expand All @@ -245,7 +249,7 @@ impl<'tcx> MirPass<'tcx> for DestinationPropagation {
/// frequently. Everything with a `&'alloc` lifetime points into here.
#[derive(Default)]
struct Allocations {
candidates: FxHashMap<Local, Vec<Local>>,
candidates: FxHashMap<Local, Vec<(Local, Vec<Location>)>>,
candidates_reverse: FxHashMap<Local, Vec<Local>>,
write_info: WriteInfo,
// PERF: Do this for `MaybeLiveLocals` allocations too.
Expand All @@ -267,7 +271,11 @@ struct Candidates<'alloc> {
///
/// We will still report that we would like to merge `_1` and `_2` in an attempt to allow us to
/// remove that assignment.
c: &'alloc mut FxHashMap<Local, Vec<Local>>,
///
/// Each candidate pair is associated with a `Vec<Location>`. If the candidate pair is accepted,
/// all writes to either local at these locations must be removed. The writes will always be
/// removable.
c: &'alloc mut FxHashMap<Local, Vec<(Local, Vec<Location>)>>,
/// A reverse index of the `c` set; if the `c` set contains `a => Place { local: b, proj }`,
/// then this contains `b => a`.
// PERF: Possibly these should be `SmallVec`s?
Expand All @@ -283,18 +291,29 @@ fn apply_merges<'tcx>(
body: &mut Body<'tcx>,
tcx: TyCtxt<'tcx>,
merges: &FxHashMap<Local, Local>,
remove_writes: &FxHashMap<Local, Vec<Location>>,
merged_locals: &BitSet<Local>,
) {
let mut merger = Merger { tcx, merges, merged_locals };
let mut merger = Merger { tcx, merges, remove_writes, merged_locals };
merger.visit_body_preserves_cfg(body);
}

struct Merger<'a, 'tcx> {
tcx: TyCtxt<'tcx>,
merges: &'a FxHashMap<Local, Local>,
remove_writes: &'a FxHashMap<Local, Vec<Location>>,
merged_locals: &'a BitSet<Local>,
}

impl<'a, 'tcx> Merger<'a, 'tcx> {
fn should_remove_write_at(&self, local: Local, location: Location) -> bool {
let Some(to_remove) = self.remove_writes.get(&local) else {
return false;
};
to_remove.contains(&location)
}
}

impl<'a, 'tcx> MutVisitor<'tcx> for Merger<'a, 'tcx> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
Expand Down Expand Up @@ -332,10 +351,27 @@ impl<'a, 'tcx> MutVisitor<'tcx> for Merger<'a, 'tcx> {
_ => {}
}
}
StatementKind::Deinit(place) => {
if self.should_remove_write_at(place.local, location) {
statement.make_nop();
}
}

_ => {}
}
}

fn visit_operand(&mut self, op: &mut Operand<'tcx>, location: Location) {
self.super_operand(op, location);
match op {
Operand::Move(place) => {
if self.should_remove_write_at(place.local, location) {
*op = Operand::Copy(*place);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only valid if place is sized as unsized values can't be copied. I don't think should_remove_write_at can return true for unsized places (maybe there is a case where an unsized place can be re-initialized after a move though?), but it may be a good idea to check anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this is a theoretical problem in future versions of dest prop I think. God, I really wish we could just treat unsized values as completely unsupported. I'll add a check.

}
}
_ => (),
}
}
}

//////////////////////////////////////////////////////////
Expand All @@ -356,30 +392,35 @@ struct FilterInformation<'a, 'body, 'alloc, 'tcx> {
// through these methods, and not directly.
impl<'alloc> Candidates<'alloc> {
/// Just `Vec::retain`, but the condition is inverted and we add debugging output
fn vec_filter_candidates(
fn vec_modify_candidates(
src: Local,
v: &mut Vec<Local>,
mut f: impl FnMut(Local) -> CandidateFilter,
v: &mut Vec<(Local, Vec<Location>)>,
mut f: impl FnMut(Local) -> CandidateModification,
at: Location,
) {
v.retain(|dest| {
let remove = f(*dest);
if remove == CandidateFilter::Remove {
v.retain_mut(|(dest, remove_writes)| match f(*dest) {
CandidateModification::Remove => {
trace!("eliminating {:?} => {:?} due to conflict at {:?}", src, dest, at);
false
}
CandidateModification::RemoveWrite => {
trace!("marking write for {:?} => {:?} as needing removing at {:?}", src, dest, at);
remove_writes.push(at);
true
}
remove == CandidateFilter::Keep
CandidateModification::Keep => true,
});
}

/// `vec_filter_candidates` but for an `Entry`
fn entry_filter_candidates(
mut entry: OccupiedEntry<'_, Local, Vec<Local>>,
mut entry: OccupiedEntry<'_, Local, Vec<(Local, Vec<Location>)>>,
p: Local,
f: impl FnMut(Local) -> CandidateFilter,
f: impl FnMut(Local) -> CandidateModification,
at: Location,
) {
let candidates = entry.get_mut();
Self::vec_filter_candidates(p, candidates, f, at);
Self::vec_modify_candidates(p, candidates, f, at);
if candidates.len() == 0 {
entry.remove();
}
Expand All @@ -389,7 +430,7 @@ impl<'alloc> Candidates<'alloc> {
fn filter_candidates_by(
&mut self,
p: Local,
mut f: impl FnMut(Local) -> CandidateFilter,
mut f: impl FnMut(Local) -> CandidateModification,
at: Location,
) {
// Cover the cases where `p` appears as a `src`
Expand All @@ -403,7 +444,8 @@ impl<'alloc> Candidates<'alloc> {
// We use `retain` here to remove the elements from the reverse set if we've removed the
// matching candidate in the forward set.
srcs.retain(|src| {
if f(*src) == CandidateFilter::Keep {
let modification = f(*src);
if modification == CandidateModification::Keep {
return true;
}
let Entry::Occupied(entry) = self.c.entry(*src) else {
Expand All @@ -413,18 +455,20 @@ impl<'alloc> Candidates<'alloc> {
entry,
*src,
|dest| {
if dest == p { CandidateFilter::Remove } else { CandidateFilter::Keep }
if dest == p { modification } else { CandidateModification::Keep }
},
at,
);
false
// Remove the src from the reverse set if we removed the candidate pair
modification == CandidateModification::RemoveWrite
});
}
}

#[derive(Copy, Clone, PartialEq, Eq)]
enum CandidateFilter {
enum CandidateModification {
Keep,
RemoveWrite,
Remove,
}

Expand Down Expand Up @@ -483,31 +527,36 @@ impl<'a, 'body, 'alloc, 'tcx> FilterInformation<'a, 'body, 'alloc, 'tcx> {

fn apply_conflicts(&mut self) {
let writes = &self.write_info.writes;
for p in writes {
for &(p, is_removable) in writes {
let modification = if is_removable {
CandidateModification::RemoveWrite
} else {
CandidateModification::Remove
};
let other_skip = self.write_info.skip_pair.and_then(|(a, b)| {
if a == *p {
if a == p {
Some(b)
} else if b == *p {
} else if b == p {
Some(a)
} else {
None
}
});
self.candidates.filter_candidates_by(
*p,
p,
|q| {
if Some(q) == other_skip {
return CandidateFilter::Keep;
return CandidateModification::Keep;
}
// It is possible that a local may be live for less than the
// duration of a statement This happens in the case of function
// calls or inline asm. Because of this, we also mark locals as
// conflicting when both of them are written to in the same
// statement.
if self.live.contains(q) || writes.contains(&q) {
CandidateFilter::Remove
if self.live.contains(q) || writes.iter().any(|&(x, _)| x == q) {
modification
} else {
CandidateFilter::Keep
CandidateModification::Keep
}
},
self.at,
Expand All @@ -519,7 +568,9 @@ impl<'a, 'body, 'alloc, 'tcx> FilterInformation<'a, 'body, 'alloc, 'tcx> {
/// Describes where a statement/terminator writes to
#[derive(Default, Debug)]
struct WriteInfo {
writes: Vec<Local>,
/// Which locals are written to. The `bool` is true if the write is "removable," ie if it comes
/// from a `Operand::Move` or `Deinit`.
writes: Vec<(Local, bool)>,
/// If this pair of locals is a candidate pair, completely skip processing it during this
/// statement. All other candidates are unaffected.
skip_pair: Option<(Local, Local)>,
Expand Down Expand Up @@ -563,10 +614,11 @@ impl WriteInfo {
| Rvalue::CopyForDeref(_) => (),
}
}
StatementKind::Deinit(p) => {
self.writes.push((p.local, true));
}
// Retags are technically also reads, but reporting them as a write suffices
StatementKind::SetDiscriminant { place, .. }
| StatementKind::Deinit(place)
| StatementKind::Retag(_, place) => {
StatementKind::SetDiscriminant { place, .. } | StatementKind::Retag(_, place) => {
self.add_place(**place);
}
StatementKind::Intrinsic(_)
Expand Down Expand Up @@ -652,16 +704,12 @@ impl WriteInfo {
}

fn add_place<'tcx>(&mut self, place: Place<'tcx>) {
self.writes.push(place.local);
self.writes.push((place.local, false));
}

fn add_operand<'tcx>(&mut self, op: &Operand<'tcx>) {
match op {
// FIXME(JakobDegen): In a previous version, the `Move` case was incorrectly treated as
// being a read only. This was unsound, however we cannot add a regression test because
// it is not possible to set this off with current MIR. Once we have that ability, a
// regression test should be added.
Operand::Move(p) => self.add_place(*p),
Operand::Move(p) => self.writes.push((p.local, true)),
Operand::Copy(_) | Operand::Constant(_) => (),
}
}
Expand Down Expand Up @@ -716,7 +764,7 @@ fn places_to_candidate_pair<'tcx>(
fn find_candidates<'alloc, 'tcx>(
body: &Body<'tcx>,
borrowed: &BitSet<Local>,
candidates: &'alloc mut FxHashMap<Local, Vec<Local>>,
candidates: &'alloc mut FxHashMap<Local, Vec<(Local, Vec<Location>)>>,
candidates_reverse: &'alloc mut FxHashMap<Local, Vec<Local>>,
) -> Candidates<'alloc> {
candidates.clear();
Expand All @@ -730,16 +778,16 @@ fn find_candidates<'alloc, 'tcx>(
}
// Generate the reverse map
for (src, cands) in candidates.iter() {
for dest in cands.iter().copied() {
candidates_reverse.entry(dest).or_default().push(*src);
for (dest, _) in cands.iter() {
candidates_reverse.entry(*dest).or_default().push(*src);
}
}
Candidates { c: candidates, reverse: candidates_reverse }
}

struct FindAssignments<'a, 'alloc, 'tcx> {
body: &'a Body<'tcx>,
candidates: &'alloc mut FxHashMap<Local, Vec<Local>>,
candidates: &'alloc mut FxHashMap<Local, Vec<(Local, Vec<Location>)>>,
borrowed: &'a BitSet<Local>,
}

Expand All @@ -766,7 +814,7 @@ impl<'tcx> Visitor<'tcx> for FindAssignments<'_, '_, 'tcx> {
}

// We may insert duplicates here, but that's fine
self.candidates.entry(src).or_default().push(dest);
self.candidates.entry(src).or_default().push((dest, Vec::new()));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
- // MIR for `move_simple` before DestinationPropagation
+ // MIR for `move_simple` after DestinationPropagation

fn move_simple(_1: i32) -> () {
debug x => _1; // in scope 0 at $DIR/move.rs:+0:16: +0:17
let mut _0: (); // return place in scope 0 at $DIR/move.rs:+0:24: +0:24
let _2: (); // in scope 0 at $DIR/move.rs:+1:5: +1:19
let mut _3: i32; // in scope 0 at $DIR/move.rs:+1:14: +1:15
let mut _4: i32; // in scope 0 at $DIR/move.rs:+1:17: +1:18

bb0: {
StorageLive(_2); // scope 0 at $DIR/move.rs:+1:5: +1:19
- StorageLive(_3); // scope 0 at $DIR/move.rs:+1:14: +1:15
- _3 = _1; // scope 0 at $DIR/move.rs:+1:14: +1:15
- StorageLive(_4); // scope 0 at $DIR/move.rs:+1:17: +1:18
- _4 = _1; // scope 0 at $DIR/move.rs:+1:17: +1:18
- _2 = use_both(move _3, move _4) -> bb1; // scope 0 at $DIR/move.rs:+1:5: +1:19
+ nop; // scope 0 at $DIR/move.rs:+1:14: +1:15
+ nop; // scope 0 at $DIR/move.rs:+1:14: +1:15
+ nop; // scope 0 at $DIR/move.rs:+1:17: +1:18
+ nop; // scope 0 at $DIR/move.rs:+1:17: +1:18
+ _2 = use_both(_1, _1) -> bb1; // scope 0 at $DIR/move.rs:+1:5: +1:19
// mir::Constant
// + span: $DIR/move.rs:8:5: 8:13
// + literal: Const { ty: fn(i32, i32) {use_both}, val: Value(<ZST>) }
}

bb1: {
- StorageDead(_4); // scope 0 at $DIR/move.rs:+1:18: +1:19
- StorageDead(_3); // scope 0 at $DIR/move.rs:+1:18: +1:19
+ nop; // scope 0 at $DIR/move.rs:+1:18: +1:19
+ nop; // scope 0 at $DIR/move.rs:+1:18: +1:19
StorageDead(_2); // scope 0 at $DIR/move.rs:+1:19: +1:20
_0 = const (); // scope 0 at $DIR/move.rs:+0:24: +2:2
return; // scope 0 at $DIR/move.rs:+2:2: +2:2
}
}

13 changes: 13 additions & 0 deletions src/test/mir-opt/dest-prop/move.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// unit-test: DestinationPropagation

#[inline(never)]
fn use_both(_: i32, _: i32) {}

// EMIT_MIR move.move_simple.DestinationPropagation.diff
fn move_simple(x: i32) {
use_both(x, x);
}

fn main() {
move_simple(1);
}
Loading