diff --git a/library/core/benches/slice.rs b/library/core/benches/slice.rs index 3bfb35e684ea1..1ec51653d92ef 100644 --- a/library/core/benches/slice.rs +++ b/library/core/benches/slice.rs @@ -171,3 +171,17 @@ fn fold_to_last(b: &mut Bencher) { let slice: &[i32] = &[0; 1024]; b.iter(|| black_box(slice).iter().fold(None, |_, r| Some(NonNull::from(r)))); } + +#[bench] +fn slice_cmp_generic(b: &mut Bencher) { + #[derive(PartialEq, Clone, Copy)] + struct Foo(u32, u32); + + let left = [Foo(128, 128); 128]; + let right = [Foo(128, 128); 128]; + + b.iter(|| { + let (left, right) = (black_box(&left), black_box(&right)); + left.as_slice() == right.as_slice() + }); +} diff --git a/library/core/src/slice/cmp.rs b/library/core/src/slice/cmp.rs index 075347b80d031..3231e4313a8a3 100644 --- a/library/core/src/slice/cmp.rs +++ b/library/core/src/slice/cmp.rs @@ -55,15 +55,56 @@ impl SlicePartialEq for [A] where A: PartialEq, { + #[inline] default fn equal(&self, other: &[B]) -> bool { if self.len() != other.len() { return false; } - self.iter().zip(other.iter()).all(|(x, y)| x == y) + // at least 8 items for unrolling to make sense (4 peeled + 4+ unrolled) + if self.len() < 8 { + return eq_small(self, other); + } + + eq_unroll(self, other) } } +#[inline] +fn eq_small(a: &[A], b: &[B]) -> bool +where + A: PartialEq, +{ + a.iter().zip(b).all(|(a, b)| a == b) +} + +fn eq_unroll(a: &[A], b: &[B]) -> bool +where + A: PartialEq, +{ + let (mut chunks_a, residual_a) = a.as_chunks::<4>(); + let (mut chunks_b, residual_b) = b.as_chunks::<4>(); + let peeled_a = chunks_a.take_first().unwrap(); + let peeled_b = chunks_b.take_first().unwrap(); + + // peel the first chunk and do a short-circuiting comparison to bail early on mismatches + // in case comparisons are expensive + let mut result = eq_small(peeled_a, peeled_b); + + // then check the residual, another chance to bail early + result = result && eq_small(residual_a, residual_b); + + // iter.all short-circuits which means the backend can't unroll the loop due to early exits. + // So we unroll it manually. + result = result + && chunks_a + .iter() + .zip(chunks_b) + .all(|(a, b)| (a[0] == b[0]) & (a[1] == b[1]) & (a[2] == b[2]) & (a[3] == b[3])); + + result +} + // When each element can be compared byte-wise, we can compare all the bytes // from the whole size in one call to the intrinsics. impl SlicePartialEq for [A]