miri/shims/x86/
mod.rs

1use rustc_abi::{CanonAbi, FieldIdx, Size};
2use rustc_apfloat::Float;
3use rustc_apfloat::ieee::Single;
4use rustc_middle::ty::Ty;
5use rustc_middle::{mir, ty};
6use rustc_span::Symbol;
7use rustc_target::callconv::FnAbi;
8
9use self::helpers::bool_to_simd_element;
10use crate::*;
11
12mod aesni;
13mod avx;
14mod avx2;
15mod bmi;
16mod gfni;
17mod sha;
18mod sse;
19mod sse2;
20mod sse3;
21mod sse41;
22mod sse42;
23mod ssse3;
24
25impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
26pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
27    fn emulate_x86_intrinsic(
28        &mut self,
29        link_name: Symbol,
30        abi: &FnAbi<'tcx, Ty<'tcx>>,
31        args: &[OpTy<'tcx>],
32        dest: &MPlaceTy<'tcx>,
33    ) -> InterpResult<'tcx, EmulateItemResult> {
34        let this = self.eval_context_mut();
35        // Prefix should have already been checked.
36        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.").unwrap();
37        match unprefixed_name {
38            // Used to implement the `_addcarry_u{32, 64}` and the `_subborrow_u{32, 64}` functions.
39            // Computes a + b or a - b with input and output carry/borrow. The input carry/borrow is an 8-bit
40            // value, which is interpreted as 1 if it is non-zero. The output carry/borrow is an 8-bit value that will be 0 or 1.
41            // https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-8/addcarry-u32-addcarry-u64.html
42            // https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-8/subborrow-u32-subborrow-u64.html
43            "addcarry.32" | "addcarry.64" | "subborrow.32" | "subborrow.64" => {
44                if unprefixed_name.ends_with("64") && this.tcx.sess.target.arch != "x86_64" {
45                    return interp_ok(EmulateItemResult::NotSupported);
46                }
47
48                let [cb_in, a, b] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
49                let op = if unprefixed_name.starts_with("add") {
50                    mir::BinOp::AddWithOverflow
51                } else {
52                    mir::BinOp::SubWithOverflow
53                };
54
55                let (sum, cb_out) = carrying_add(this, cb_in, a, b, op)?;
56                this.write_scalar(cb_out, &this.project_field(dest, FieldIdx::ZERO)?)?;
57                this.write_immediate(*sum, &this.project_field(dest, FieldIdx::ONE)?)?;
58            }
59
60            // Used to implement the `_addcarryx_u{32, 64}` functions. They are semantically identical with the `_addcarry_u{32, 64}` functions,
61            // except for a slightly different type signature and the requirement for the "adx" target feature.
62            // https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-8/addcarryx-u32-addcarryx-u64.html
63            "addcarryx.u32" | "addcarryx.u64" => {
64                this.expect_target_feature_for_intrinsic(link_name, "adx")?;
65
66                let is_u64 = unprefixed_name.ends_with("64");
67                if is_u64 && this.tcx.sess.target.arch != "x86_64" {
68                    return interp_ok(EmulateItemResult::NotSupported);
69                }
70                let [c_in, a, b, out] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
71                let out = this.deref_pointer_as(
72                    out,
73                    if is_u64 { this.machine.layouts.u64 } else { this.machine.layouts.u32 },
74                )?;
75
76                let (sum, c_out) = carrying_add(this, c_in, a, b, mir::BinOp::AddWithOverflow)?;
77                this.write_scalar(c_out, dest)?;
78                this.write_immediate(*sum, &out)?;
79            }
80
81            // Used to implement the `_mm_pause` function.
82            // The intrinsic is used to hint the processor that the code is in a spin-loop.
83            // It is compiled down to a `pause` instruction. When SSE2 is not available,
84            // the instruction behaves like a no-op, so it is always safe to call the
85            // intrinsic.
86            "sse2.pause" => {
87                let [] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
88                // Only exhibit the spin-loop hint behavior when SSE2 is enabled.
89                if this.tcx.sess.unstable_target_features.contains(&Symbol::intern("sse2")) {
90                    this.yield_active_thread();
91                }
92            }
93
94            "pclmulqdq" | "pclmulqdq.256" | "pclmulqdq.512" => {
95                let mut len = 2; // in units of 64bits
96                this.expect_target_feature_for_intrinsic(link_name, "pclmulqdq")?;
97                if unprefixed_name.ends_with(".256") {
98                    this.expect_target_feature_for_intrinsic(link_name, "vpclmulqdq")?;
99                    len = 4;
100                } else if unprefixed_name.ends_with(".512") {
101                    this.expect_target_feature_for_intrinsic(link_name, "vpclmulqdq")?;
102                    this.expect_target_feature_for_intrinsic(link_name, "avx512f")?;
103                    len = 8;
104                }
105
106                let [left, right, imm] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
107
108                pclmulqdq(this, left, right, imm, dest, len)?;
109            }
110
111            name if name.starts_with("bmi.") => {
112                return bmi::EvalContextExt::emulate_x86_bmi_intrinsic(
113                    this, link_name, abi, args, dest,
114                );
115            }
116            // The GFNI extension does not get its own namespace.
117            // Check for instruction names instead.
118            name if name.starts_with("vgf2p8affine") || name.starts_with("vgf2p8mulb") => {
119                return gfni::EvalContextExt::emulate_x86_gfni_intrinsic(
120                    this, link_name, abi, args, dest,
121                );
122            }
123            name if name.starts_with("sha") => {
124                return sha::EvalContextExt::emulate_x86_sha_intrinsic(
125                    this, link_name, abi, args, dest,
126                );
127            }
128            name if name.starts_with("sse.") => {
129                return sse::EvalContextExt::emulate_x86_sse_intrinsic(
130                    this, link_name, abi, args, dest,
131                );
132            }
133            name if name.starts_with("sse2.") => {
134                return sse2::EvalContextExt::emulate_x86_sse2_intrinsic(
135                    this, link_name, abi, args, dest,
136                );
137            }
138            name if name.starts_with("sse3.") => {
139                return sse3::EvalContextExt::emulate_x86_sse3_intrinsic(
140                    this, link_name, abi, args, dest,
141                );
142            }
143            name if name.starts_with("ssse3.") => {
144                return ssse3::EvalContextExt::emulate_x86_ssse3_intrinsic(
145                    this, link_name, abi, args, dest,
146                );
147            }
148            name if name.starts_with("sse41.") => {
149                return sse41::EvalContextExt::emulate_x86_sse41_intrinsic(
150                    this, link_name, abi, args, dest,
151                );
152            }
153            name if name.starts_with("sse42.") => {
154                return sse42::EvalContextExt::emulate_x86_sse42_intrinsic(
155                    this, link_name, abi, args, dest,
156                );
157            }
158            name if name.starts_with("aesni.") => {
159                return aesni::EvalContextExt::emulate_x86_aesni_intrinsic(
160                    this, link_name, abi, args, dest,
161                );
162            }
163            name if name.starts_with("avx.") => {
164                return avx::EvalContextExt::emulate_x86_avx_intrinsic(
165                    this, link_name, abi, args, dest,
166                );
167            }
168            name if name.starts_with("avx2.") => {
169                return avx2::EvalContextExt::emulate_x86_avx2_intrinsic(
170                    this, link_name, abi, args, dest,
171                );
172            }
173
174            _ => return interp_ok(EmulateItemResult::NotSupported),
175        }
176        interp_ok(EmulateItemResult::NeedsReturn)
177    }
178}
179
180#[derive(Copy, Clone)]
181enum FloatBinOp {
182    /// Comparison
183    ///
184    /// The semantics of this operator is a case distinction: we compare the two operands,
185    /// and then we return one of the four booleans `gt`, `lt`, `eq`, `unord` depending on
186    /// which class they fall into.
187    ///
188    /// AVX supports all 16 combinations, SSE only a subset
189    ///
190    /// <https://www.felixcloutier.com/x86/cmpss>
191    /// <https://www.felixcloutier.com/x86/cmpps>
192    /// <https://www.felixcloutier.com/x86/cmpsd>
193    /// <https://www.felixcloutier.com/x86/cmppd>
194    Cmp {
195        /// Result when lhs < rhs
196        gt: bool,
197        /// Result when lhs > rhs
198        lt: bool,
199        /// Result when lhs == rhs
200        eq: bool,
201        /// Result when lhs is NaN or rhs is NaN
202        unord: bool,
203    },
204    /// Minimum value (with SSE semantics)
205    ///
206    /// <https://www.felixcloutier.com/x86/minss>
207    /// <https://www.felixcloutier.com/x86/minps>
208    /// <https://www.felixcloutier.com/x86/minsd>
209    /// <https://www.felixcloutier.com/x86/minpd>
210    Min,
211    /// Maximum value (with SSE semantics)
212    ///
213    /// <https://www.felixcloutier.com/x86/maxss>
214    /// <https://www.felixcloutier.com/x86/maxps>
215    /// <https://www.felixcloutier.com/x86/maxsd>
216    /// <https://www.felixcloutier.com/x86/maxpd>
217    Max,
218}
219
220impl FloatBinOp {
221    /// Convert from the `imm` argument used to specify the comparison
222    /// operation in intrinsics such as `llvm.x86.sse.cmp.ss`.
223    fn cmp_from_imm<'tcx>(
224        ecx: &crate::MiriInterpCx<'tcx>,
225        imm: i8,
226        intrinsic: Symbol,
227    ) -> InterpResult<'tcx, Self> {
228        // Only bits 0..=4 are used, remaining should be zero.
229        if imm & !0b1_1111 != 0 {
230            panic!("invalid `imm` parameter of {intrinsic}: 0x{imm:x}");
231        }
232        // Bit 4 specifies whether the operation is quiet or signaling, which
233        // we do not care in Miri.
234        // Bits 0..=2 specifies the operation.
235        // `gt` indicates the result to be returned when the LHS is strictly
236        // greater than the RHS, and so on.
237        let (gt, lt, eq, mut unord) = match imm & 0b111 {
238            // Equal
239            0x0 => (false, false, true, false),
240            // Less-than
241            0x1 => (false, true, false, false),
242            // Less-or-equal
243            0x2 => (false, true, true, false),
244            // Unordered (either is NaN)
245            0x3 => (false, false, false, true),
246            // Not equal
247            0x4 => (true, true, false, true),
248            // Not less-than
249            0x5 => (true, false, true, true),
250            // Not less-or-equal
251            0x6 => (true, false, false, true),
252            // Ordered (neither is NaN)
253            0x7 => (true, true, true, false),
254            _ => unreachable!(),
255        };
256        // When bit 3 is 1 (only possible in AVX), unord is toggled.
257        if imm & 0b1000 != 0 {
258            ecx.expect_target_feature_for_intrinsic(intrinsic, "avx")?;
259            unord = !unord;
260        }
261        interp_ok(Self::Cmp { gt, lt, eq, unord })
262    }
263}
264
265/// Performs `which` scalar operation on `left` and `right` and returns
266/// the result.
267fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
268    which: FloatBinOp,
269    left: &ImmTy<'tcx>,
270    right: &ImmTy<'tcx>,
271) -> InterpResult<'tcx, Scalar> {
272    match which {
273        FloatBinOp::Cmp { gt, lt, eq, unord } => {
274            let left = left.to_scalar().to_float::<F>()?;
275            let right = right.to_scalar().to_float::<F>()?;
276
277            let res = match left.partial_cmp(&right) {
278                None => unord,
279                Some(std::cmp::Ordering::Less) => lt,
280                Some(std::cmp::Ordering::Equal) => eq,
281                Some(std::cmp::Ordering::Greater) => gt,
282            };
283            interp_ok(bool_to_simd_element(res, Size::from_bits(F::BITS)))
284        }
285        FloatBinOp::Min => {
286            let left_scalar = left.to_scalar();
287            let left = left_scalar.to_float::<F>()?;
288            let right_scalar = right.to_scalar();
289            let right = right_scalar.to_float::<F>()?;
290            // SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
291            // is true when `x` is either +0 or -0.
292            if (left == F::ZERO && right == F::ZERO)
293                || left.is_nan()
294                || right.is_nan()
295                || left >= right
296            {
297                interp_ok(right_scalar)
298            } else {
299                interp_ok(left_scalar)
300            }
301        }
302        FloatBinOp::Max => {
303            let left_scalar = left.to_scalar();
304            let left = left_scalar.to_float::<F>()?;
305            let right_scalar = right.to_scalar();
306            let right = right_scalar.to_float::<F>()?;
307            // SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
308            // is true when `x` is either +0 or -0.
309            if (left == F::ZERO && right == F::ZERO)
310                || left.is_nan()
311                || right.is_nan()
312                || left <= right
313            {
314                interp_ok(right_scalar)
315            } else {
316                interp_ok(left_scalar)
317            }
318        }
319    }
320}
321
322/// Performs `which` operation on the first component of `left` and `right`
323/// and copies the other components from `left`. The result is stored in `dest`.
324fn bin_op_simd_float_first<'tcx, F: rustc_apfloat::Float>(
325    ecx: &mut crate::MiriInterpCx<'tcx>,
326    which: FloatBinOp,
327    left: &OpTy<'tcx>,
328    right: &OpTy<'tcx>,
329    dest: &MPlaceTy<'tcx>,
330) -> InterpResult<'tcx, ()> {
331    let (left, left_len) = ecx.project_to_simd(left)?;
332    let (right, right_len) = ecx.project_to_simd(right)?;
333    let (dest, dest_len) = ecx.project_to_simd(dest)?;
334
335    assert_eq!(dest_len, left_len);
336    assert_eq!(dest_len, right_len);
337
338    let res0 = bin_op_float::<F>(
339        which,
340        &ecx.read_immediate(&ecx.project_index(&left, 0)?)?,
341        &ecx.read_immediate(&ecx.project_index(&right, 0)?)?,
342    )?;
343    ecx.write_scalar(res0, &ecx.project_index(&dest, 0)?)?;
344
345    for i in 1..dest_len {
346        ecx.copy_op(&ecx.project_index(&left, i)?, &ecx.project_index(&dest, i)?)?;
347    }
348
349    interp_ok(())
350}
351
352/// Performs `which` operation on each component of `left` and
353/// `right`, storing the result is stored in `dest`.
354fn bin_op_simd_float_all<'tcx, F: rustc_apfloat::Float>(
355    ecx: &mut crate::MiriInterpCx<'tcx>,
356    which: FloatBinOp,
357    left: &OpTy<'tcx>,
358    right: &OpTy<'tcx>,
359    dest: &MPlaceTy<'tcx>,
360) -> InterpResult<'tcx, ()> {
361    let (left, left_len) = ecx.project_to_simd(left)?;
362    let (right, right_len) = ecx.project_to_simd(right)?;
363    let (dest, dest_len) = ecx.project_to_simd(dest)?;
364
365    assert_eq!(dest_len, left_len);
366    assert_eq!(dest_len, right_len);
367
368    for i in 0..dest_len {
369        let left = ecx.read_immediate(&ecx.project_index(&left, i)?)?;
370        let right = ecx.read_immediate(&ecx.project_index(&right, i)?)?;
371        let dest = ecx.project_index(&dest, i)?;
372
373        let res = bin_op_float::<F>(which, &left, &right)?;
374        ecx.write_scalar(res, &dest)?;
375    }
376
377    interp_ok(())
378}
379
380#[derive(Copy, Clone)]
381enum FloatUnaryOp {
382    /// Approximation of 1/x
383    ///
384    /// <https://www.felixcloutier.com/x86/rcpss>
385    /// <https://www.felixcloutier.com/x86/rcpps>
386    Rcp,
387    /// Approximation of 1/sqrt(x)
388    ///
389    /// <https://www.felixcloutier.com/x86/rsqrtss>
390    /// <https://www.felixcloutier.com/x86/rsqrtps>
391    Rsqrt,
392}
393
394/// Performs `which` scalar operation on `op` and returns the result.
395fn unary_op_f32<'tcx>(
396    ecx: &mut crate::MiriInterpCx<'tcx>,
397    which: FloatUnaryOp,
398    op: &ImmTy<'tcx>,
399) -> InterpResult<'tcx, Scalar> {
400    match which {
401        FloatUnaryOp::Rcp => {
402            let op = op.to_scalar().to_f32()?;
403            let div = (Single::from_u128(1).value / op).value;
404            // Apply a relative error with a magnitude on the order of 2^-12 to simulate the
405            // inaccuracy of RCP.
406            let res = math::apply_random_float_error(ecx, div, -12);
407            interp_ok(Scalar::from_f32(res))
408        }
409        FloatUnaryOp::Rsqrt => {
410            let op = op.to_scalar().to_f32()?;
411            let rsqrt = (Single::from_u128(1).value / math::sqrt(op)).value;
412            // Apply a relative error with a magnitude on the order of 2^-12 to simulate the
413            // inaccuracy of RSQRT.
414            let res = math::apply_random_float_error(ecx, rsqrt, -12);
415            interp_ok(Scalar::from_f32(res))
416        }
417    }
418}
419
420/// Performs `which` operation on the first component of `op` and copies
421/// the other components. The result is stored in `dest`.
422fn unary_op_ss<'tcx>(
423    ecx: &mut crate::MiriInterpCx<'tcx>,
424    which: FloatUnaryOp,
425    op: &OpTy<'tcx>,
426    dest: &MPlaceTy<'tcx>,
427) -> InterpResult<'tcx, ()> {
428    let (op, op_len) = ecx.project_to_simd(op)?;
429    let (dest, dest_len) = ecx.project_to_simd(dest)?;
430
431    assert_eq!(dest_len, op_len);
432
433    let res0 = unary_op_f32(ecx, which, &ecx.read_immediate(&ecx.project_index(&op, 0)?)?)?;
434    ecx.write_scalar(res0, &ecx.project_index(&dest, 0)?)?;
435
436    for i in 1..dest_len {
437        ecx.copy_op(&ecx.project_index(&op, i)?, &ecx.project_index(&dest, i)?)?;
438    }
439
440    interp_ok(())
441}
442
443/// Performs `which` operation on each component of `op`, storing the
444/// result is stored in `dest`.
445fn unary_op_ps<'tcx>(
446    ecx: &mut crate::MiriInterpCx<'tcx>,
447    which: FloatUnaryOp,
448    op: &OpTy<'tcx>,
449    dest: &MPlaceTy<'tcx>,
450) -> InterpResult<'tcx, ()> {
451    let (op, op_len) = ecx.project_to_simd(op)?;
452    let (dest, dest_len) = ecx.project_to_simd(dest)?;
453
454    assert_eq!(dest_len, op_len);
455
456    for i in 0..dest_len {
457        let op = ecx.read_immediate(&ecx.project_index(&op, i)?)?;
458        let dest = ecx.project_index(&dest, i)?;
459
460        let res = unary_op_f32(ecx, which, &op)?;
461        ecx.write_scalar(res, &dest)?;
462    }
463
464    interp_ok(())
465}
466
467enum ShiftOp {
468    /// Shift left, logically (shift in zeros) -- same as shift left, arithmetically
469    Left,
470    /// Shift right, logically (shift in zeros)
471    RightLogic,
472    /// Shift right, arithmetically (shift in sign)
473    RightArith,
474}
475
476/// Shifts each element of `left` by a scalar amount. The shift amount
477/// is determined by the lowest 64 bits of `right` (which is a 128-bit vector).
478///
479/// For logic shifts, when right is larger than BITS - 1, zero is produced.
480/// For arithmetic right-shifts, when right is larger than BITS - 1, the sign
481/// bit is copied to all bits.
482fn shift_simd_by_scalar<'tcx>(
483    ecx: &mut crate::MiriInterpCx<'tcx>,
484    left: &OpTy<'tcx>,
485    right: &OpTy<'tcx>,
486    which: ShiftOp,
487    dest: &MPlaceTy<'tcx>,
488) -> InterpResult<'tcx, ()> {
489    let (left, left_len) = ecx.project_to_simd(left)?;
490    let (dest, dest_len) = ecx.project_to_simd(dest)?;
491
492    assert_eq!(dest_len, left_len);
493    // `right` may have a different length, and we only care about its
494    // lowest 64bit anyway.
495
496    // Get the 64-bit shift operand and convert it to the type expected
497    // by checked_{shl,shr} (u32).
498    // It is ok to saturate the value to u32::MAX because any value
499    // above BITS - 1 will produce the same result.
500    let shift = u32::try_from(extract_first_u64(ecx, right)?).unwrap_or(u32::MAX);
501
502    for i in 0..dest_len {
503        let left = ecx.read_scalar(&ecx.project_index(&left, i)?)?;
504        let dest = ecx.project_index(&dest, i)?;
505
506        let res = match which {
507            ShiftOp::Left => {
508                let left = left.to_uint(dest.layout.size)?;
509                let res = left.checked_shl(shift).unwrap_or(0);
510                // `truncate` is needed as left-shift can make the absolute value larger.
511                Scalar::from_uint(dest.layout.size.truncate(res), dest.layout.size)
512            }
513            ShiftOp::RightLogic => {
514                let left = left.to_uint(dest.layout.size)?;
515                let res = left.checked_shr(shift).unwrap_or(0);
516                // No `truncate` needed as right-shift can only make the absolute value smaller.
517                Scalar::from_uint(res, dest.layout.size)
518            }
519            ShiftOp::RightArith => {
520                let left = left.to_int(dest.layout.size)?;
521                // On overflow, copy the sign bit to the remaining bits
522                let res = left.checked_shr(shift).unwrap_or(left >> 127);
523                // No `truncate` needed as right-shift can only make the absolute value smaller.
524                Scalar::from_int(res, dest.layout.size)
525            }
526        };
527        ecx.write_scalar(res, &dest)?;
528    }
529
530    interp_ok(())
531}
532
533/// Shifts each element of `left` by the corresponding element of `right`.
534///
535/// For logic shifts, when right is larger than BITS - 1, zero is produced.
536/// For arithmetic right-shifts, when right is larger than BITS - 1, the sign
537/// bit is copied to all bits.
538fn shift_simd_by_simd<'tcx>(
539    ecx: &mut crate::MiriInterpCx<'tcx>,
540    left: &OpTy<'tcx>,
541    right: &OpTy<'tcx>,
542    which: ShiftOp,
543    dest: &MPlaceTy<'tcx>,
544) -> InterpResult<'tcx, ()> {
545    let (left, left_len) = ecx.project_to_simd(left)?;
546    let (right, right_len) = ecx.project_to_simd(right)?;
547    let (dest, dest_len) = ecx.project_to_simd(dest)?;
548
549    assert_eq!(dest_len, left_len);
550    assert_eq!(dest_len, right_len);
551
552    for i in 0..dest_len {
553        let left = ecx.read_scalar(&ecx.project_index(&left, i)?)?;
554        let right = ecx.read_scalar(&ecx.project_index(&right, i)?)?;
555        let dest = ecx.project_index(&dest, i)?;
556
557        // It is ok to saturate the value to u32::MAX because any value
558        // above BITS - 1 will produce the same result.
559        let shift = u32::try_from(right.to_uint(dest.layout.size)?).unwrap_or(u32::MAX);
560
561        let res = match which {
562            ShiftOp::Left => {
563                let left = left.to_uint(dest.layout.size)?;
564                let res = left.checked_shl(shift).unwrap_or(0);
565                // `truncate` is needed as left-shift can make the absolute value larger.
566                Scalar::from_uint(dest.layout.size.truncate(res), dest.layout.size)
567            }
568            ShiftOp::RightLogic => {
569                let left = left.to_uint(dest.layout.size)?;
570                let res = left.checked_shr(shift).unwrap_or(0);
571                // No `truncate` needed as right-shift can only make the absolute value smaller.
572                Scalar::from_uint(res, dest.layout.size)
573            }
574            ShiftOp::RightArith => {
575                let left = left.to_int(dest.layout.size)?;
576                // On overflow, copy the sign bit to the remaining bits
577                let res = left.checked_shr(shift).unwrap_or(left >> 127);
578                // No `truncate` needed as right-shift can only make the absolute value smaller.
579                Scalar::from_int(res, dest.layout.size)
580            }
581        };
582        ecx.write_scalar(res, &dest)?;
583    }
584
585    interp_ok(())
586}
587
588/// Takes a 128-bit vector, transmutes it to `[u64; 2]` and extracts
589/// the first value.
590fn extract_first_u64<'tcx>(
591    ecx: &crate::MiriInterpCx<'tcx>,
592    op: &OpTy<'tcx>,
593) -> InterpResult<'tcx, u64> {
594    // Transmute vector to `[u64; 2]`
595    let array_layout = ecx.layout_of(Ty::new_array(ecx.tcx.tcx, ecx.tcx.types.u64, 2))?;
596    let op = op.transmute(array_layout, ecx)?;
597
598    // Get the first u64 from the array
599    ecx.read_scalar(&ecx.project_index(&op, 0)?)?.to_u64()
600}
601
602// Rounds the first element of `right` according to `rounding`
603// and copies the remaining elements from `left`.
604fn round_first<'tcx, F: rustc_apfloat::Float>(
605    ecx: &mut crate::MiriInterpCx<'tcx>,
606    left: &OpTy<'tcx>,
607    right: &OpTy<'tcx>,
608    rounding: &OpTy<'tcx>,
609    dest: &MPlaceTy<'tcx>,
610) -> InterpResult<'tcx, ()> {
611    let (left, left_len) = ecx.project_to_simd(left)?;
612    let (right, right_len) = ecx.project_to_simd(right)?;
613    let (dest, dest_len) = ecx.project_to_simd(dest)?;
614
615    assert_eq!(dest_len, left_len);
616    assert_eq!(dest_len, right_len);
617
618    let rounding = rounding_from_imm(ecx.read_scalar(rounding)?.to_i32()?)?;
619
620    let op0: F = ecx.read_scalar(&ecx.project_index(&right, 0)?)?.to_float()?;
621    let res = op0.round_to_integral(rounding).value;
622    ecx.write_scalar(
623        Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
624        &ecx.project_index(&dest, 0)?,
625    )?;
626
627    for i in 1..dest_len {
628        ecx.copy_op(&ecx.project_index(&left, i)?, &ecx.project_index(&dest, i)?)?;
629    }
630
631    interp_ok(())
632}
633
634// Rounds all elements of `op` according to `rounding`.
635fn round_all<'tcx, F: rustc_apfloat::Float>(
636    ecx: &mut crate::MiriInterpCx<'tcx>,
637    op: &OpTy<'tcx>,
638    rounding: &OpTy<'tcx>,
639    dest: &MPlaceTy<'tcx>,
640) -> InterpResult<'tcx, ()> {
641    let (op, op_len) = ecx.project_to_simd(op)?;
642    let (dest, dest_len) = ecx.project_to_simd(dest)?;
643
644    assert_eq!(dest_len, op_len);
645
646    let rounding = rounding_from_imm(ecx.read_scalar(rounding)?.to_i32()?)?;
647
648    for i in 0..dest_len {
649        let op: F = ecx.read_scalar(&ecx.project_index(&op, i)?)?.to_float()?;
650        let res = op.round_to_integral(rounding).value;
651        ecx.write_scalar(
652            Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
653            &ecx.project_index(&dest, i)?,
654        )?;
655    }
656
657    interp_ok(())
658}
659
660/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of
661/// `round.{ss,sd,ps,pd}` intrinsics.
662fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> {
663    // The fourth bit of `rounding` only affects the SSE status
664    // register, which cannot be accessed from Miri (or from Rust,
665    // for that matter), so we can ignore it.
666    match rounding & !0b1000 {
667        // When the third bit is 0, the rounding mode is determined by the
668        // first two bits.
669        0b000 => interp_ok(rustc_apfloat::Round::NearestTiesToEven),
670        0b001 => interp_ok(rustc_apfloat::Round::TowardNegative),
671        0b010 => interp_ok(rustc_apfloat::Round::TowardPositive),
672        0b011 => interp_ok(rustc_apfloat::Round::TowardZero),
673        // When the third bit is 1, the rounding mode is determined by the
674        // SSE status register. Since we do not support modifying it from
675        // Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
676        0b100..=0b111 => interp_ok(rustc_apfloat::Round::NearestTiesToEven),
677        rounding => panic!("invalid rounding mode 0x{rounding:02x}"),
678    }
679}
680
681/// Converts each element of `op` from floating point to signed integer.
682///
683/// When the input value is NaN or out of range, fall back to minimum value.
684///
685/// If `op` has more elements than `dest`, extra elements are ignored. If `op`
686/// has less elements than `dest`, the rest is filled with zeros.
687fn convert_float_to_int<'tcx>(
688    ecx: &mut crate::MiriInterpCx<'tcx>,
689    op: &OpTy<'tcx>,
690    rnd: rustc_apfloat::Round,
691    dest: &MPlaceTy<'tcx>,
692) -> InterpResult<'tcx, ()> {
693    let (op, op_len) = ecx.project_to_simd(op)?;
694    let (dest, dest_len) = ecx.project_to_simd(dest)?;
695
696    // Output must be *signed* integers.
697    assert!(matches!(dest.layout.field(ecx, 0).ty.kind(), ty::Int(_)));
698
699    for i in 0..op_len.min(dest_len) {
700        let op = ecx.read_immediate(&ecx.project_index(&op, i)?)?;
701        let dest = ecx.project_index(&dest, i)?;
702
703        let res = ecx.float_to_int_checked(&op, dest.layout, rnd)?.unwrap_or_else(|| {
704            // Fallback to minimum according to SSE/AVX semantics.
705            ImmTy::from_int(dest.layout.size.signed_int_min(), dest.layout)
706        });
707        ecx.write_immediate(*res, &dest)?;
708    }
709    // Fill remainder with zeros
710    for i in op_len..dest_len {
711        let dest = ecx.project_index(&dest, i)?;
712        ecx.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
713    }
714
715    interp_ok(())
716}
717
718/// Calculates absolute value of integers in `op` and stores the result in `dest`.
719///
720/// In case of overflow (when the operand is the minimum value), the operation
721/// will wrap around.
722fn int_abs<'tcx>(
723    ecx: &mut crate::MiriInterpCx<'tcx>,
724    op: &OpTy<'tcx>,
725    dest: &MPlaceTy<'tcx>,
726) -> InterpResult<'tcx, ()> {
727    let (op, op_len) = ecx.project_to_simd(op)?;
728    let (dest, dest_len) = ecx.project_to_simd(dest)?;
729
730    assert_eq!(op_len, dest_len);
731
732    let zero = ImmTy::from_int(0, op.layout.field(ecx, 0));
733
734    for i in 0..dest_len {
735        let op = ecx.read_immediate(&ecx.project_index(&op, i)?)?;
736        let dest = ecx.project_index(&dest, i)?;
737
738        let lt_zero = ecx.binary_op(mir::BinOp::Lt, &op, &zero)?;
739        let res =
740            if lt_zero.to_scalar().to_bool()? { ecx.unary_op(mir::UnOp::Neg, &op)? } else { op };
741
742        ecx.write_immediate(*res, &dest)?;
743    }
744
745    interp_ok(())
746}
747
748/// Splits `op` (which must be a SIMD vector) into 128-bit chunks.
749///
750/// Returns a tuple where:
751/// * The first element is the number of 128-bit chunks (let's call it `N`).
752/// * The second element is the number of elements per chunk (let's call it `M`).
753/// * The third element is the `op` vector split into chunks, i.e, it's
754///   type is `[[T; M]; N]` where `T` is the element type of `op`.
755fn split_simd_to_128bit_chunks<'tcx, P: Projectable<'tcx, Provenance>>(
756    ecx: &mut crate::MiriInterpCx<'tcx>,
757    op: &P,
758) -> InterpResult<'tcx, (u64, u64, P)> {
759    let simd_layout = op.layout();
760    let (simd_len, element_ty) = simd_layout.ty.simd_size_and_type(ecx.tcx.tcx);
761
762    assert_eq!(simd_layout.size.bits() % 128, 0);
763    let num_chunks = simd_layout.size.bits() / 128;
764    let items_per_chunk = simd_len.strict_div(num_chunks);
765
766    // Transmute to `[[T; items_per_chunk]; num_chunks]`
767    let chunked_layout = ecx
768        .layout_of(Ty::new_array(
769            ecx.tcx.tcx,
770            Ty::new_array(ecx.tcx.tcx, element_ty, items_per_chunk),
771            num_chunks,
772        ))
773        .unwrap();
774    let chunked_op = op.transmute(chunked_layout, ecx)?;
775
776    interp_ok((num_chunks, items_per_chunk, chunked_op))
777}
778
779/// Horizontally performs `which` operation on adjacent values of
780/// `left` and `right` SIMD vectors and stores the result in `dest`.
781/// "Horizontal" means that the i-th output element is calculated
782/// from the elements 2*i and 2*i+1 of the concatenation of `left` and
783/// `right`.
784///
785/// Each 128-bit chunk is treated independently (i.e., the value for
786/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
787/// 128-bit chunks of `left` and `right`).
788fn horizontal_bin_op<'tcx>(
789    ecx: &mut crate::MiriInterpCx<'tcx>,
790    which: mir::BinOp,
791    saturating: bool,
792    left: &OpTy<'tcx>,
793    right: &OpTy<'tcx>,
794    dest: &MPlaceTy<'tcx>,
795) -> InterpResult<'tcx, ()> {
796    assert_eq!(left.layout, dest.layout);
797    assert_eq!(right.layout, dest.layout);
798
799    let (num_chunks, items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
800    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
801    let (_, _, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
802
803    let middle = items_per_chunk / 2;
804    for i in 0..num_chunks {
805        let left = ecx.project_index(&left, i)?;
806        let right = ecx.project_index(&right, i)?;
807        let dest = ecx.project_index(&dest, i)?;
808
809        for j in 0..items_per_chunk {
810            // `j` is the index in `dest`
811            // `k` is the index of the 2-item chunk in `src`
812            let (k, src) = if j < middle { (j, &left) } else { (j.strict_sub(middle), &right) };
813            // `base_i` is the index of the first item of the 2-item chunk in `src`
814            let base_i = k.strict_mul(2);
815            let lhs = ecx.read_immediate(&ecx.project_index(src, base_i)?)?;
816            let rhs = ecx.read_immediate(&ecx.project_index(src, base_i.strict_add(1))?)?;
817
818            let res = if saturating {
819                Immediate::from(ecx.saturating_arith(which, &lhs, &rhs)?)
820            } else {
821                *ecx.binary_op(which, &lhs, &rhs)?
822            };
823
824            ecx.write_immediate(res, &ecx.project_index(&dest, j)?)?;
825        }
826    }
827
828    interp_ok(())
829}
830
831/// Conditionally multiplies the packed floating-point elements in
832/// `left` and `right` using the high 4 bits in `imm`, sums the calculated
833/// products (up to 4), and conditionally stores the sum in `dest` using
834/// the low 4 bits of `imm`.
835///
836/// Each 128-bit chunk is treated independently (i.e., the value for
837/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
838/// 128-bit blocks of `left` and `right`).
839fn conditional_dot_product<'tcx>(
840    ecx: &mut crate::MiriInterpCx<'tcx>,
841    left: &OpTy<'tcx>,
842    right: &OpTy<'tcx>,
843    imm: &OpTy<'tcx>,
844    dest: &MPlaceTy<'tcx>,
845) -> InterpResult<'tcx, ()> {
846    assert_eq!(left.layout, dest.layout);
847    assert_eq!(right.layout, dest.layout);
848
849    let (num_chunks, items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
850    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
851    let (_, _, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
852
853    let element_layout = left.layout.field(ecx, 0).field(ecx, 0);
854    assert!(items_per_chunk <= 4);
855
856    // `imm` is a `u8` for SSE4.1 or an `i32` for AVX :/
857    let imm = ecx.read_scalar(imm)?.to_uint(imm.layout.size)?;
858
859    for i in 0..num_chunks {
860        let left = ecx.project_index(&left, i)?;
861        let right = ecx.project_index(&right, i)?;
862        let dest = ecx.project_index(&dest, i)?;
863
864        // Calculate dot product
865        // Elements are floating point numbers, but we can use `from_int`
866        // for the initial value because the representation of 0.0 is all zero bits.
867        let mut sum = ImmTy::from_int(0u8, element_layout);
868        for j in 0..items_per_chunk {
869            if imm & (1 << j.strict_add(4)) != 0 {
870                let left = ecx.read_immediate(&ecx.project_index(&left, j)?)?;
871                let right = ecx.read_immediate(&ecx.project_index(&right, j)?)?;
872
873                let mul = ecx.binary_op(mir::BinOp::Mul, &left, &right)?;
874                sum = ecx.binary_op(mir::BinOp::Add, &sum, &mul)?;
875            }
876        }
877
878        // Write to destination (conditioned to imm)
879        for j in 0..items_per_chunk {
880            let dest = ecx.project_index(&dest, j)?;
881
882            if imm & (1 << j) != 0 {
883                ecx.write_immediate(*sum, &dest)?;
884            } else {
885                ecx.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
886            }
887        }
888    }
889
890    interp_ok(())
891}
892
893/// Calculates two booleans.
894///
895/// The first is true when all the bits of `op & mask` are zero.
896/// The second is true when `(op & mask) == mask`
897fn test_bits_masked<'tcx>(
898    ecx: &crate::MiriInterpCx<'tcx>,
899    op: &OpTy<'tcx>,
900    mask: &OpTy<'tcx>,
901) -> InterpResult<'tcx, (bool, bool)> {
902    assert_eq!(op.layout, mask.layout);
903
904    let (op, op_len) = ecx.project_to_simd(op)?;
905    let (mask, mask_len) = ecx.project_to_simd(mask)?;
906
907    assert_eq!(op_len, mask_len);
908
909    let mut all_zero = true;
910    let mut masked_set = true;
911    for i in 0..op_len {
912        let op = ecx.project_index(&op, i)?;
913        let mask = ecx.project_index(&mask, i)?;
914
915        let op = ecx.read_scalar(&op)?.to_uint(op.layout.size)?;
916        let mask = ecx.read_scalar(&mask)?.to_uint(mask.layout.size)?;
917        all_zero &= (op & mask) == 0;
918        masked_set &= (op & mask) == mask;
919    }
920
921    interp_ok((all_zero, masked_set))
922}
923
924/// Calculates two booleans.
925///
926/// The first is true when the highest bit of each element of `op & mask` is zero.
927/// The second is true when the highest bit of each element of `!op & mask` is zero.
928fn test_high_bits_masked<'tcx>(
929    ecx: &crate::MiriInterpCx<'tcx>,
930    op: &OpTy<'tcx>,
931    mask: &OpTy<'tcx>,
932) -> InterpResult<'tcx, (bool, bool)> {
933    assert_eq!(op.layout, mask.layout);
934
935    let (op, op_len) = ecx.project_to_simd(op)?;
936    let (mask, mask_len) = ecx.project_to_simd(mask)?;
937
938    assert_eq!(op_len, mask_len);
939
940    let high_bit_offset = op.layout.field(ecx, 0).size.bits().strict_sub(1);
941
942    let mut direct = true;
943    let mut negated = true;
944    for i in 0..op_len {
945        let op = ecx.project_index(&op, i)?;
946        let mask = ecx.project_index(&mask, i)?;
947
948        let op = ecx.read_scalar(&op)?.to_uint(op.layout.size)?;
949        let mask = ecx.read_scalar(&mask)?.to_uint(mask.layout.size)?;
950        direct &= (op & mask) >> high_bit_offset == 0;
951        negated &= (!op & mask) >> high_bit_offset == 0;
952    }
953
954    interp_ok((direct, negated))
955}
956
957/// Conditionally loads from `ptr` according the high bit of each
958/// element of `mask`. `ptr` does not need to be aligned.
959fn mask_load<'tcx>(
960    ecx: &mut crate::MiriInterpCx<'tcx>,
961    ptr: &OpTy<'tcx>,
962    mask: &OpTy<'tcx>,
963    dest: &MPlaceTy<'tcx>,
964) -> InterpResult<'tcx, ()> {
965    let (mask, mask_len) = ecx.project_to_simd(mask)?;
966    let (dest, dest_len) = ecx.project_to_simd(dest)?;
967
968    assert_eq!(dest_len, mask_len);
969
970    let mask_item_size = mask.layout.field(ecx, 0).size;
971    let high_bit_offset = mask_item_size.bits().strict_sub(1);
972
973    let ptr = ecx.read_pointer(ptr)?;
974    for i in 0..dest_len {
975        let mask = ecx.project_index(&mask, i)?;
976        let dest = ecx.project_index(&dest, i)?;
977
978        if ecx.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 {
979            let ptr = ptr.wrapping_offset(dest.layout.size * i, &ecx.tcx);
980            // Unaligned copy, which is what we want.
981            ecx.mem_copy(ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?;
982        } else {
983            ecx.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
984        }
985    }
986
987    interp_ok(())
988}
989
990/// Conditionally stores into `ptr` according the high bit of each
991/// element of `mask`. `ptr` does not need to be aligned.
992fn mask_store<'tcx>(
993    ecx: &mut crate::MiriInterpCx<'tcx>,
994    ptr: &OpTy<'tcx>,
995    mask: &OpTy<'tcx>,
996    value: &OpTy<'tcx>,
997) -> InterpResult<'tcx, ()> {
998    let (mask, mask_len) = ecx.project_to_simd(mask)?;
999    let (value, value_len) = ecx.project_to_simd(value)?;
1000
1001    assert_eq!(value_len, mask_len);
1002
1003    let mask_item_size = mask.layout.field(ecx, 0).size;
1004    let high_bit_offset = mask_item_size.bits().strict_sub(1);
1005
1006    let ptr = ecx.read_pointer(ptr)?;
1007    for i in 0..value_len {
1008        let mask = ecx.project_index(&mask, i)?;
1009        let value = ecx.project_index(&value, i)?;
1010
1011        if ecx.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 {
1012            // *Non-inbounds* pointer arithmetic to compute the destination.
1013            // (That's why we can't use a place projection.)
1014            let ptr = ptr.wrapping_offset(value.layout.size * i, &ecx.tcx);
1015            // Deref the pointer *unaligned*, and do the copy.
1016            let dest = ecx.ptr_to_mplace_unaligned(ptr, value.layout);
1017            ecx.copy_op(&value, &dest)?;
1018        }
1019    }
1020
1021    interp_ok(())
1022}
1023
1024/// Compute the sum of absolute differences of quadruplets of unsigned
1025/// 8-bit integers in `left` and `right`, and store the 16-bit results
1026/// in `right`. Quadruplets are selected from `left` and `right` with
1027/// offsets specified in `imm`.
1028///
1029/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_maddubs_epi16>
1030/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mpsadbw_epu8>
1031///
1032/// Each 128-bit chunk is treated independently (i.e., the value for
1033/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1034/// 128-bit chunks of `left` and `right`).
1035fn mpsadbw<'tcx>(
1036    ecx: &mut crate::MiriInterpCx<'tcx>,
1037    left: &OpTy<'tcx>,
1038    right: &OpTy<'tcx>,
1039    imm: &OpTy<'tcx>,
1040    dest: &MPlaceTy<'tcx>,
1041) -> InterpResult<'tcx, ()> {
1042    assert_eq!(left.layout, right.layout);
1043    assert_eq!(left.layout.size, dest.layout.size);
1044
1045    let (num_chunks, op_items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
1046    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
1047    let (_, dest_items_per_chunk, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
1048
1049    assert_eq!(op_items_per_chunk, dest_items_per_chunk.strict_mul(2));
1050
1051    let imm = ecx.read_scalar(imm)?.to_uint(imm.layout.size)?;
1052    // Bit 2 of `imm` specifies the offset for indices of `left`.
1053    // The offset is 0 when the bit is 0 or 4 when the bit is 1.
1054    let left_offset = u64::try_from((imm >> 2) & 1).unwrap().strict_mul(4);
1055    // Bits 0..=1 of `imm` specify the offset for indices of
1056    // `right` in blocks of 4 elements.
1057    let right_offset = u64::try_from(imm & 0b11).unwrap().strict_mul(4);
1058
1059    for i in 0..num_chunks {
1060        let left = ecx.project_index(&left, i)?;
1061        let right = ecx.project_index(&right, i)?;
1062        let dest = ecx.project_index(&dest, i)?;
1063
1064        for j in 0..dest_items_per_chunk {
1065            let left_offset = left_offset.strict_add(j);
1066            let mut res: u16 = 0;
1067            for k in 0..4 {
1068                let left = ecx
1069                    .read_scalar(&ecx.project_index(&left, left_offset.strict_add(k))?)?
1070                    .to_u8()?;
1071                let right = ecx
1072                    .read_scalar(&ecx.project_index(&right, right_offset.strict_add(k))?)?
1073                    .to_u8()?;
1074                res = res.strict_add(left.abs_diff(right).into());
1075            }
1076            ecx.write_scalar(Scalar::from_u16(res), &ecx.project_index(&dest, j)?)?;
1077        }
1078    }
1079
1080    interp_ok(())
1081}
1082
1083/// Multiplies packed 16-bit signed integer values, truncates the 32-bit
1084/// product to the 18 most significant bits by right-shifting, and then
1085/// divides the 18-bit value by 2 (rounding to nearest) by first adding
1086/// 1 and then taking the bits `1..=16`.
1087///
1088/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_mulhrs_epi16>
1089/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mulhrs_epi16>
1090fn pmulhrsw<'tcx>(
1091    ecx: &mut crate::MiriInterpCx<'tcx>,
1092    left: &OpTy<'tcx>,
1093    right: &OpTy<'tcx>,
1094    dest: &MPlaceTy<'tcx>,
1095) -> InterpResult<'tcx, ()> {
1096    let (left, left_len) = ecx.project_to_simd(left)?;
1097    let (right, right_len) = ecx.project_to_simd(right)?;
1098    let (dest, dest_len) = ecx.project_to_simd(dest)?;
1099
1100    assert_eq!(dest_len, left_len);
1101    assert_eq!(dest_len, right_len);
1102
1103    for i in 0..dest_len {
1104        let left = ecx.read_scalar(&ecx.project_index(&left, i)?)?.to_i16()?;
1105        let right = ecx.read_scalar(&ecx.project_index(&right, i)?)?.to_i16()?;
1106        let dest = ecx.project_index(&dest, i)?;
1107
1108        let res = (i32::from(left).strict_mul(right.into()) >> 14).strict_add(1) >> 1;
1109
1110        // The result of this operation can overflow a signed 16-bit integer.
1111        // When `left` and `right` are -0x8000, the result is 0x8000.
1112        #[expect(clippy::as_conversions)]
1113        let res = res as i16;
1114
1115        ecx.write_scalar(Scalar::from_i16(res), &dest)?;
1116    }
1117
1118    interp_ok(())
1119}
1120
1121/// Perform a carry-less multiplication of two 64-bit integers, selected from `left` and `right` according to `imm8`,
1122/// and store the results in `dst`.
1123///
1124/// `left` and `right` are both vectors of type `len` x i64. Only bits 0 and 4 of `imm8` matter;
1125/// they select the element of `left` and `right`, respectively.
1126///
1127/// `len` is the SIMD vector length (in counts of `i64` values). It is expected to be one of
1128/// `2`, `4`, or `8`.
1129///
1130/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_clmulepi64_si128>
1131fn pclmulqdq<'tcx>(
1132    ecx: &mut MiriInterpCx<'tcx>,
1133    left: &OpTy<'tcx>,
1134    right: &OpTy<'tcx>,
1135    imm8: &OpTy<'tcx>,
1136    dest: &MPlaceTy<'tcx>,
1137    len: u64,
1138) -> InterpResult<'tcx, ()> {
1139    assert_eq!(left.layout, right.layout);
1140    assert_eq!(left.layout.size, dest.layout.size);
1141    assert!([2u64, 4, 8].contains(&len));
1142
1143    // Transmute the input into arrays of `[u64; len]`.
1144    // Transmute the output into an array of `[u128, len / 2]`.
1145
1146    let src_layout = ecx.layout_of(Ty::new_array(ecx.tcx.tcx, ecx.tcx.types.u64, len))?;
1147    let dest_layout = ecx.layout_of(Ty::new_array(ecx.tcx.tcx, ecx.tcx.types.u128, len / 2))?;
1148
1149    let left = left.transmute(src_layout, ecx)?;
1150    let right = right.transmute(src_layout, ecx)?;
1151    let dest = dest.transmute(dest_layout, ecx)?;
1152
1153    let imm8 = ecx.read_scalar(imm8)?.to_u8()?;
1154
1155    for i in 0..(len / 2) {
1156        let lo = i.strict_mul(2);
1157        let hi = i.strict_mul(2).strict_add(1);
1158
1159        // select the 64-bit integer from left that the user specified (low or high)
1160        let index = if (imm8 & 0x01) == 0 { lo } else { hi };
1161        let left = ecx.read_scalar(&ecx.project_index(&left, index)?)?.to_u64()?;
1162
1163        // select the 64-bit integer from right that the user specified (low or high)
1164        let index = if (imm8 & 0x10) == 0 { lo } else { hi };
1165        let right = ecx.read_scalar(&ecx.project_index(&right, index)?)?.to_u64()?;
1166
1167        // Perform carry-less multiplication.
1168        //
1169        // This operation is like long multiplication, but ignores all carries.
1170        // That idea corresponds to the xor operator, which is used in the implementation.
1171        //
1172        // Wikipedia has an example https://en.wikipedia.org/wiki/Carry-less_product#Example
1173        let mut result: u128 = 0;
1174
1175        for i in 0..64 {
1176            // if the i-th bit in right is set
1177            if (right & (1 << i)) != 0 {
1178                // xor result with `left` shifted to the left by i positions
1179                result ^= u128::from(left) << i;
1180            }
1181        }
1182
1183        let dest = ecx.project_index(&dest, i)?;
1184        ecx.write_scalar(Scalar::from_u128(result), &dest)?;
1185    }
1186
1187    interp_ok(())
1188}
1189
1190/// Packs two N-bit integer vectors to a single N/2-bit integers.
1191///
1192/// The conversion from N-bit to N/2-bit should be provided by `f`.
1193///
1194/// Each 128-bit chunk is treated independently (i.e., the value for
1195/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1196/// 128-bit chunks of `left` and `right`).
1197fn pack_generic<'tcx>(
1198    ecx: &mut crate::MiriInterpCx<'tcx>,
1199    left: &OpTy<'tcx>,
1200    right: &OpTy<'tcx>,
1201    dest: &MPlaceTy<'tcx>,
1202    f: impl Fn(Scalar) -> InterpResult<'tcx, Scalar>,
1203) -> InterpResult<'tcx, ()> {
1204    assert_eq!(left.layout, right.layout);
1205    assert_eq!(left.layout.size, dest.layout.size);
1206
1207    let (num_chunks, op_items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
1208    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
1209    let (_, dest_items_per_chunk, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
1210
1211    assert_eq!(dest_items_per_chunk, op_items_per_chunk.strict_mul(2));
1212
1213    for i in 0..num_chunks {
1214        let left = ecx.project_index(&left, i)?;
1215        let right = ecx.project_index(&right, i)?;
1216        let dest = ecx.project_index(&dest, i)?;
1217
1218        for j in 0..op_items_per_chunk {
1219            let left = ecx.read_scalar(&ecx.project_index(&left, j)?)?;
1220            let right = ecx.read_scalar(&ecx.project_index(&right, j)?)?;
1221            let left_dest = ecx.project_index(&dest, j)?;
1222            let right_dest = ecx.project_index(&dest, j.strict_add(op_items_per_chunk))?;
1223
1224            let left_res = f(left)?;
1225            let right_res = f(right)?;
1226
1227            ecx.write_scalar(left_res, &left_dest)?;
1228            ecx.write_scalar(right_res, &right_dest)?;
1229        }
1230    }
1231
1232    interp_ok(())
1233}
1234
1235/// Converts two 16-bit integer vectors to a single 8-bit integer
1236/// vector with signed saturation.
1237///
1238/// Each 128-bit chunk is treated independently (i.e., the value for
1239/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1240/// 128-bit chunks of `left` and `right`).
1241fn packsswb<'tcx>(
1242    ecx: &mut crate::MiriInterpCx<'tcx>,
1243    left: &OpTy<'tcx>,
1244    right: &OpTy<'tcx>,
1245    dest: &MPlaceTy<'tcx>,
1246) -> InterpResult<'tcx, ()> {
1247    pack_generic(ecx, left, right, dest, |op| {
1248        let op = op.to_i16()?;
1249        let res = i8::try_from(op).unwrap_or(if op < 0 { i8::MIN } else { i8::MAX });
1250        interp_ok(Scalar::from_i8(res))
1251    })
1252}
1253
1254/// Converts two 16-bit signed integer vectors to a single 8-bit
1255/// unsigned integer vector with saturation.
1256///
1257/// Each 128-bit chunk is treated independently (i.e., the value for
1258/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1259/// 128-bit chunks of `left` and `right`).
1260fn packuswb<'tcx>(
1261    ecx: &mut crate::MiriInterpCx<'tcx>,
1262    left: &OpTy<'tcx>,
1263    right: &OpTy<'tcx>,
1264    dest: &MPlaceTy<'tcx>,
1265) -> InterpResult<'tcx, ()> {
1266    pack_generic(ecx, left, right, dest, |op| {
1267        let op = op.to_i16()?;
1268        let res = u8::try_from(op).unwrap_or(if op < 0 { 0 } else { u8::MAX });
1269        interp_ok(Scalar::from_u8(res))
1270    })
1271}
1272
1273/// Converts two 32-bit integer vectors to a single 16-bit integer
1274/// vector with signed saturation.
1275///
1276/// Each 128-bit chunk is treated independently (i.e., the value for
1277/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1278/// 128-bit chunks of `left` and `right`).
1279fn packssdw<'tcx>(
1280    ecx: &mut crate::MiriInterpCx<'tcx>,
1281    left: &OpTy<'tcx>,
1282    right: &OpTy<'tcx>,
1283    dest: &MPlaceTy<'tcx>,
1284) -> InterpResult<'tcx, ()> {
1285    pack_generic(ecx, left, right, dest, |op| {
1286        let op = op.to_i32()?;
1287        let res = i16::try_from(op).unwrap_or(if op < 0 { i16::MIN } else { i16::MAX });
1288        interp_ok(Scalar::from_i16(res))
1289    })
1290}
1291
1292/// Converts two 32-bit integer vectors to a single 16-bit integer
1293/// vector with unsigned saturation.
1294///
1295/// Each 128-bit chunk is treated independently (i.e., the value for
1296/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1297/// 128-bit chunks of `left` and `right`).
1298fn packusdw<'tcx>(
1299    ecx: &mut crate::MiriInterpCx<'tcx>,
1300    left: &OpTy<'tcx>,
1301    right: &OpTy<'tcx>,
1302    dest: &MPlaceTy<'tcx>,
1303) -> InterpResult<'tcx, ()> {
1304    pack_generic(ecx, left, right, dest, |op| {
1305        let op = op.to_i32()?;
1306        let res = u16::try_from(op).unwrap_or(if op < 0 { 0 } else { u16::MAX });
1307        interp_ok(Scalar::from_u16(res))
1308    })
1309}
1310
1311/// Negates elements from `left` when the corresponding element in
1312/// `right` is negative. If an element from `right` is zero, zero
1313/// is written to the corresponding output element.
1314/// In other words, multiplies `left` with `right.signum()`.
1315fn psign<'tcx>(
1316    ecx: &mut crate::MiriInterpCx<'tcx>,
1317    left: &OpTy<'tcx>,
1318    right: &OpTy<'tcx>,
1319    dest: &MPlaceTy<'tcx>,
1320) -> InterpResult<'tcx, ()> {
1321    let (left, left_len) = ecx.project_to_simd(left)?;
1322    let (right, right_len) = ecx.project_to_simd(right)?;
1323    let (dest, dest_len) = ecx.project_to_simd(dest)?;
1324
1325    assert_eq!(dest_len, left_len);
1326    assert_eq!(dest_len, right_len);
1327
1328    for i in 0..dest_len {
1329        let dest = ecx.project_index(&dest, i)?;
1330        let left = ecx.read_immediate(&ecx.project_index(&left, i)?)?;
1331        let right = ecx.read_scalar(&ecx.project_index(&right, i)?)?.to_int(dest.layout.size)?;
1332
1333        let res =
1334            ecx.binary_op(mir::BinOp::Mul, &left, &ImmTy::from_int(right.signum(), dest.layout))?;
1335
1336        ecx.write_immediate(*res, &dest)?;
1337    }
1338
1339    interp_ok(())
1340}
1341
1342/// Calcultates either `a + b + cb_in` or `a - b - cb_in` depending on the value
1343/// of `op` and returns both the sum and the overflow bit. `op` is expected to be
1344/// either one of `mir::BinOp::AddWithOverflow` and `mir::BinOp::SubWithOverflow`.
1345fn carrying_add<'tcx>(
1346    ecx: &mut crate::MiriInterpCx<'tcx>,
1347    cb_in: &OpTy<'tcx>,
1348    a: &OpTy<'tcx>,
1349    b: &OpTy<'tcx>,
1350    op: mir::BinOp,
1351) -> InterpResult<'tcx, (ImmTy<'tcx>, Scalar)> {
1352    assert!(op == mir::BinOp::AddWithOverflow || op == mir::BinOp::SubWithOverflow);
1353
1354    let cb_in = ecx.read_scalar(cb_in)?.to_u8()? != 0;
1355    let a = ecx.read_immediate(a)?;
1356    let b = ecx.read_immediate(b)?;
1357
1358    let (sum, overflow1) = ecx.binary_op(op, &a, &b)?.to_pair(ecx);
1359    let (sum, overflow2) =
1360        ecx.binary_op(op, &sum, &ImmTy::from_uint(cb_in, a.layout))?.to_pair(ecx);
1361    let cb_out = overflow1.to_scalar().to_bool()? | overflow2.to_scalar().to_bool()?;
1362
1363    interp_ok((sum, Scalar::from_u8(cb_out.into())))
1364}