rustc_codegen_llvm/builder/autodiff.rs
1use std::ptr;
2
3use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
4use rustc_codegen_ssa::ModuleCodegen;
5use rustc_codegen_ssa::common::TypeKind;
6use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
7use rustc_errors::FatalError;
8use rustc_middle::bug;
9use tracing::{debug, trace};
10
11use crate::back::write::llvm_err;
12use crate::builder::{SBuilder, UNNAMED};
13use crate::context::SimpleCx;
14use crate::declare::declare_simple_fn;
15use crate::errors::{AutoDiffWithoutEnable, LlvmError};
16use crate::llvm::AttributePlace::Function;
17use crate::llvm::{Metadata, True};
18use crate::value::Value;
19use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
20
21fn get_params(fnc: &Value) -> Vec<&Value> {
22 let param_num = llvm::LLVMCountParams(fnc) as usize;
23 let mut fnc_args: Vec<&Value> = vec![];
24 fnc_args.reserve(param_num);
25 unsafe {
26 llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr());
27 fnc_args.set_len(param_num);
28 }
29 fnc_args
30}
31
32fn has_sret(fnc: &Value) -> bool {
33 let num_args = llvm::LLVMCountParams(fnc) as usize;
34 if num_args == 0 {
35 false
36 } else {
37 unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) }
38 }
39}
40
41// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
42// original inputs, as well as metadata and the additional shadow arguments.
43// This function matches the arguments from the outer function to the inner enzyme call.
44//
45// This function also considers that Rust level arguments not always match the llvm-ir level
46// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
47// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
48// need to match those.
49// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
50// using iterators and peek()?
51fn match_args_from_caller_to_enzyme<'ll>(
52 cx: &SimpleCx<'ll>,
53 builder: &SBuilder<'ll, 'll>,
54 width: u32,
55 args: &mut Vec<&'ll llvm::Value>,
56 inputs: &[DiffActivity],
57 outer_args: &[&'ll llvm::Value],
58 has_sret: bool,
59) {
60 debug!("matching autodiff arguments");
61 // We now handle the issue that Rust level arguments not always match the llvm-ir level
62 // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
63 // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
64 // need to match those.
65 // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
66 // using iterators and peek()?
67 let mut outer_pos: usize = 0;
68 let mut activity_pos = 0;
69
70 if has_sret {
71 // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
72 // inner function will still return something. We increase our outer_pos by one,
73 // and once we're done with all other args we will take the return of the inner call and
74 // update the sret pointer with it
75 outer_pos = 1;
76 }
77
78 let enzyme_const = cx.create_metadata(b"enzyme_const");
79 let enzyme_out = cx.create_metadata(b"enzyme_out");
80 let enzyme_dup = cx.create_metadata(b"enzyme_dup");
81 let enzyme_dupv = cx.create_metadata(b"enzyme_dupv");
82 let enzyme_dupnoneed = cx.create_metadata(b"enzyme_dupnoneed");
83 let enzyme_dupnoneedv = cx.create_metadata(b"enzyme_dupnoneedv");
84
85 while activity_pos < inputs.len() {
86 let diff_activity = inputs[activity_pos as usize];
87 // Duplicated arguments received a shadow argument, into which enzyme will write the
88 // gradient.
89 let (activity, duplicated): (&Metadata, bool) = match diff_activity {
90 DiffActivity::None => panic!("not a valid input activity"),
91 DiffActivity::Const => (enzyme_const, false),
92 DiffActivity::Active => (enzyme_out, false),
93 DiffActivity::ActiveOnly => (enzyme_out, false),
94 DiffActivity::Dual => (enzyme_dup, true),
95 DiffActivity::Dualv => (enzyme_dupv, true),
96 DiffActivity::DualOnly => (enzyme_dupnoneed, true),
97 DiffActivity::DualvOnly => (enzyme_dupnoneedv, true),
98 DiffActivity::Duplicated => (enzyme_dup, true),
99 DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
100 DiffActivity::FakeActivitySize(_) => (enzyme_const, false),
101 };
102 let outer_arg = outer_args[outer_pos];
103 args.push(cx.get_metadata_value(activity));
104 if matches!(diff_activity, DiffActivity::Dualv) {
105 let next_outer_arg = outer_args[outer_pos + 1];
106 let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
107 DiffActivity::FakeActivitySize(Some(s)) => s.into(),
108 _ => bug!("incorrect Dualv handling recognized."),
109 };
110 // stride: sizeof(T) * n_elems.
111 // n_elems is the next integer.
112 // Now we multiply `4 * next_outer_arg` to get the stride.
113 let mul = unsafe {
114 llvm::LLVMBuildMul(
115 builder.llbuilder,
116 cx.get_const_int(cx.type_i64(), elem_bytes_size),
117 next_outer_arg,
118 UNNAMED,
119 )
120 };
121 args.push(mul);
122 }
123 args.push(outer_arg);
124 if duplicated {
125 // We know that duplicated args by construction have a following argument,
126 // so this can not be out of bounds.
127 let next_outer_arg = outer_args[outer_pos + 1];
128 let next_outer_ty = cx.val_ty(next_outer_arg);
129 // FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since
130 // vectors behind references (&Vec<T>) are already supported. Users can not pass a
131 // Vec by value for reverse mode, so this would only help forward mode autodiff.
132 let slice = {
133 if activity_pos + 1 >= inputs.len() {
134 // If there is no arg following our ptr, it also can't be a slice,
135 // since that would lead to a ptr, int pair.
136 false
137 } else {
138 let next_activity = inputs[activity_pos + 1];
139 // We analyze the MIR types and add this dummy activity if we visit a slice.
140 matches!(next_activity, DiffActivity::FakeActivitySize(_))
141 }
142 };
143 if slice {
144 // A duplicated slice will have the following two outer_fn arguments:
145 // (..., ptr1, int1, ptr2, int2, ...). We add the following llvm-ir to our __enzyme call:
146 // (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
147 // FIXME(ZuseZ4): We will upstream a safety check later which asserts that
148 // int2 >= int1, which means the shadow vector is large enough to store the gradient.
149 assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
150
151 let iterations =
152 if matches!(diff_activity, DiffActivity::Dualv) { 1 } else { width as usize };
153
154 for i in 0..iterations {
155 let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
156 let next_outer_ty2 = cx.val_ty(next_outer_arg2);
157 assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
158 let next_outer_arg3 = outer_args[outer_pos + 2 * (i + 1) + 1];
159 let next_outer_ty3 = cx.val_ty(next_outer_arg3);
160 assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
161 args.push(next_outer_arg2);
162 }
163 args.push(cx.get_metadata_value(enzyme_const));
164 args.push(next_outer_arg);
165 outer_pos += 2 + 2 * iterations;
166 activity_pos += 2;
167 } else {
168 // A duplicated pointer will have the following two outer_fn arguments:
169 // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
170 // (..., metadata! enzyme_dup, ptr, ptr, ...).
171 if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly)
172 {
173 assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Pointer);
174 }
175 // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
176 args.push(next_outer_arg);
177 outer_pos += 2;
178 activity_pos += 1;
179
180 // Now, if width > 1, we need to account for that
181 for _ in 1..width {
182 let next_outer_arg = outer_args[outer_pos];
183 args.push(next_outer_arg);
184 outer_pos += 1;
185 }
186 }
187 } else {
188 // We do not differentiate with resprect to this argument.
189 // We already added the metadata and argument above, so just increase the counters.
190 outer_pos += 1;
191 activity_pos += 1;
192 }
193 }
194}
195
196// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
197// arguments. We do however need to declare them with their correct return type.
198// We already figured the correct return type out in our frontend, when generating the outer_fn,
199// so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
200// Beyond sret, this article describes our challenges nicely:
201// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
202// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
203fn compute_enzyme_fn_ty<'ll>(
204 cx: &SimpleCx<'ll>,
205 attrs: &AutoDiffAttrs,
206 fn_to_diff: &'ll Value,
207 outer_fn: &'ll Value,
208) -> &'ll llvm::Type {
209 let fn_ty = cx.get_type_of_global(outer_fn);
210 let mut ret_ty = cx.get_return_type(fn_ty);
211
212 let has_sret = has_sret(outer_fn);
213
214 if has_sret {
215 // Now we don't just forward the return type, so we have to figure it out based on the
216 // primal return type, in combination with the autodiff settings.
217 let fn_ty = cx.get_type_of_global(fn_to_diff);
218 let inner_ret_ty = cx.get_return_type(fn_ty);
219
220 let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) };
221 if inner_ret_ty == void_ty {
222 // This indicates that even the inner function has an sret.
223 // Right now I only look for an sret in the outer function.
224 // This *probably* needs some extra handling, but I never ran
225 // into such a case. So I'll wait for user reports to have a test case.
226 bug!("sret in inner function");
227 }
228
229 if attrs.width == 1 {
230 // Enzyme returns a struct of style:
231 // `{ original_ret(if requested), float, float, ... }`
232 let mut struct_elements = vec![];
233 if attrs.has_primal_ret() {
234 struct_elements.push(inner_ret_ty);
235 }
236 // Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
237 // and therefore part of the return struct.
238 let param_tys = cx.func_params_types(fn_ty);
239 for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) {
240 if matches!(act, DiffActivity::Active) {
241 // Now find the float type at position i based on the fn_ty,
242 // to know what (f16/f32/f64/...) to add to the struct.
243 struct_elements.push(param_ty);
244 }
245 }
246 ret_ty = cx.type_struct(&struct_elements, false);
247 } else {
248 // First we check if we also have to deal with the primal return.
249 match attrs.mode {
250 DiffMode::Forward => match attrs.ret_activity {
251 DiffActivity::Dual => {
252 let arr_ty =
253 unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) };
254 ret_ty = arr_ty;
255 }
256 DiffActivity::DualOnly => {
257 let arr_ty =
258 unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) };
259 ret_ty = arr_ty;
260 }
261 DiffActivity::Const => {
262 todo!("Not sure, do we need to do something here?");
263 }
264 _ => {
265 bug!("unreachable");
266 }
267 },
268 DiffMode::Reverse => {
269 todo!("Handle sret for reverse mode");
270 }
271 _ => {
272 bug!("unreachable");
273 }
274 }
275 }
276 }
277
278 // LLVM can figure out the input types on it's own, so we take a shortcut here.
279 unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }
280}
281
282/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
283/// function with expected naming and calling conventions[^1] which will be
284/// discovered by the enzyme LLVM pass and its body populated with the differentiated
285/// `fn_to_diff`. `outer_fn` is then modified to have a call to the generated
286/// function and handle the differences between the Rust calling convention and
287/// Enzyme.
288/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
289// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
290// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
291fn generate_enzyme_call<'ll>(
292 cx: &SimpleCx<'ll>,
293 fn_to_diff: &'ll Value,
294 outer_fn: &'ll Value,
295 attrs: AutoDiffAttrs,
296) {
297 // We have to pick the name depending on whether we want forward or reverse mode autodiff.
298 let mut ad_name: String = match attrs.mode {
299 DiffMode::Forward => "__enzyme_fwddiff",
300 DiffMode::Reverse => "__enzyme_autodiff",
301 _ => panic!("logic bug in autodiff, unrecognized mode"),
302 }
303 .to_string();
304
305 // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
306 // functions. Unwrap will only panic, if LLVM gave us an invalid string.
307 let name = llvm::get_value_name(outer_fn);
308 let outer_fn_name = std::str::from_utf8(&name).unwrap();
309 ad_name.push_str(outer_fn_name);
310
311 // Let us assume the user wrote the following function square:
312 //
313 // ```llvm
314 // define double @square(double %x) {
315 // entry:
316 // %0 = fmul double %x, %x
317 // ret double %0
318 // }
319 // ```
320 //
321 // The user now applies autodiff to the function square, in which case fn_to_diff will be `square`.
322 // Our macro generates the following placeholder code (slightly simplified):
323 //
324 // ```llvm
325 // define double @dsquare(double %x) {
326 // ; placeholder code
327 // return 0.0;
328 // }
329 // ```
330 //
331 // so our `outer_fn` will be `dsquare`. The unsafe code section below now removes the placeholder
332 // code and inserts an autodiff call. We also add a declaration for the __enzyme_autodiff call.
333 // Again, the arguments to all functions are slightly simplified.
334 // ```llvm
335 // declare double @__enzyme_autodiff_square(...)
336 //
337 // define double @dsquare(double %x) {
338 // entry:
339 // %0 = tail call double (...) @__enzyme_autodiff_square(double (double)* nonnull @square, double %x)
340 // ret double %0
341 // }
342 // ```
343 unsafe {
344 let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn);
345
346 // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
347 // think a bit more about what should go here.
348 let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
349 let ad_fn = declare_simple_fn(
350 cx,
351 &ad_name,
352 llvm::CallConv::try_from(cc).expect("invalid callconv"),
353 llvm::UnnamedAddr::No,
354 llvm::Visibility::Default,
355 enzyme_ty,
356 );
357
358 // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
359 // do it's work.
360 let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
361 attributes::apply_to_llfn(ad_fn, Function, &[attr]);
362
363 // We add a made-up attribute just such that we can recognize it after AD to update
364 // (no)-inline attributes. We'll then also remove this attribute.
365 let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker");
366 attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]);
367
368 // first, remove all calls from fnc
369 let entry = llvm::LLVMGetFirstBasicBlock(outer_fn);
370 let br = llvm::LLVMRustGetTerminator(entry);
371 llvm::LLVMRustEraseInstFromParent(br);
372
373 let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
374 let mut builder = SBuilder::build(cx, entry);
375
376 let num_args = llvm::LLVMCountParams(&fn_to_diff);
377 let mut args = Vec::with_capacity(num_args as usize + 1);
378 args.push(fn_to_diff);
379
380 let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return");
381 if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
382 args.push(cx.get_metadata_value(enzyme_primal_ret));
383 }
384 if attrs.width > 1 {
385 let enzyme_width = cx.create_metadata(b"enzyme_width");
386 args.push(cx.get_metadata_value(enzyme_width));
387 args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
388 }
389
390 let has_sret = has_sret(outer_fn);
391 let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
392 match_args_from_caller_to_enzyme(
393 &cx,
394 &builder,
395 attrs.width,
396 &mut args,
397 &attrs.input_activity,
398 &outer_args,
399 has_sret,
400 );
401
402 let call = builder.call(enzyme_ty, ad_fn, &args, None);
403
404 // This part is a bit iffy. LLVM requires that a call to an inlineable function has some
405 // metadata attached to it, but we just created this code oota. Given that the
406 // differentiated function already has partly confusing metadata, and given that this
407 // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
408 // dummy code which we inserted at a higher level.
409 // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have,
410 // and how to best improve it for enzyme core and rust-enzyme.
411 let md_ty = cx.get_md_kind_id("dbg");
412 if llvm::LLVMRustHasMetadata(last_inst, md_ty) {
413 let md = llvm::LLVMRustDIGetInstMetadata(last_inst)
414 .expect("failed to get instruction metadata");
415 let md_todiff = cx.get_metadata_value(md);
416 llvm::LLVMSetMetadata(call, md_ty, md_todiff);
417 } else {
418 // We don't panic, since depending on whether we are in debug or release mode, we might
419 // have no debug info to copy, which would then be ok.
420 trace!("no dbg info");
421 }
422
423 // Now that we copied the metadata, get rid of dummy code.
424 llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
425
426 if cx.val_ty(call) == cx.type_void() || has_sret {
427 if has_sret {
428 // This is what we already have in our outer_fn (shortened):
429 // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
430 // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
431 // <Here we are, we want to add the following two lines>
432 // store [4 x double] %7, ptr %0, align 8
433 // ret void
434 // }
435
436 // now store the result of the enzyme call into the sret pointer.
437 let sret_ptr = outer_args[0];
438 let call_ty = cx.val_ty(call);
439 if attrs.width == 1 {
440 assert_eq!(cx.type_kind(call_ty), TypeKind::Struct);
441 } else {
442 assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
443 }
444 llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
445 }
446 builder.ret_void();
447 } else {
448 builder.ret(call);
449 }
450
451 // Let's crash in case that we messed something up above and generated invalid IR.
452 llvm::LLVMRustVerifyFunction(
453 outer_fn,
454 llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction,
455 );
456 }
457}
458
459pub(crate) fn differentiate<'ll>(
460 module: &'ll ModuleCodegen<ModuleLlvm>,
461 cgcx: &CodegenContext<LlvmCodegenBackend>,
462 diff_items: Vec<AutoDiffItem>,
463) -> Result<(), FatalError> {
464 for item in &diff_items {
465 trace!("{}", item);
466 }
467
468 let diag_handler = cgcx.create_dcx();
469
470 let cx = SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size);
471
472 // First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag?
473 if !diff_items.is_empty()
474 && !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
475 {
476 return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
477 }
478
479 // Here we replace the placeholder code with the actual autodiff code, which calls Enzyme.
480 for item in diff_items.iter() {
481 let name = item.source.clone();
482 let fn_def: Option<&llvm::Value> = cx.get_function(&name);
483 let Some(fn_def) = fn_def else {
484 return Err(llvm_err(
485 diag_handler.handle(),
486 LlvmError::PrepareAutoDiff {
487 src: item.source.clone(),
488 target: item.target.clone(),
489 error: "could not find source function".to_owned(),
490 },
491 ));
492 };
493 debug!(?item.target);
494 let fn_target: Option<&llvm::Value> = cx.get_function(&item.target);
495 let Some(fn_target) = fn_target else {
496 return Err(llvm_err(
497 diag_handler.handle(),
498 LlvmError::PrepareAutoDiff {
499 src: item.source.clone(),
500 target: item.target.clone(),
501 error: "could not find target function".to_owned(),
502 },
503 ));
504 };
505
506 generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
507 }
508
509 // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
510
511 trace!("done with differentiate()");
512
513 Ok(())
514}