miri/
math.rs

1use rand::Rng as _;
2use rustc_apfloat::Float as _;
3use rustc_apfloat::ieee::IeeeFloat;
4use rustc_middle::ty::{self, FloatTy, ScalarInt};
5
6use crate::*;
7
8/// Disturbes a floating-point result by a relative error in the range (-2^scale, 2^scale).
9///
10/// For a 2^N ULP error, you can use an `err_scale` of `-(F::PRECISION - 1 - N)`.
11/// In other words, a 1 ULP (absolute) error is the same as a `2^-(F::PRECISION-1)` relative error.
12/// (Subtracting 1 compensates for the integer bit.)
13pub(crate) fn apply_random_float_error<F: rustc_apfloat::Float>(
14    ecx: &mut crate::MiriInterpCx<'_>,
15    val: F,
16    err_scale: i32,
17) -> F {
18    if !ecx.machine.float_nondet {
19        return val;
20    }
21
22    let rng = ecx.machine.rng.get_mut();
23    // Generate a random integer in the range [0, 2^PREC).
24    // (When read as binary, the position of the first `1` determines the exponent,
25    // and the remaining bits fill the mantissa. `PREC` is one plus the size of the mantissa,
26    // so this all works out.)
27    let r = F::from_u128(rng.random_range(0..(1 << F::PRECISION))).value;
28    // Multiply this with 2^(scale - PREC). The result is between 0 and
29    // 2^PREC * 2^(scale - PREC) = 2^scale.
30    let err = r.scalbn(err_scale.strict_sub(F::PRECISION.try_into().unwrap()));
31    // give it a random sign
32    let err = if rng.random() { -err } else { err };
33    // multiple the value with (1+err)
34    (val * (F::from_u128(1).value + err).value).value
35}
36
37/// [`apply_random_float_error`] gives instructions to apply a 2^N ULP error.
38/// This function implements these instructions such that applying a 2^N ULP error is less error prone.
39/// So for a 2^N ULP error, you would pass N as the `ulp_exponent` argument.
40pub(crate) fn apply_random_float_error_ulp<F: rustc_apfloat::Float>(
41    ecx: &mut crate::MiriInterpCx<'_>,
42    val: F,
43    ulp_exponent: u32,
44) -> F {
45    let n = i32::try_from(ulp_exponent)
46        .expect("`err_scale_for_ulp`: exponent is too large to create an error scale");
47    // we know this fits
48    let prec = i32::try_from(F::PRECISION).unwrap();
49    let err_scale = -(prec - n - 1);
50    apply_random_float_error(ecx, val, err_scale)
51}
52
53/// Applies a random 16ULP floating point error to `val` and returns the new value.
54/// Will fail if `val` is not a floating point number.
55pub(crate) fn apply_random_float_error_to_imm<'tcx>(
56    ecx: &mut MiriInterpCx<'tcx>,
57    val: ImmTy<'tcx>,
58    ulp_exponent: u32,
59) -> InterpResult<'tcx, ImmTy<'tcx>> {
60    let scalar = val.to_scalar_int()?;
61    let res: ScalarInt = match val.layout.ty.kind() {
62        ty::Float(FloatTy::F16) =>
63            apply_random_float_error_ulp(ecx, scalar.to_f16(), ulp_exponent).into(),
64        ty::Float(FloatTy::F32) =>
65            apply_random_float_error_ulp(ecx, scalar.to_f32(), ulp_exponent).into(),
66        ty::Float(FloatTy::F64) =>
67            apply_random_float_error_ulp(ecx, scalar.to_f64(), ulp_exponent).into(),
68        ty::Float(FloatTy::F128) =>
69            apply_random_float_error_ulp(ecx, scalar.to_f128(), ulp_exponent).into(),
70        _ => bug!("intrinsic called with non-float input type"),
71    };
72
73    interp_ok(ImmTy::from_scalar_int(res, val.layout))
74}
75
76pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFloat<S> {
77    match x.category() {
78        // preserve zero sign
79        rustc_apfloat::Category::Zero => x,
80        // propagate NaN
81        rustc_apfloat::Category::NaN => x,
82        // sqrt of negative number is NaN
83        _ if x.is_negative() => IeeeFloat::NAN,
84        // sqrt(∞) = ∞
85        rustc_apfloat::Category::Infinity => IeeeFloat::INFINITY,
86        rustc_apfloat::Category::Normal => {
87            // Floating point precision, excluding the integer bit
88            let prec = i32::try_from(S::PRECISION).unwrap() - 1;
89
90            // x = 2^(exp - prec) * mant
91            // where mant is an integer with prec+1 bits
92            // mant is a u128, which should be large enough for the largest prec (112 for f128)
93            let mut exp = x.ilogb();
94            let mut mant = x.scalbn(prec - exp).to_u128(128).value;
95
96            if exp % 2 != 0 {
97                // Make exponent even, so it can be divided by 2
98                exp -= 1;
99                mant <<= 1;
100            }
101
102            // Bit-by-bit (base-2 digit-by-digit) sqrt of mant.
103            // mant is treated here as a fixed point number with prec fractional bits.
104            // mant will be shifted left by one bit to have an extra fractional bit, which
105            // will be used to determine the rounding direction.
106
107            // res is the truncated sqrt of mant, where one bit is added at each iteration.
108            let mut res = 0u128;
109            // rem is the remainder with the current res
110            // rem_i = 2^i * ((mant<<1) - res_i^2)
111            // starting with res = 0, rem = mant<<1
112            let mut rem = mant << 1;
113            // s_i = 2*res_i
114            let mut s = 0u128;
115            // d is used to iterate over bits, from high to low (d_i = 2^(-i))
116            let mut d = 1u128 << (prec + 1);
117
118            // For iteration j=i+1, we need to find largest b_j = 0 or 1 such that
119            //  (res_i + b_j * 2^(-j))^2 <= mant<<1
120            // Expanding (a + b)^2 = a^2 + b^2 + 2*a*b:
121            //  res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1
122            // And rearranging the terms:
123            //  b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2)
124            //  b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i
125
126            while d != 0 {
127                // Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1:
128                // t = 2*res_i + 2^(-j)
129                let t = s + d;
130                if rem >= t {
131                    // b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem
132                    res += d;
133                    s += d + d;
134                    rem -= t;
135                }
136                // Adjust rem for next iteration
137                rem <<= 1;
138                // Shift iterator
139                d >>= 1;
140            }
141
142            // Remove extra fractional bit from result, rounding to nearest.
143            // If the last bit is 0, then the nearest neighbor is definitely the lower one.
144            // If the last bit is 1, it sounds like this may either be a tie (if there's
145            // infinitely many 0s after this 1), or the nearest neighbor is the upper one.
146            // However, since square roots are either exact or irrational, and an exact root
147            // would lead to the last "extra" bit being 0, we can exclude a tie in this case.
148            // We therefore always round up if the last bit is 1. When the last bit is 0,
149            // adding 1 will not do anything since the shift will discard it.
150            res = (res + 1) >> 1;
151
152            // Build resulting value with res as mantissa and exp/2 as exponent
153            IeeeFloat::from_u128(res).value.scalbn(exp / 2 - prec)
154        }
155    }
156}
157
158/// Extend functionality of rustc_apfloat softfloats
159pub trait IeeeExt: rustc_apfloat::Float {
160    #[inline]
161    fn one() -> Self {
162        Self::from_u128(1).value
163    }
164
165    #[inline]
166    fn clamp(self, min: Self, max: Self) -> Self {
167        self.maximum(min).minimum(max)
168    }
169}
170impl<S: rustc_apfloat::ieee::Semantics> IeeeExt for IeeeFloat<S> {}
171
172#[cfg(test)]
173mod tests {
174    use rustc_apfloat::ieee::{DoubleS, HalfS, IeeeFloat, QuadS, SingleS};
175
176    use super::sqrt;
177
178    #[test]
179    fn test_sqrt() {
180        #[track_caller]
181        fn test<S: rustc_apfloat::ieee::Semantics>(x: &str, expected: &str) {
182            let x: IeeeFloat<S> = x.parse().unwrap();
183            let expected: IeeeFloat<S> = expected.parse().unwrap();
184            let result = sqrt(x);
185            assert_eq!(result, expected);
186        }
187
188        fn exact_tests<S: rustc_apfloat::ieee::Semantics>() {
189            test::<S>("0", "0");
190            test::<S>("1", "1");
191            test::<S>("1.5625", "1.25");
192            test::<S>("2.25", "1.5");
193            test::<S>("4", "2");
194            test::<S>("5.0625", "2.25");
195            test::<S>("9", "3");
196            test::<S>("16", "4");
197            test::<S>("25", "5");
198            test::<S>("36", "6");
199            test::<S>("49", "7");
200            test::<S>("64", "8");
201            test::<S>("81", "9");
202            test::<S>("100", "10");
203
204            test::<S>("0.5625", "0.75");
205            test::<S>("0.25", "0.5");
206            test::<S>("0.0625", "0.25");
207            test::<S>("0.00390625", "0.0625");
208        }
209
210        exact_tests::<HalfS>();
211        exact_tests::<SingleS>();
212        exact_tests::<DoubleS>();
213        exact_tests::<QuadS>();
214
215        test::<SingleS>("2", "1.4142135");
216        test::<DoubleS>("2", "1.4142135623730951");
217
218        test::<SingleS>("1.1", "1.0488088");
219        test::<DoubleS>("1.1", "1.0488088481701516");
220
221        test::<SingleS>("2.2", "1.4832398");
222        test::<DoubleS>("2.2", "1.4832396974191326");
223
224        test::<SingleS>("1.22101e-40", "1.10499205e-20");
225        test::<DoubleS>("1.22101e-310", "1.1049932126488395e-155");
226
227        test::<SingleS>("3.4028235e38", "1.8446743e19");
228        test::<DoubleS>("1.7976931348623157e308", "1.3407807929942596e154");
229    }
230}