From d98270b07a04178b019d8317d146eb3dd1afface Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Fri, 31 Jan 2025 11:53:32 +0100 Subject: [PATCH] miri: make float min/max non-deterministic --- .../src/interpret/intrinsics.rs | 16 +++++++++-- .../rustc_const_eval/src/interpret/machine.rs | 6 ++++ src/tools/miri/src/machine.rs | 13 ++++++--- src/tools/miri/src/operator.rs | 7 +++++ src/tools/miri/tests/pass/float.rs | 28 +++++++++++++++++++ 5 files changed, 64 insertions(+), 6 deletions(-) diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics.rs b/compiler/rustc_const_eval/src/interpret/intrinsics.rs index 0664a882c1d50..9f5f2533e085b 100644 --- a/compiler/rustc_const_eval/src/interpret/intrinsics.rs +++ b/compiler/rustc_const_eval/src/interpret/intrinsics.rs @@ -747,7 +747,13 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { { let a: F = self.read_scalar(&args[0])?.to_float()?; let b: F = self.read_scalar(&args[1])?.to_float()?; - let res = self.adjust_nan(a.min(b), &[a, b]); + let res = if a == b { + // They are definitely not NaN (those are never equal), but they could be `+0` and `-0`. + // Let the machine decide which one to return. + M::equal_float_min_max(self, a, b) + } else { + self.adjust_nan(a.min(b), &[a, b]) + }; self.write_scalar(res, dest)?; interp_ok(()) } @@ -762,7 +768,13 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { { let a: F = self.read_scalar(&args[0])?.to_float()?; let b: F = self.read_scalar(&args[1])?.to_float()?; - let res = self.adjust_nan(a.max(b), &[a, b]); + let res = if a == b { + // They are definitely not NaN (those are never equal), but they could be `+0` and `-0`. + // Let the machine decide which one to return. + M::equal_float_min_max(self, a, b) + } else { + self.adjust_nan(a.max(b), &[a, b]) + }; self.write_scalar(res, dest)?; interp_ok(()) } diff --git a/compiler/rustc_const_eval/src/interpret/machine.rs b/compiler/rustc_const_eval/src/interpret/machine.rs index 36e5a2ff750ae..8f6b15b8df012 100644 --- a/compiler/rustc_const_eval/src/interpret/machine.rs +++ b/compiler/rustc_const_eval/src/interpret/machine.rs @@ -278,6 +278,12 @@ pub trait Machine<'tcx>: Sized { F2::NAN } + /// Determines the result of `min`/`max` on floats when the arguments are equal. + fn equal_float_min_max(_ecx: &InterpCx<'tcx, Self>, a: F, _b: F) -> F { + // By default, we pick the left argument. + a + } + /// Called before a basic block terminator is executed. #[inline] fn before_terminator(_ecx: &mut InterpCx<'tcx, Self>) -> InterpResult<'tcx> { diff --git a/src/tools/miri/src/machine.rs b/src/tools/miri/src/machine.rs index 845ba484326f0..3727b5f4cae4a 100644 --- a/src/tools/miri/src/machine.rs +++ b/src/tools/miri/src/machine.rs @@ -11,6 +11,7 @@ use std::{fmt, process}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use rustc_abi::{Align, ExternAbi, Size}; +use rustc_apfloat::{Float, FloatConvert}; use rustc_attr_parsing::InlineAttr; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; #[allow(unused)] @@ -1129,20 +1130,24 @@ impl<'tcx> Machine<'tcx> for MiriMachine<'tcx> { } #[inline(always)] - fn generate_nan< - F1: rustc_apfloat::Float + rustc_apfloat::FloatConvert, - F2: rustc_apfloat::Float, - >( + fn generate_nan, F2: Float>( ecx: &InterpCx<'tcx, Self>, inputs: &[F1], ) -> F2 { ecx.generate_nan(inputs) } + #[inline(always)] + fn equal_float_min_max(ecx: &MiriInterpCx<'tcx>, a: F, b: F) -> F { + ecx.equal_float_min_max(a, b) + } + + #[inline(always)] fn ub_checks(ecx: &InterpCx<'tcx, Self>) -> InterpResult<'tcx, bool> { interp_ok(ecx.tcx.sess.ub_checks()) } + #[inline(always)] fn thread_local_static_pointer( ecx: &mut MiriInterpCx<'tcx>, def_id: DefId, diff --git a/src/tools/miri/src/operator.rs b/src/tools/miri/src/operator.rs index 0017a3991b53b..43c628d66d590 100644 --- a/src/tools/miri/src/operator.rs +++ b/src/tools/miri/src/operator.rs @@ -115,4 +115,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { nan } } + + fn equal_float_min_max(&self, a: F, b: F) -> F { + let this = self.eval_context_ref(); + // Return one side non-deterministically. + let mut rand = this.machine.rng.borrow_mut(); + if rand.gen() { a } else { b } + } } diff --git a/src/tools/miri/tests/pass/float.rs b/src/tools/miri/tests/pass/float.rs index 4de315e358975..2f4f64b1aa800 100644 --- a/src/tools/miri/tests/pass/float.rs +++ b/src/tools/miri/tests/pass/float.rs @@ -31,6 +31,7 @@ fn main() { test_fast(); test_algebraic(); test_fmuladd(); + test_min_max_nondet(); } trait Float: Copy + PartialEq + Debug { @@ -1211,3 +1212,30 @@ fn test_fmuladd() { test_operations_f32(0.1, 0.2, 0.3); test_operations_f64(1.1, 1.2, 1.3); } + +/// `min` and `max` on equal arguments are non-deterministic. +fn test_min_max_nondet() { + /// Ensure that if we call the closure often enough, we see both `true` and `false.` + #[track_caller] + fn ensure_both(f: impl Fn() -> bool) { + let rounds = 16; + let first = f(); + for _ in 1..rounds { + if f() != first { + // We saw two different values! + return; + } + } + // We saw the same thing N times. + panic!("expected non-determinism, got {rounds} times the same result: {first:?}"); + } + + ensure_both(|| f16::min(0.0, -0.0).is_sign_positive()); + ensure_both(|| f16::max(0.0, -0.0).is_sign_positive()); + ensure_both(|| f32::min(0.0, -0.0).is_sign_positive()); + ensure_both(|| f32::max(0.0, -0.0).is_sign_positive()); + ensure_both(|| f64::min(0.0, -0.0).is_sign_positive()); + ensure_both(|| f64::max(0.0, -0.0).is_sign_positive()); + ensure_both(|| f128::min(0.0, -0.0).is_sign_positive()); + ensure_both(|| f128::max(0.0, -0.0).is_sign_positive()); +}