1mod llvm_enzyme {
7 use std::str::FromStr;
8 use std::string::String;
9
10 use rustc_ast::expand::autodiff_attrs::{
11 AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
12 valid_ty_for_activity,
13 };
14 use rustc_ast::ptr::P;
15 use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
16 use rustc_ast::tokenstream::*;
17 use rustc_ast::visit::AssocCtxt::*;
18 use rustc_ast::{
19 self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
20 MetaItemInner, PatKind, QSelf, TyKind, Visibility,
21 };
22 use rustc_expand::base::{Annotatable, ExtCtxt};
23 use rustc_span::{Ident, Span, Symbol, kw, sym};
24 use thin_vec::{ThinVec, thin_vec};
25 use tracing::{debug, trace};
26
27 use crate::errors;
28
29 pub(crate) fn outer_normal_attr(
30 kind: &P<rustc_ast::NormalAttr>,
31 id: rustc_ast::AttrId,
32 span: Span,
33 ) -> rustc_ast::Attribute {
34 let style = rustc_ast::AttrStyle::Outer;
35 let kind = rustc_ast::AttrKind::Normal(kind.clone());
36 rustc_ast::Attribute { kind, id, style, span }
37 }
38
39 fn has_ret(ty: &FnRetTy) -> bool {
42 match ty {
43 FnRetTy::Ty(ty) => !ty.kind.is_unit(),
44 FnRetTy::Default(_) => false,
45 }
46 }
47 fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
48 if let Some(l) = x.lit() {
49 match l.kind {
50 ast::LitKind::Int(val, _) => {
51 return rustc_span::Ident::from_str(val.get().to_string().as_str());
53 }
54 _ => {}
55 }
56 }
57
58 let segments = &x.meta_item().unwrap().path.segments;
59 assert!(segments.len() == 1);
60 segments[0].ident
61 }
62
63 fn name(x: &MetaItemInner) -> String {
64 first_ident(x).name.to_string()
65 }
66
67 fn width(x: &MetaItemInner) -> Option<u128> {
68 let lit = x.lit()?;
69 match lit.kind {
70 ast::LitKind::Int(x, _) => Some(x.get()),
71 _ => return None,
72 }
73 }
74
75 fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
77 match &iitem.kind {
78 ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
79 Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
80 }
81 _ => None,
82 }
83 }
84
85 pub(crate) fn from_ast(
86 ecx: &mut ExtCtxt<'_>,
87 meta_item: &ThinVec<MetaItemInner>,
88 has_ret: bool,
89 mode: DiffMode,
90 ) -> AutoDiffAttrs {
91 let dcx = ecx.sess.dcx();
92
93 let mut first_activity = 1;
96
97 let width = if let [_, x, ..] = &meta_item[..]
98 && let Some(x) = width(x)
99 {
100 first_activity = 2;
101 match x.try_into() {
102 Ok(x) => x,
103 Err(_) => {
104 dcx.emit_err(errors::AutoDiffInvalidWidth {
105 span: meta_item[1].span(),
106 width: x,
107 });
108 return AutoDiffAttrs::error();
109 }
110 }
111 } else {
112 1
113 };
114
115 let mut activities: Vec<DiffActivity> = vec![];
116 let mut errors = false;
117 for x in &meta_item[first_activity..] {
118 let activity_str = name(&x);
119 let res = DiffActivity::from_str(&activity_str);
120 match res {
121 Ok(x) => activities.push(x),
122 Err(_) => {
123 dcx.emit_err(errors::AutoDiffUnknownActivity {
124 span: x.span(),
125 act: activity_str,
126 });
127 errors = true;
128 }
129 };
130 }
131 if errors {
132 return AutoDiffAttrs::error();
133 }
134
135 let (ret_activity, input_activity) = if has_ret {
138 let Some((last, rest)) = activities.split_last() else {
139 unreachable!(
140 "should not be reachable because we counted the number of activities previously"
141 );
142 };
143 (last, rest)
144 } else {
145 (&DiffActivity::None, activities.as_slice())
146 };
147
148 AutoDiffAttrs {
149 mode,
150 width,
151 ret_activity: *ret_activity,
152 input_activity: input_activity.to_vec(),
153 }
154 }
155
156 fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
157 let comma: Token = Token::new(TokenKind::Comma, Span::default());
158 let val = first_ident(t);
159 let t = Token::from_ast_ident(val);
160 ts.push(TokenTree::Token(t, Spacing::Joint));
161 ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
162 }
163
164 pub(crate) fn expand_forward(
165 ecx: &mut ExtCtxt<'_>,
166 expand_span: Span,
167 meta_item: &ast::MetaItem,
168 item: Annotatable,
169 ) -> Vec<Annotatable> {
170 expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
171 }
172
173 pub(crate) fn expand_reverse(
174 ecx: &mut ExtCtxt<'_>,
175 expand_span: Span,
176 meta_item: &ast::MetaItem,
177 item: Annotatable,
178 ) -> Vec<Annotatable> {
179 expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
180 }
181
182 pub(crate) fn expand_with_mode(
216 ecx: &mut ExtCtxt<'_>,
217 expand_span: Span,
218 meta_item: &ast::MetaItem,
219 mut item: Annotatable,
220 mode: DiffMode,
221 ) -> Vec<Annotatable> {
222 if cfg!(not(llvm_enzyme)) {
223 ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
224 return vec![item];
225 }
226 let dcx = ecx.sess.dcx();
227
228 let Some((vis, sig, primal, generics)) = (match &item {
232 Annotatable::Item(iitem) => extract_item_info(iitem),
233 Annotatable::Stmt(stmt) => match &stmt.kind {
234 ast::StmtKind::Item(iitem) => extract_item_info(iitem),
235 _ => None,
236 },
237 Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
238 ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
239 Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
240 }
241 _ => None,
242 },
243 _ => None,
244 }) else {
245 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
246 return vec![item];
247 };
248
249 let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
250 ast::MetaItemKind::List(ref vec) => vec.clone(),
251 _ => {
252 dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
253 return vec![item];
254 }
255 };
256
257 let has_ret = has_ret(&sig.decl.output);
258 let sig_span = ecx.with_call_site_ctxt(sig.span);
259
260 let mut ts: Vec<TokenTree> = vec![];
263 if meta_item_vec.len() < 1 {
264 dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
266 return vec![item];
267 }
268
269 let mode_symbol = match mode {
270 DiffMode::Forward => sym::Forward,
271 DiffMode::Reverse => sym::Reverse,
272 _ => unreachable!("Unsupported mode: {:?}", mode),
273 };
274
275 let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
277 ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
278 ts.insert(
279 1,
280 TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
281 );
282
283 let start_position;
286 let kind: LitKind = LitKind::Integer;
287 let symbol;
288 if meta_item_vec.len() >= 2
289 && let Some(width) = width(&meta_item_vec[1])
290 {
291 start_position = 2;
292 symbol = Symbol::intern(&width.to_string());
293 } else {
294 start_position = 1;
295 symbol = sym::integer(1);
296 }
297
298 let l: Lit = Lit { kind, symbol, suffix: None };
299 let t = Token::new(TokenKind::Literal(l), Span::default());
300 let comma = Token::new(TokenKind::Comma, Span::default());
301 ts.push(TokenTree::Token(t, Spacing::Joint));
302 ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
303
304 for t in meta_item_vec.clone()[start_position..].iter() {
305 meta_item_inner_to_ts(t, &mut ts);
306 }
307
308 if !has_ret {
309 let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
312 ts.push(TokenTree::Token(t, Spacing::Joint));
313 ts.push(TokenTree::Token(comma, Spacing::Alone));
314 }
315 ts.pop();
317 let ts: TokenStream = TokenStream::from_iter(ts);
318
319 let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
320 if !x.is_active() {
321 return vec![item];
324 }
325 let span = ecx.with_def_site_ctxt(expand_span);
326
327 let n_active: u32 = x
328 .input_activity
329 .iter()
330 .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
331 .count() as u32;
332 let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
333 let d_body = gen_enzyme_body(
334 ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
335 &generics,
336 );
337
338 let asdf = Box::new(ast::Fn {
340 defaultness: ast::Defaultness::Final,
341 sig: d_sig,
342 ident: first_ident(&meta_item_vec[0]),
343 generics,
344 contract: None,
345 body: Some(d_body),
346 define_opaque: None,
347 });
348 let mut rustc_ad_attr =
349 P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
350
351 let ts2: Vec<TokenTree> = vec![TokenTree::Token(
352 Token::new(TokenKind::Ident(sym::never, false.into()), span),
353 Spacing::Joint,
354 )];
355 let never_arg = ast::DelimArgs {
356 dspan: DelimSpan::from_single(span),
357 delim: ast::token::Delimiter::Parenthesis,
358 tokens: TokenStream::from_iter(ts2),
359 };
360 let inline_item = ast::AttrItem {
361 unsafety: ast::Safety::Default,
362 path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
363 args: ast::AttrArgs::Delimited(never_arg),
364 tokens: None,
365 };
366 let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
367 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
368 let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
369 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
370 let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
371
372 fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
374 match (attr, item) {
375 (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
376 let a = &a.item.path;
377 let b = &b.item.path;
378 a.segments.len() == b.segments.len()
379 && a.segments.iter().zip(b.segments.iter()).all(|(a, b)| a.ident == b.ident)
380 }
381 _ => false,
382 }
383 }
384
385 let orig_annotatable: Annotatable = match item {
387 Annotatable::Item(ref mut iitem) => {
388 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
389 iitem.attrs.push(attr);
390 }
391 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
392 iitem.attrs.push(inline_never.clone());
393 }
394 Annotatable::Item(iitem.clone())
395 }
396 Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
397 if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
398 assoc_item.attrs.push(attr);
399 }
400 if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
401 assoc_item.attrs.push(inline_never.clone());
402 }
403 Annotatable::AssocItem(assoc_item.clone(), i)
404 }
405 Annotatable::Stmt(ref mut stmt) => {
406 match stmt.kind {
407 ast::StmtKind::Item(ref mut iitem) => {
408 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
409 iitem.attrs.push(attr);
410 }
411 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
412 {
413 iitem.attrs.push(inline_never.clone());
414 }
415 }
416 _ => unreachable!("stmt kind checked previously"),
417 };
418
419 Annotatable::Stmt(stmt.clone())
420 }
421 _ => {
422 unreachable!("annotatable kind checked previously")
423 }
424 };
425 rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
427 dspan: DelimSpan::dummy(),
428 delim: rustc_ast::token::Delimiter::Parenthesis,
429 tokens: ts,
430 });
431
432 let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
433 let d_annotatable = match &item {
434 Annotatable::AssocItem(_, _) => {
435 let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
436 let d_fn = P(ast::AssocItem {
437 attrs: thin_vec![d_attr, inline_never],
438 id: ast::DUMMY_NODE_ID,
439 span,
440 vis,
441 kind: assoc_item,
442 tokens: None,
443 });
444 Annotatable::AssocItem(d_fn, Impl { of_trait: false })
445 }
446 Annotatable::Item(_) => {
447 let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
448 d_fn.vis = vis;
449
450 Annotatable::Item(d_fn)
451 }
452 Annotatable::Stmt(_) => {
453 let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
454 d_fn.vis = vis;
455
456 Annotatable::Stmt(P(ast::Stmt {
457 id: ast::DUMMY_NODE_ID,
458 kind: ast::StmtKind::Item(d_fn),
459 span,
460 }))
461 }
462 _ => {
463 unreachable!("item kind checked previously")
464 }
465 };
466
467 return vec![orig_annotatable, d_annotatable];
468 }
469
470 fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
473 let mut ty = ty.clone();
474 match ty.kind {
475 TyKind::Ptr(ref mut mut_ty) => {
476 mut_ty.mutbl = ast::Mutability::Mut;
477 }
478 TyKind::Ref(_, ref mut mut_ty) => {
479 mut_ty.mutbl = ast::Mutability::Mut;
480 }
481 _ => {
482 panic!("unsupported type: {:?}", ty);
483 }
484 }
485 ty
486 }
487
488 fn init_body_helper(
500 ecx: &ExtCtxt<'_>,
501 span: Span,
502 primal: Ident,
503 new_names: &[String],
504 sig_span: Span,
505 new_decl_span: Span,
506 idents: &[Ident],
507 errored: bool,
508 generics: &Generics,
509 ) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
510 let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
511 let noop = ast::InlineAsm {
512 asm_macro: ast::AsmMacro::Asm,
513 template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
514 template_strs: Box::new([]),
515 operands: vec![],
516 clobber_abis: vec![],
517 options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
518 line_spans: vec![],
519 };
520 let noop_expr = ecx.expr_asm(span, P(noop));
521 let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
522 let unsf_block = ast::Block {
523 stmts: thin_vec![ecx.stmt_semi(noop_expr)],
524 id: ast::DUMMY_NODE_ID,
525 tokens: None,
526 rules: unsf,
527 span,
528 };
529 let unsf_expr = ecx.expr_block(P(unsf_block));
530 let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
531 let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
532 let black_box_primal_call = ecx.expr_call(
533 new_decl_span,
534 blackbox_call_expr.clone(),
535 thin_vec![primal_call.clone()],
536 );
537 let tup_args = new_names
538 .iter()
539 .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
540 .collect();
541
542 let black_box_remaining_args = ecx.expr_call(
543 sig_span,
544 blackbox_call_expr.clone(),
545 thin_vec![ecx.expr_tuple(sig_span, tup_args)],
546 );
547
548 let mut body = ecx.block(span, ThinVec::new());
549 body.stmts.push(ecx.stmt_semi(unsf_expr));
550
551 if !errored {
553 body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
554 }
555 body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
556
557 (body, primal_call, black_box_primal_call, blackbox_call_expr)
558 }
559
560 fn gen_enzyme_body(
569 ecx: &ExtCtxt<'_>,
570 x: &AutoDiffAttrs,
571 n_active: u32,
572 sig: &ast::FnSig,
573 d_sig: &ast::FnSig,
574 primal: Ident,
575 new_names: &[String],
576 span: Span,
577 sig_span: Span,
578 idents: Vec<Ident>,
579 errored: bool,
580 generics: &Generics,
581 ) -> P<ast::Block> {
582 let new_decl_span = d_sig.span;
583
584 let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
593 ecx,
594 span,
595 primal,
596 new_names,
597 sig_span,
598 new_decl_span,
599 &idents,
600 errored,
601 generics,
602 );
603
604 if !has_ret(&d_sig.decl.output) {
605 return body;
607 }
608
609 let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
614
615 if primal_ret && n_active == 0 && x.mode.is_rev() {
616 body.stmts.push(ecx.stmt_expr(bb_primal_call));
618 return body;
619 }
620
621 if !primal_ret && n_active == 1 {
622 let ty = match d_sig.decl.output {
624 FnRetTy::Ty(ref ty) => ty.clone(),
625 FnRetTy::Default(span) => {
626 panic!("Did not expect Default ret ty: {:?}", span);
627 }
628 };
629 let arg = ty.kind.is_simple_path().unwrap();
630 let tmp = ecx.def_site_path(&[arg, kw::Default]);
631 let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
632 let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
633 body.stmts.push(ecx.stmt_expr(default_call_expr));
634 return body;
635 }
636
637 let mut exprs: P<ast::Expr> = primal_call;
638 let d_ret_ty = match d_sig.decl.output {
639 FnRetTy::Ty(ref ty) => ty.clone(),
640 FnRetTy::Default(span) => {
641 panic!("Did not expect Default ret ty: {:?}", span);
642 }
643 };
644 if x.mode.is_fwd() {
645 if x.ret_activity == DiffActivity::Const {
650 exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
653 } else {
654 let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 };
655 let y = ExprKind::Path(
656 Some(P(q)),
657 ecx.path_ident(span, Ident::with_dummy_span(kw::Default)),
658 );
659 let default_call_expr = ecx.expr(span, y);
660 let default_call_expr =
661 ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
662 exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
663 }
664 } else if x.mode.is_rev() {
665 if x.width == 1 {
666 match d_ret_ty.kind {
668 TyKind::Tup(ref args) => {
669 let mut exprs2 = thin_vec![exprs];
672 for arg in args.iter().skip(1) {
673 let arg = arg.kind.is_simple_path().unwrap();
674 let tmp = ecx.def_site_path(&[arg, kw::Default]);
675 let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
676 let default_call_expr =
677 ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
678 exprs2.push(default_call_expr);
679 }
680 exprs = ecx.expr_tuple(new_decl_span, exprs2);
681 }
682 _ => {
683 panic!("Unsupported return type: {:?}", d_ret_ty);
687 }
688 }
689 }
690 exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
691 } else {
692 unreachable!("Unsupported mode: {:?}", x.mode);
693 }
694
695 body.stmts.push(ecx.stmt_expr(exprs));
696
697 body
698 }
699
700 fn gen_primal_call(
701 ecx: &ExtCtxt<'_>,
702 span: Span,
703 primal: Ident,
704 idents: &[Ident],
705 generics: &Generics,
706 ) -> P<ast::Expr> {
707 let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
708
709 if has_self {
710 let args: ThinVec<_> =
711 idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
712 let self_expr = ecx.expr_self(span);
713 ecx.expr_method_call(span, self_expr, primal, args)
714 } else {
715 let args: ThinVec<_> =
716 idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
717 let mut primal_path = ecx.path_ident(span, primal);
718
719 let is_generic = !generics.params.is_empty();
720
721 match (is_generic, primal_path.segments.last_mut()) {
722 (true, Some(function_path)) => {
723 let primal_generic_types = generics
724 .params
725 .iter()
726 .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
727
728 let generated_generic_types = primal_generic_types
729 .map(|type_param| {
730 let generic_param = TyKind::Path(
731 None,
732 ast::Path {
733 span,
734 segments: thin_vec![ast::PathSegment {
735 ident: type_param.ident,
736 args: None,
737 id: ast::DUMMY_NODE_ID,
738 }],
739 tokens: None,
740 },
741 );
742
743 ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty {
744 id: type_param.id,
745 span,
746 kind: generic_param,
747 tokens: None,
748 })))
749 })
750 .collect();
751
752 function_path.args =
753 Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
754 span,
755 args: generated_generic_types,
756 })));
757 }
758 _ => {}
759 }
760
761 let primal_call_expr = ecx.expr_path(primal_path);
762 ecx.expr_call(span, primal_call_expr, args)
763 }
764 }
765
766 fn gen_enzyme_decl(
778 ecx: &ExtCtxt<'_>,
779 sig: &ast::FnSig,
780 x: &AutoDiffAttrs,
781 span: Span,
782 ) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
783 let dcx = ecx.sess.dcx();
784 let has_ret = has_ret(&sig.decl.output);
785 let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
786 let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
787 if sig_args != num_activities {
788 dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
789 span,
790 expected: sig_args,
791 found: num_activities,
792 });
793 return (sig.clone(), vec![], vec![], true);
795 }
796 assert!(sig.decl.inputs.len() == x.input_activity.len());
797 assert!(has_ret == x.has_ret_activity());
798 let mut d_decl = sig.decl.clone();
799 let mut d_inputs = Vec::new();
800 let mut new_inputs = Vec::new();
801 let mut idents = Vec::new();
802 let mut act_ret = ThinVec::new();
803
804 let mut errors = false;
807 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
808 if !valid_input_activity(x.mode, *activity) {
809 dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
810 span,
811 mode: x.mode.to_string(),
812 act: activity.to_string(),
813 });
814 errors = true;
815 }
816 if !valid_ty_for_activity(&arg.ty, *activity) {
817 dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
818 span: arg.ty.span,
819 act: activity.to_string(),
820 });
821 errors = true;
822 }
823 }
824
825 if has_ret && !valid_ret_activity(x.mode, x.ret_activity) {
826 dcx.emit_err(errors::AutoDiffInvalidRetAct {
827 span,
828 mode: x.mode.to_string(),
829 act: x.ret_activity.to_string(),
830 });
831 }
834
835 if errors {
836 return (sig.clone(), new_inputs, idents, true);
838 }
839
840 let unsafe_activities = x
841 .input_activity
842 .iter()
843 .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
844 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
845 d_inputs.push(arg.clone());
846 match activity {
847 DiffActivity::Active => {
848 act_ret.push(arg.ty.clone());
849 }
851 DiffActivity::ActiveOnly => {
852 }
855 DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
856 for i in 0..x.width {
857 let mut shadow_arg = arg.clone();
858 shadow_arg.ty = P(assure_mut_ref(&arg.ty));
860 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
861 ident.name
862 } else {
863 debug!("{:#?}", &shadow_arg.pat);
864 panic!("not an ident?");
865 };
866 let name: String = format!("d{}_{}", old_name, i);
867 new_inputs.push(name.clone());
868 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
869 shadow_arg.pat = P(ast::Pat {
870 id: ast::DUMMY_NODE_ID,
871 kind: PatKind::Ident(BindingMode::NONE, ident, None),
872 span: shadow_arg.pat.span,
873 tokens: shadow_arg.pat.tokens.clone(),
874 });
875 d_inputs.push(shadow_arg.clone());
876 }
877 }
878 DiffActivity::Dual
879 | DiffActivity::DualOnly
880 | DiffActivity::Dualv
881 | DiffActivity::DualvOnly => {
882 let iterations =
885 if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
886 1
887 } else {
888 x.width
889 };
890 for i in 0..iterations {
891 let mut shadow_arg = arg.clone();
892 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
893 ident.name
894 } else {
895 debug!("{:#?}", &shadow_arg.pat);
896 panic!("not an ident?");
897 };
898 let name: String = format!("b{}_{}", old_name, i);
899 new_inputs.push(name.clone());
900 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
901 shadow_arg.pat = P(ast::Pat {
902 id: ast::DUMMY_NODE_ID,
903 kind: PatKind::Ident(BindingMode::NONE, ident, None),
904 span: shadow_arg.pat.span,
905 tokens: shadow_arg.pat.tokens.clone(),
906 });
907 d_inputs.push(shadow_arg.clone());
908 }
909 }
910 DiffActivity::Const => {
911 }
913 DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
914 panic!("Should not happen");
915 }
916 }
917 if let PatKind::Ident(_, ident, _) = arg.pat.kind {
918 idents.push(ident.clone());
919 } else {
920 panic!("not an ident?");
921 }
922 }
923
924 let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
925 if active_only_ret {
926 assert!(x.mode.is_rev());
927 }
928
929 if x.mode.is_rev() {
932 match x.ret_activity {
933 DiffActivity::Active | DiffActivity::ActiveOnly => {
934 let ty = match d_decl.output {
935 FnRetTy::Ty(ref ty) => ty.clone(),
936 FnRetTy::Default(span) => {
937 panic!("Did not expect Default ret ty: {:?}", span);
938 }
939 };
940 let name = "dret".to_string();
941 let ident = Ident::from_str_and_span(&name, ty.span);
942 let shadow_arg = ast::Param {
943 attrs: ThinVec::new(),
944 ty: ty.clone(),
945 pat: P(ast::Pat {
946 id: ast::DUMMY_NODE_ID,
947 kind: PatKind::Ident(BindingMode::NONE, ident, None),
948 span: ty.span,
949 tokens: None,
950 }),
951 id: ast::DUMMY_NODE_ID,
952 span: ty.span,
953 is_placeholder: false,
954 };
955 d_inputs.push(shadow_arg);
956 new_inputs.push(name);
957 }
958 _ => {}
959 }
960 }
961 d_decl.inputs = d_inputs.into();
962
963 if x.mode.is_fwd() {
964 let ty = match d_decl.output {
965 FnRetTy::Ty(ref ty) => ty.clone(),
966 FnRetTy::Default(span) => {
967 let kind = TyKind::Tup(ThinVec::new());
969 let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
970 d_decl.output = FnRetTy::Ty(ty.clone());
971 assert!(matches!(x.ret_activity, DiffActivity::None));
972 ty
974 }
975 };
976
977 if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
978 let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
979 TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
982 } else {
983 let anon_const = rustc_ast::AnonConst {
985 id: ast::DUMMY_NODE_ID,
986 value: ecx.expr_usize(span, 1 + x.width as usize),
987 };
988 TyKind::Array(ty.clone(), anon_const)
989 };
990 let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
991 d_decl.output = FnRetTy::Ty(ty);
992 }
993 if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
994 if x.width > 1 {
998 let anon_const = rustc_ast::AnonConst {
999 id: ast::DUMMY_NODE_ID,
1000 value: ecx.expr_usize(span, x.width as usize),
1001 };
1002 let kind = TyKind::Array(ty.clone(), anon_const);
1003 let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
1004 d_decl.output = FnRetTy::Ty(ty);
1005 }
1006 }
1007 }
1008
1009 d_decl.output =
1011 if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
1012
1013 trace!("act_ret: {:?}", act_ret);
1014
1015 if act_ret.len() > 0 {
1019 let ret_ty = match d_decl.output {
1020 FnRetTy::Ty(ref ty) => {
1021 if !active_only_ret {
1022 act_ret.insert(0, ty.clone());
1023 }
1024 let kind = TyKind::Tup(act_ret);
1025 P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
1026 }
1027 FnRetTy::Default(span) => {
1028 if act_ret.len() == 1 {
1029 act_ret[0].clone()
1030 } else {
1031 let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
1032 P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
1033 }
1034 }
1035 };
1036 d_decl.output = FnRetTy::Ty(ret_ty);
1037 }
1038
1039 let mut d_header = sig.header.clone();
1040 if unsafe_activities {
1041 d_header.safety = rustc_ast::Safety::Unsafe(span);
1042 }
1043 let d_sig = FnSig { header: d_header, decl: d_decl, span };
1044 trace!("Generated signature: {:?}", d_sig);
1045 (d_sig, new_inputs, idents, false)
1046 }
1047}
1048
1049pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};