miri/shims/x86/sse2.rs
1use rustc_abi::CanonAbi;
2use rustc_apfloat::ieee::Double;
3use rustc_middle::ty::Ty;
4use rustc_span::Symbol;
5use rustc_target::callconv::FnAbi;
6
7use super::{
8 FloatBinOp, ShiftOp, bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int,
9 packssdw, packsswb, packuswb, shift_simd_by_scalar,
10};
11use crate::*;
12
13impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
14pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
15 fn emulate_x86_sse2_intrinsic(
16 &mut self,
17 link_name: Symbol,
18 abi: &FnAbi<'tcx, Ty<'tcx>>,
19 args: &[OpTy<'tcx>],
20 dest: &MPlaceTy<'tcx>,
21 ) -> InterpResult<'tcx, EmulateItemResult> {
22 let this = self.eval_context_mut();
23 this.expect_target_feature_for_intrinsic(link_name, "sse2")?;
24 // Prefix should have already been checked.
25 let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.sse2.").unwrap();
26
27 // These intrinsics operate on 128-bit (f32x4, f64x2, i8x16, i16x8, i32x4, i64x2) SIMD
28 // vectors unless stated otherwise.
29 // Many intrinsic names are sufixed with "ps" (packed single), "ss" (scalar signle),
30 // "pd" (packed double) or "sd" (scalar double), where single means single precision
31 // floating point (f32) and double means double precision floating point (f64). "ps"
32 // and "pd" means thet the operation is performed on each element of the vector, while
33 // "ss" and "sd" means that the operation is performed only on the first element, copying
34 // the remaining elements from the input vector (for binary operations, from the left-hand
35 // side).
36 // Intrinsincs sufixed with "epiX" or "epuX" operate with X-bit signed or unsigned
37 // vectors.
38 match unprefixed_name {
39 // Used to implement the _mm_madd_epi16 function.
40 // Multiplies packed signed 16-bit integers in `left` and `right`, producing
41 // intermediate signed 32-bit integers. Horizontally add adjacent pairs of
42 // intermediate 32-bit integers, and pack the results in `dest`.
43 "pmadd.wd" => {
44 let [left, right] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
45
46 let (left, left_len) = this.project_to_simd(left)?;
47 let (right, right_len) = this.project_to_simd(right)?;
48 let (dest, dest_len) = this.project_to_simd(dest)?;
49
50 assert_eq!(left_len, right_len);
51 assert_eq!(dest_len.strict_mul(2), left_len);
52
53 for i in 0..dest_len {
54 let j1 = i.strict_mul(2);
55 let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_i16()?;
56 let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i16()?;
57
58 let j2 = j1.strict_add(1);
59 let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_i16()?;
60 let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i16()?;
61
62 let dest = this.project_index(&dest, i)?;
63
64 // Multiplications are i16*i16->i32, which will not overflow.
65 let mul1 = i32::from(left1).strict_mul(right1.into());
66 let mul2 = i32::from(left2).strict_mul(right2.into());
67 // However, this addition can overflow in the most extreme case
68 // (-0x8000)*(-0x8000)+(-0x8000)*(-0x8000) = 0x80000000
69 let res = mul1.wrapping_add(mul2);
70
71 this.write_scalar(Scalar::from_i32(res), &dest)?;
72 }
73 }
74 // Used to implement the _mm_sad_epu8 function.
75 // Computes the absolute differences of packed unsigned 8-bit integers in `a`
76 // and `b`, then horizontally sum each consecutive 8 differences to produce
77 // two unsigned 16-bit integers, and pack these unsigned 16-bit integers in
78 // the low 16 bits of 64-bit elements returned.
79 //
80 // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_sad_epu8
81 "psad.bw" => {
82 let [left, right] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
83
84 let (left, left_len) = this.project_to_simd(left)?;
85 let (right, right_len) = this.project_to_simd(right)?;
86 let (dest, dest_len) = this.project_to_simd(dest)?;
87
88 // left and right are u8x16, dest is u64x2
89 assert_eq!(left_len, right_len);
90 assert_eq!(left_len, 16);
91 assert_eq!(dest_len, 2);
92
93 for i in 0..dest_len {
94 let dest = this.project_index(&dest, i)?;
95
96 let mut res: u16 = 0;
97 let n = left_len.strict_div(dest_len);
98 for j in 0..n {
99 let op_i = j.strict_add(i.strict_mul(n));
100 let left = this.read_scalar(&this.project_index(&left, op_i)?)?.to_u8()?;
101 let right =
102 this.read_scalar(&this.project_index(&right, op_i)?)?.to_u8()?;
103
104 res = res.strict_add(left.abs_diff(right).into());
105 }
106
107 this.write_scalar(Scalar::from_u64(res.into()), &dest)?;
108 }
109 }
110 // Used to implement the _mm_{sll,srl,sra}_epi{16,32,64} functions
111 // (except _mm_sra_epi64, which is not available in SSE2).
112 // Shifts N-bit packed integers in left by the amount in right.
113 // Both operands are 128-bit vectors. However, right is interpreted as
114 // a single 64-bit integer (remaining bits are ignored).
115 // For logic shifts, when right is larger than N - 1, zero is produced.
116 // For arithmetic shifts, when right is larger than N - 1, the sign bit
117 // is copied to remaining bits.
118 "psll.w" | "psrl.w" | "psra.w" | "psll.d" | "psrl.d" | "psra.d" | "psll.q"
119 | "psrl.q" => {
120 let [left, right] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
121
122 let which = match unprefixed_name {
123 "psll.w" | "psll.d" | "psll.q" => ShiftOp::Left,
124 "psrl.w" | "psrl.d" | "psrl.q" => ShiftOp::RightLogic,
125 "psra.w" | "psra.d" => ShiftOp::RightArith,
126 _ => unreachable!(),
127 };
128
129 shift_simd_by_scalar(this, left, right, which, dest)?;
130 }
131 // Used to implement the _mm_cvtps_epi32, _mm_cvttps_epi32, _mm_cvtpd_epi32
132 // and _mm_cvttpd_epi32 functions.
133 // Converts packed f32/f64 to packed i32.
134 "cvtps2dq" | "cvttps2dq" | "cvtpd2dq" | "cvttpd2dq" => {
135 let [op] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
136
137 let (op_len, _) = op.layout.ty.simd_size_and_type(*this.tcx);
138 let (dest_len, _) = dest.layout.ty.simd_size_and_type(*this.tcx);
139 match unprefixed_name {
140 "cvtps2dq" | "cvttps2dq" => {
141 // f32x4 to i32x4 conversion
142 assert_eq!(op_len, 4);
143 assert_eq!(dest_len, op_len);
144 }
145 "cvtpd2dq" | "cvttpd2dq" => {
146 // f64x2 to i32x4 conversion
147 // the last two values are filled with zeros
148 assert_eq!(op_len, 2);
149 assert_eq!(dest_len, 4);
150 }
151 _ => unreachable!(),
152 }
153
154 let rnd = match unprefixed_name {
155 // "current SSE rounding mode", assume nearest
156 // https://www.felixcloutier.com/x86/cvtps2dq
157 // https://www.felixcloutier.com/x86/cvtpd2dq
158 "cvtps2dq" | "cvtpd2dq" => rustc_apfloat::Round::NearestTiesToEven,
159 // always truncate
160 // https://www.felixcloutier.com/x86/cvttps2dq
161 // https://www.felixcloutier.com/x86/cvttpd2dq
162 "cvttps2dq" | "cvttpd2dq" => rustc_apfloat::Round::TowardZero,
163 _ => unreachable!(),
164 };
165
166 convert_float_to_int(this, op, rnd, dest)?;
167 }
168 // Used to implement the _mm_packs_epi16 function.
169 // Converts two 16-bit integer vectors to a single 8-bit integer
170 // vector with signed saturation.
171 "packsswb.128" => {
172 let [left, right] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
173
174 packsswb(this, left, right, dest)?;
175 }
176 // Used to implement the _mm_packus_epi16 function.
177 // Converts two 16-bit signed integer vectors to a single 8-bit
178 // unsigned integer vector with saturation.
179 "packuswb.128" => {
180 let [left, right] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
181
182 packuswb(this, left, right, dest)?;
183 }
184 // Used to implement the _mm_packs_epi32 function.
185 // Converts two 32-bit integer vectors to a single 16-bit integer
186 // vector with signed saturation.
187 "packssdw.128" => {
188 let [left, right] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
189
190 packssdw(this, left, right, dest)?;
191 }
192 // Used to implement _mm_min_sd and _mm_max_sd functions.
193 // Note that the semantics are a bit different from Rust simd_min
194 // and simd_max intrinsics regarding handling of NaN and -0.0: Rust
195 // matches the IEEE min/max operations, while x86 has different
196 // semantics.
197 "min.sd" | "max.sd" => {
198 let [left, right] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
199
200 let which = match unprefixed_name {
201 "min.sd" => FloatBinOp::Min,
202 "max.sd" => FloatBinOp::Max,
203 _ => unreachable!(),
204 };
205
206 bin_op_simd_float_first::<Double>(this, which, left, right, dest)?;
207 }
208 // Used to implement _mm_min_pd and _mm_max_pd functions.
209 // Note that the semantics are a bit different from Rust simd_min
210 // and simd_max intrinsics regarding handling of NaN and -0.0: Rust
211 // matches the IEEE min/max operations, while x86 has different
212 // semantics.
213 "min.pd" | "max.pd" => {
214 let [left, right] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
215
216 let which = match unprefixed_name {
217 "min.pd" => FloatBinOp::Min,
218 "max.pd" => FloatBinOp::Max,
219 _ => unreachable!(),
220 };
221
222 bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
223 }
224 // Used to implement the _mm_cmp*_sd functions.
225 // Performs a comparison operation on the first component of `left`
226 // and `right`, returning 0 if false or `u64::MAX` if true. The remaining
227 // components are copied from `left`.
228 // _mm_cmp_sd is actually an AVX function where the operation is specified
229 // by a const parameter.
230 // _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_sd are SSE2 functions
231 // with hard-coded operations.
232 "cmp.sd" => {
233 let [left, right, imm] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
234
235 let which =
236 FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
237
238 bin_op_simd_float_first::<Double>(this, which, left, right, dest)?;
239 }
240 // Used to implement the _mm_cmp*_pd functions.
241 // Performs a comparison operation on each component of `left`
242 // and `right`. For each component, returns 0 if false or `u64::MAX`
243 // if true.
244 // _mm_cmp_pd is actually an AVX function where the operation is specified
245 // by a const parameter.
246 // _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_pd are SSE2 functions
247 // with hard-coded operations.
248 "cmp.pd" => {
249 let [left, right, imm] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
250
251 let which =
252 FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
253
254 bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
255 }
256 // Used to implement _mm_{,u}comi{eq,lt,le,gt,ge,neq}_sd functions.
257 // Compares the first component of `left` and `right` and returns
258 // a scalar value (0 or 1).
259 "comieq.sd" | "comilt.sd" | "comile.sd" | "comigt.sd" | "comige.sd" | "comineq.sd"
260 | "ucomieq.sd" | "ucomilt.sd" | "ucomile.sd" | "ucomigt.sd" | "ucomige.sd"
261 | "ucomineq.sd" => {
262 let [left, right] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
263
264 let (left, left_len) = this.project_to_simd(left)?;
265 let (right, right_len) = this.project_to_simd(right)?;
266
267 assert_eq!(left_len, right_len);
268
269 let left = this.read_scalar(&this.project_index(&left, 0)?)?.to_f64()?;
270 let right = this.read_scalar(&this.project_index(&right, 0)?)?.to_f64()?;
271 // The difference between the com* and ucom* variants is signaling
272 // of exceptions when either argument is a quiet NaN. We do not
273 // support accessing the SSE status register from miri (or from Rust,
274 // for that matter), so we treat both variants equally.
275 let res = match unprefixed_name {
276 "comieq.sd" | "ucomieq.sd" => left == right,
277 "comilt.sd" | "ucomilt.sd" => left < right,
278 "comile.sd" | "ucomile.sd" => left <= right,
279 "comigt.sd" | "ucomigt.sd" => left > right,
280 "comige.sd" | "ucomige.sd" => left >= right,
281 "comineq.sd" | "ucomineq.sd" => left != right,
282 _ => unreachable!(),
283 };
284 this.write_scalar(Scalar::from_i32(i32::from(res)), dest)?;
285 }
286 // Use to implement the _mm_cvtsd_si32, _mm_cvttsd_si32,
287 // _mm_cvtsd_si64 and _mm_cvttsd_si64 functions.
288 // Converts the first component of `op` from f64 to i32/i64.
289 "cvtsd2si" | "cvttsd2si" | "cvtsd2si64" | "cvttsd2si64" => {
290 let [op] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
291 let (op, _) = this.project_to_simd(op)?;
292
293 let op = this.read_immediate(&this.project_index(&op, 0)?)?;
294
295 let rnd = match unprefixed_name {
296 // "current SSE rounding mode", assume nearest
297 // https://www.felixcloutier.com/x86/cvtsd2si
298 "cvtsd2si" | "cvtsd2si64" => rustc_apfloat::Round::NearestTiesToEven,
299 // always truncate
300 // https://www.felixcloutier.com/x86/cvttsd2si
301 "cvttsd2si" | "cvttsd2si64" => rustc_apfloat::Round::TowardZero,
302 _ => unreachable!(),
303 };
304
305 let res = this.float_to_int_checked(&op, dest.layout, rnd)?.unwrap_or_else(|| {
306 // Fallback to minimum according to SSE semantics.
307 ImmTy::from_int(dest.layout.size.signed_int_min(), dest.layout)
308 });
309
310 this.write_immediate(*res, dest)?;
311 }
312 // Used to implement the _mm_cvtsd_ss and _mm_cvtss_sd functions.
313 // Converts the first f64/f32 from `right` to f32/f64 and copies
314 // the remaining elements from `left`
315 "cvtsd2ss" | "cvtss2sd" => {
316 let [left, right] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
317
318 let (left, left_len) = this.project_to_simd(left)?;
319 let (right, _) = this.project_to_simd(right)?;
320 let (dest, dest_len) = this.project_to_simd(dest)?;
321
322 assert_eq!(dest_len, left_len);
323
324 // Convert first element of `right`
325 let right0 = this.read_immediate(&this.project_index(&right, 0)?)?;
326 let dest0 = this.project_index(&dest, 0)?;
327 // `float_to_float_or_int` here will convert from f64 to f32 (cvtsd2ss) or
328 // from f32 to f64 (cvtss2sd).
329 let res0 = this.float_to_float_or_int(&right0, dest0.layout)?;
330 this.write_immediate(*res0, &dest0)?;
331
332 // Copy remaining from `left`
333 for i in 1..dest_len {
334 this.copy_op(&this.project_index(&left, i)?, &this.project_index(&dest, i)?)?;
335 }
336 }
337 _ => return interp_ok(EmulateItemResult::NotSupported),
338 }
339 interp_ok(EmulateItemResult::NeedsReturn)
340 }
341}