1use rustc_abi::CanonAbi;
2use rustc_middle::ty::Ty;
3use rustc_span::Symbol;
4use rustc_target::callconv::FnAbi;
5
6use crate::*;
7
8impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
9pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
10 fn emulate_x86_gfni_intrinsic(
11 &mut self,
12 link_name: Symbol,
13 abi: &FnAbi<'tcx, Ty<'tcx>>,
14 args: &[OpTy<'tcx>],
15 dest: &MPlaceTy<'tcx>,
16 ) -> InterpResult<'tcx, EmulateItemResult> {
17 let this = self.eval_context_mut();
18
19 let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.").unwrap();
21
22 this.expect_target_feature_for_intrinsic(link_name, "gfni")?;
23 if unprefixed_name.ends_with(".256") {
24 this.expect_target_feature_for_intrinsic(link_name, "avx")?;
25 } else if unprefixed_name.ends_with(".512") {
26 this.expect_target_feature_for_intrinsic(link_name, "avx512f")?;
27 }
28
29 match unprefixed_name {
30 "vgf2p8affineqb.128" | "vgf2p8affineqb.256" | "vgf2p8affineqb.512" => {
34 let [left, right, imm8] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
35 affine_transform(this, left, right, imm8, dest, false)?;
36 }
37 "vgf2p8affineinvqb.128" | "vgf2p8affineinvqb.256" | "vgf2p8affineinvqb.512" => {
41 let [left, right, imm8] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
42 affine_transform(this, left, right, imm8, dest, true)?;
43 }
44 "vgf2p8mulb.128" | "vgf2p8mulb.256" | "vgf2p8mulb.512" => {
50 let [left, right] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
51 let (left, left_len) = this.project_to_simd(left)?;
52 let (right, right_len) = this.project_to_simd(right)?;
53 let (dest, dest_len) = this.project_to_simd(dest)?;
54
55 assert_eq!(left_len, right_len);
56 assert_eq!(dest_len, right_len);
57
58 for i in 0..dest_len {
59 let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u8()?;
60 let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u8()?;
61 let dest = this.project_index(&dest, i)?;
62 this.write_scalar(Scalar::from_u8(gf2p8_mul(left, right)), &dest)?;
63 }
64 }
65 _ => return interp_ok(EmulateItemResult::NotSupported),
66 }
67 interp_ok(EmulateItemResult::NeedsReturn)
68 }
69}
70
71fn affine_transform<'tcx>(
76 ecx: &mut MiriInterpCx<'tcx>,
77 left: &OpTy<'tcx>,
78 right: &OpTy<'tcx>,
79 imm8: &OpTy<'tcx>,
80 dest: &MPlaceTy<'tcx>,
81 inverse: bool,
82) -> InterpResult<'tcx, ()> {
83 let (left, left_len) = ecx.project_to_simd(left)?;
84 let (right, right_len) = ecx.project_to_simd(right)?;
85 let (dest, dest_len) = ecx.project_to_simd(dest)?;
86
87 assert_eq!(dest_len, right_len);
88 assert_eq!(dest_len, left_len);
89
90 let imm8 = ecx.read_scalar(imm8)?.to_u8()?;
91
92 for i in (0..dest_len).step_by(8) {
95 let mut matrix = [0u8; 8];
97 for j in 0..8 {
98 matrix[usize::try_from(j).unwrap()] =
99 ecx.read_scalar(&ecx.project_index(&right, i.wrapping_add(j))?)?.to_u8()?;
100 }
101
102 for j in 0..8 {
104 let index = i.wrapping_add(j);
105 let left = ecx.read_scalar(&ecx.project_index(&left, index)?)?.to_u8()?;
106 let left = if inverse { TABLE[usize::from(left)] } else { left };
107
108 let mut res = 0;
109
110 for bit in 0u8..8 {
112 let mut b = matrix[usize::from(bit)] & left;
113
114 b = (b & 0b1111) ^ (b >> 4);
116 b = (b & 0b11) ^ (b >> 2);
117 b = (b & 0b1) ^ (b >> 1);
118
119 res |= b << 7u8.wrapping_sub(bit);
120 }
121
122 res ^= imm8;
124
125 let dest = ecx.project_index(&dest, index)?;
126 ecx.write_scalar(Scalar::from_u8(res), &dest)?;
127 }
128 }
129
130 interp_ok(())
131}
132
133static TABLE: [u8; 256] = {
138 let mut array = [0; 256];
139
140 let mut i = 1;
141 while i < 256 {
142 #[expect(clippy::as_conversions)] let mut x = i as u8;
144 let mut y = gf2p8_mul(x, x);
145 x = y;
146 let mut j = 2;
147 while j < 8 {
148 x = gf2p8_mul(x, x);
149 y = gf2p8_mul(x, y);
150 j += 1;
151 }
152 array[i] = y;
153 i += 1;
154 }
155
156 array
157};
158
159#[expect(clippy::as_conversions)]
165const fn gf2p8_mul(left: u8, right: u8) -> u8 {
166 const POLYNOMIAL: u32 = 0x11b;
171
172 let left = left as u32;
173 let right = right as u32;
174
175 let mut result = 0u32;
176
177 let mut i = 0u32;
178 while i < 8 {
179 if left & (1 << i) != 0 {
180 result ^= right << i;
181 }
182 i = i.wrapping_add(1);
183 }
184
185 let mut i = 14u32;
186 while i >= 8 {
187 if result & (1 << i) != 0 {
188 result ^= POLYNOMIAL << i.wrapping_sub(8);
189 }
190 i = i.wrapping_sub(1);
191 }
192
193 result as u8
194}