Skip to main content

rustc_builtin_macros/
offload.rs

1use rustc_ast::token::{Delimiter, Token, TokenKind};
2use rustc_ast::tokenstream::{DelimSpan, Spacing, TokenStream, TokenTree};
3use rustc_ast::{AttrItem, ast};
4use rustc_expand::base::{Annotatable, ExtCtxt};
5use rustc_session::config::Offload;
6use rustc_span::{Ident, Span, sym};
7use thin_vec::thin_vec;
8
9use crate::errors;
10
11fn compile_for_device(ecx: &mut ExtCtxt<'_>) -> bool {
12    ecx.sess.opts.unstable_opts.offload.contains(&Offload::Device)
13}
14
15fn outer_normal_attr(
16    kind: &Box<rustc_ast::NormalAttr>,
17    id: rustc_ast::AttrId,
18    span: Span,
19) -> rustc_ast::Attribute {
20    let style = rustc_ast::AttrStyle::Outer;
21    let kind = rustc_ast::AttrKind::Normal(kind.clone());
22    rustc_ast::Attribute { kind, id, style, span }
23}
24
25fn extract_fn(
26    item: &Annotatable,
27) -> Option<(ast::Visibility, ast::FnSig, Ident, ast::Generics, Option<Box<ast::Block>>)> {
28    match item {
29        Annotatable::Item(iitem) => match &iitem.kind {
30            ast::ItemKind::Fn(ast::Fn { sig, ident, generics, body, .. }) => {
31                Some((iitem.vis.clone(), sig.clone(), *ident, generics.clone(), body.clone()))
32            }
33            _ => None,
34        },
35        _ => None,
36    }
37}
38
39/// The `offload_kernel` macro expands the function into two separate definitions:
40/// one on the host to handle the call, and one on the device for executing the kernel.
41///
42/// ```
43/// #[offload_kernel]
44/// fn foo(a: &[f32], b: &[f32], c: *mut f32) {
45///     *c = a[0] + b[0];
46/// }
47/// ```
48///
49/// This expands to the host-side function:
50///
51/// ```
52/// #[unsafe(no_mangle)]
53/// #[inline(never)]
54/// fn foo(_: &[f32], _: &[f32], _: *mut f32) {
55///     ::core::panicking::panic("not implemented")
56/// }
57/// ```
58///
59/// And the device-side kernel:
60///
61/// ```
62/// #[rustc_offload_kernel]
63/// #[unsafe(no_mangle)]
64/// unsafe extern "gpu-kernel" fn foo(a: &[f32], b: &[f32], c: *mut f32) {
65///     *c = a[0] + b[0];
66/// }
67/// ```
68pub(crate) fn expand_kernel(
69    ecx: &mut ExtCtxt<'_>,
70    expand_span: Span,
71    _meta_item: &ast::MetaItem,
72    item: Annotatable,
73) -> Vec<Annotatable> {
74    let dcx = ecx.sess.dcx();
75
76    let Some((vis, sig, ident, generics, body)) = extract_fn(&item) else {
77        dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
78        return ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [item]))vec![item];
79    };
80
81    let span = ecx.with_def_site_ctxt(expand_span);
82
83    // device function
84    let mut device_fn = Box::new(ast::Fn {
85        defaultness: ast::Defaultness::Implicit,
86        sig: sig.clone(),
87        ident,
88        generics: generics.clone(),
89        contract: None,
90        body,
91        define_opaque: None,
92        eii_impls: Default::default(),
93    });
94
95    let extern_gpu_kernel = ast::Extern::from_abi(
96        Some(ast::StrLit {
97            symbol: sym::gpu_kernel,
98            suffix: None,
99            symbol_unescaped: sym::gpu_kernel,
100            style: ast::StrStyle::Cooked,
101            span,
102        }),
103        span,
104    );
105    device_fn.sig.header.ext = extern_gpu_kernel;
106    device_fn.sig.header.safety = ast::Safety::Unsafe(span);
107
108    // rustc_offload_kernel attr
109    let rustc_offload_kernel_attr =
110        Box::new(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_offload_kernel)));
111    let rustc_offload_kernel = outer_normal_attr(
112        &rustc_offload_kernel_attr,
113        ecx.sess.psess.attr_id_generator.mk_attr_id(),
114        span,
115    );
116
117    // unsafe(no_mangle) attr
118    let unsafe_item = AttrItem {
119        unsafety: ast::Safety::Unsafe(span),
120        path: ast::Path::from_ident(Ident::new(sym::no_mangle, span)),
121        args: ast::AttrItemKind::Unparsed(ast::AttrArgs::Empty),
122        tokens: None,
123    };
124
125    let no_mangle_attr = Box::new(ast::NormalAttr { item: unsafe_item, tokens: None });
126    let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
127    let unsafe_no_mangle = outer_normal_attr(&no_mangle_attr, new_id, span);
128
129    let device_item = {
130        let mut item = ecx.item(
131            span,
132            {
    let len = [(), ()].len();
    let mut vec = ::thin_vec::ThinVec::with_capacity(len);
    vec.push(rustc_offload_kernel);
    vec.push(unsafe_no_mangle);
    vec
}thin_vec![rustc_offload_kernel, unsafe_no_mangle],
133            ast::ItemKind::Fn(device_fn),
134        );
135        item.vis = vis.clone();
136        Annotatable::Item(item)
137    };
138
139    // unimplemented! body
140    let macro_expr = ecx.expr_macro_call(
141        span,
142        ecx.macro_call(
143            span,
144            ecx.path_global(
145                span,
146                [sym::std, sym::unimplemented].map(|s| Ident::new(s, span)).to_vec(),
147            ),
148            Delimiter::Parenthesis,
149            TokenStream::default(),
150        ),
151    );
152    let stmt = ecx.stmt_expr(macro_expr);
153    let body = ecx.block(span, {
    let len = [()].len();
    let mut vec = ::thin_vec::ThinVec::with_capacity(len);
    vec.push(stmt);
    vec
}thin_vec![stmt]);
154
155    // host function
156    let mut host_fn = Box::new(ast::Fn {
157        defaultness: ast::Defaultness::Implicit,
158        sig: sig.clone(),
159        ident,
160        generics: generics.clone(),
161        contract: None,
162        body: Some(body),
163        define_opaque: None,
164        eii_impls: Default::default(),
165    });
166
167    for param in host_fn.sig.decl.inputs.iter_mut() {
168        param.pat = Box::new(ecx.pat_wild(param.pat.span));
169    }
170
171    // inline(never) attr
172    let ts: Vec<TokenTree> = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [TokenTree::Token(Token::new(TokenKind::Ident(sym::never,
                            false.into()), span), Spacing::Joint)]))vec![TokenTree::Token(
173        Token::new(TokenKind::Ident(sym::never, false.into()), span),
174        Spacing::Joint,
175    )];
176
177    let never_arg = ast::DelimArgs {
178        dspan: DelimSpan::from_single(span),
179        delim: Delimiter::Parenthesis,
180        tokens: TokenStream::from_iter(ts),
181    };
182
183    let inline_item = ast::AttrItem {
184        unsafety: ast::Safety::Default,
185        path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
186        args: rustc_ast::ast::AttrItemKind::Unparsed(ast::AttrArgs::Delimited(never_arg)),
187        tokens: None,
188    };
189    let inline_never_attr = Box::new(ast::NormalAttr { item: inline_item, tokens: None });
190
191    let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
192    let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
193
194    let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
195    let unsafe_no_mangle = outer_normal_attr(&no_mangle_attr, new_id, span);
196
197    let host_item = {
198        let mut item =
199            ecx.item(span, {
    let len = [(), ()].len();
    let mut vec = ::thin_vec::ThinVec::with_capacity(len);
    vec.push(unsafe_no_mangle);
    vec.push(inline_never);
    vec
}thin_vec![unsafe_no_mangle, inline_never], ast::ItemKind::Fn(host_fn));
200        item.vis = vis.clone();
201        Annotatable::Item(item)
202    };
203
204    if compile_for_device(ecx) { ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [device_item]))vec![device_item] } else { ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [host_item]))vec![host_item] }
205}