diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml new file mode 100644 index 0000000000000..edc48879479e8 --- /dev/null +++ b/.github/workflows/enzyme-ci.yml @@ -0,0 +1,38 @@ +name: Rust CI + +on: + push: + branches: + - master + pull_request: + branches: + - master + merge_group: + +jobs: + build: + name: Rust Integration CI LLVM ${{ matrix.llvm }} ${{ matrix.build }} ${{ matrix.os }} + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: [openstack22] + + timeout-minutes: 600 + steps: + - name: checkout the source code + uses: actions/checkout@v3 + with: + fetch-depth: 2 + - name: build + run: | + mkdir build + cd build + ../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-lld --enable-option-checking --enable-ninja --disable-docs + ../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc + rustup toolchain link enzyme `pwd`/build/`rustup target list --installed`/stage1 + rustup toolchain install nightly # enables -Z unstable-options + - name: test + run: | + cargo +enzyme test --examples diff --git a/.gitmodules b/.gitmodules index f5025097a18dc..7e217cb215dd8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -43,3 +43,6 @@ path = library/backtrace url = https://github.com/rust-lang/backtrace-rs.git shallow = true +[submodule "src/tools/enzyme"] + path = src/tools/enzyme + url = https://github.com/EnzymeAD/Enzyme.git diff --git a/Cargo.lock b/Cargo.lock index 0761268c9d411..06be5561ceec4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4312,6 +4312,7 @@ dependencies = [ "rustc_middle", "rustc_session", "rustc_span", + "rustc_symbol_mangling", "rustc_target", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 9b11ae8744b4f..ab42109434320 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ exclude = [ "src/tools/x", # stdarch has its own Cargo workspace "library/stdarch", + "library/autodiff", ] [profile.release.package.compiler_builtins] diff --git a/README.md b/README.md index a88ee4b8bf061..7b880aeb22d4a 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,68 @@ -# The Rust Programming Language +# The Rust Programming Language +Enzyme (git history) [![Rust Community](https://img.shields.io/badge/Rust_Community%20-Join_us-brightgreen?style=plastic&logo=rust)](https://www.rust-lang.org/community) This is the main source code repository for [Rust]. It contains the compiler, -standard library, and documentation. +standard library, and documentation. It is modified to use Enzyme for AutoDiff. + +Please configure this fork using the following command: + +``` +mkdir build +cd build +../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-clang --enable-lld --enable-option-checking --enable-ninja --disable-docs +``` + +Afterwards you can build rustc using: +``` +../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc +``` + +Afterwards rustc toolchain link will allow you to use it through cargo: +``` +rustup toolchain link enzyme `pwd`/build/`rustup target list --installed`/stage1 +rustup toolchain install nightly # enables -Z unstable-options +``` + +You can then look at examples in the `library/autodiff/examples/*` folder and run them with + +```bash +# rosenbrock forward iteration +cargo +enzyme run --example rosenbrock_fwd_iter --release + +# or all of them +cargo +enzyme test --examples +``` + +## Enzyme Config +To help with debugging, Enzyme can be configured using environment variables. +```bash +export ENZYME_PRINT_TA=1 +export ENZYME_PRINT_AA=1 +export ENZYME_PRINT=1 +export ENZYME_PRINT_MOD=1 +export ENZYME_PRINT_MOD_AFTER=1 +``` +The first three will print TypeAnalysis, ActivityAnalysis and the llvm-ir on a function basis, respectively. +The last two variables will print the whole module directly before and after Enzyme differented the functions. + +When experimenting with flags please make sure that EnzymeStrictAliasing=0 +is not changed, since it is required for Enzyme to handle enums correctly. + +## Bug reporting +Bugs are pretty much expected at this point of the development process. +In order to help us please minimize the Rust code as far as possible. +This tool might be a nicer helper: https://github.com/Nilstrieb/cargo-minimize +If you have some knowledge of LLVM-IR we also greatly appreciate it if you could help +us by compiling your minimized Rust code to LLVM-IR and reducing it further. + +The only exception to this strategy is error based on "Can not deduce type of X", +where reducing your example will make it harder for us to understand the origin of the bug. +In this case please just try to inline all dependencies into a single crate or even file, +without deleting used code. + + + [Rust]: https://www.rust-lang.org/ diff --git a/compiler/rustc_ast/src/ast.rs b/compiler/rustc_ast/src/ast.rs index 146a4db200caa..62d3c0227a9c6 100644 --- a/compiler/rustc_ast/src/ast.rs +++ b/compiler/rustc_ast/src/ast.rs @@ -1592,11 +1592,13 @@ impl MacCall { } } +/// Manuel /// Arguments passed to an attribute macro. #[derive(Clone, Encodable, Decodable, Debug)] pub enum AttrArgs { /// No arguments: `#[attr]`. Empty, + /// Manuel autodiff /// Delimited arguments: `#[attr()/[]/{}]`. Delimited(DelimArgs), /// Arguments of a key-value attribute: `#[attr = "value"]`. diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs new file mode 100644 index 0000000000000..2a8a120d62323 --- /dev/null +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -0,0 +1,149 @@ +use super::typetree::TypeTree; +use std::str::FromStr; +use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableOrd}; +use crate::HashStableContext; + +#[allow(dead_code)] +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)] +pub enum DiffMode { + Inactive, + Source, + Forward, + Reverse, +} + +#[allow(dead_code)] +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)] +pub enum DiffActivity { + None, + Active, + Const, + Duplicated, + DuplicatedNoNeed, +} +fn clause_diffactivity_discriminant(value: &DiffActivity) -> usize { + match value { + DiffActivity::None => 0, + DiffActivity::Active => 1, + DiffActivity::Const => 2, + DiffActivity::Duplicated => 3, + DiffActivity::DuplicatedNoNeed => 4, + } +} +fn clause_diffmode_discriminant(value: &DiffMode) -> usize { + match value { + DiffMode::Inactive => 0, + DiffMode::Source => 1, + DiffMode::Forward => 2, + DiffMode::Reverse => 3, + } +} + + +impl HashStable for DiffMode { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + clause_diffmode_discriminant(self).hash_stable(hcx, hasher); + } +} + +impl HashStable for DiffActivity { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + clause_diffactivity_discriminant(self).hash_stable(hcx, hasher); + } +} + + +impl FromStr for DiffActivity { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "None" => Ok(DiffActivity::None), + "Active" => Ok(DiffActivity::Active), + "Const" => Ok(DiffActivity::Const), + "Duplicated" => Ok(DiffActivity::Duplicated), + "DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed), + _ => Err(()), + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] +pub struct AutoDiffAttrs { + pub mode: DiffMode, + pub ret_activity: DiffActivity, + pub input_activity: Vec, +} + +impl HashStable for AutoDiffAttrs { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + self.mode.hash_stable(hcx, hasher); + self.ret_activity.hash_stable(hcx, hasher); + self.input_activity.hash_stable(hcx, hasher); + } +} + +impl AutoDiffAttrs { + pub fn inactive() -> Self { + AutoDiffAttrs { + mode: DiffMode::Inactive, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + } + } + + pub fn is_active(&self) -> bool { + match self.mode { + DiffMode::Inactive => false, + _ => true, + } + } + + pub fn is_source(&self) -> bool { + match self.mode { + DiffMode::Source => true, + _ => false, + } + } + pub fn apply_autodiff(&self) -> bool { + match self.mode { + DiffMode::Inactive => false, + DiffMode::Source => false, + _ => true, + } + } + + pub fn into_item( + self, + source: String, + target: String, + inputs: Vec, + output: TypeTree, + ) -> AutoDiffItem { + AutoDiffItem { source, target, inputs, output, attrs: self } + } +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] +pub struct AutoDiffItem { + pub source: String, + pub target: String, + pub attrs: AutoDiffAttrs, + pub inputs: Vec, + pub output: TypeTree, +} + +impl HashStable for AutoDiffItem { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + self.source.hash_stable(hcx, hasher); + self.target.hash_stable(hcx, hasher); + self.attrs.hash_stable(hcx, hasher); + for tt in &self.inputs { + tt.0.hash_stable(hcx, hasher); + } + //self.inputs.hash_stable(hcx, hasher); + self.output.0.hash_stable(hcx, hasher); + } +} + diff --git a/compiler/rustc_ast/src/expand/mod.rs b/compiler/rustc_ast/src/expand/mod.rs index 942347383ce31..b8434374a3594 100644 --- a/compiler/rustc_ast/src/expand/mod.rs +++ b/compiler/rustc_ast/src/expand/mod.rs @@ -5,6 +5,8 @@ use rustc_span::{def_id::DefId, symbol::Ident}; use crate::MetaItem; pub mod allocator; +pub mod typetree; +pub mod autodiff_attrs; #[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)] pub struct StrippedCfgItem { diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs new file mode 100644 index 0000000000000..2adcc724831a7 --- /dev/null +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -0,0 +1,67 @@ +use std::fmt; +use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableOrd}; +use crate::HashStableContext; + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] +pub enum Kind { + Anything, + Integer, + Pointer, + Half, + Float, + Double, + Unknown, +} +impl HashStable for Kind { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + clause_kind_discriminant(self).hash_stable(hcx, hasher); + } +} +fn clause_kind_discriminant(value: &Kind) -> usize { + match value { + Kind::Anything => 0, + Kind::Integer => 1, + Kind::Pointer => 2, + Kind::Half => 3, + Kind::Float => 4, + Kind::Double => 5, + Kind::Unknown => 6, + } +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] +pub struct TypeTree(pub Vec); + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] +pub struct Type { + pub offset: isize, + pub size: usize, + pub kind: Kind, + pub child: TypeTree, +} + +impl HashStable for Type { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + self.offset.hash_stable(hcx, hasher); + self.size.hash_stable(hcx, hasher); + self.kind.hash_stable(hcx, hasher); + self.child.0.hash_stable(hcx, hasher); + } +} + +impl Type { + pub fn add_offset(self, add: isize) -> Self { + let offset = match self.offset { + -1 => add, + x => add + x, + }; + + Self { size: self.size, kind: self.kind, child: self.child, offset } + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} diff --git a/compiler/rustc_builtin_macros/messages.ftl b/compiler/rustc_builtin_macros/messages.ftl index dda466b026d91..07c9b588e1faf 100644 --- a/compiler/rustc_builtin_macros/messages.ftl +++ b/compiler/rustc_builtin_macros/messages.ftl @@ -1,6 +1,8 @@ builtin_macros_alloc_error_must_be_fn = alloc_error_handler must be a function builtin_macros_alloc_must_statics = allocators must be statics +builtin_macros_autodiff = autodiff must be applied to function + builtin_macros_asm_clobber_abi = clobber_abi builtin_macros_asm_clobber_no_reg = asm with `clobber_abi` must specify explicit registers for outputs builtin_macros_asm_clobber_outputs = generic outputs diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs new file mode 100644 index 0000000000000..2e5c57e14f9f3 --- /dev/null +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -0,0 +1,178 @@ +use crate::errors; +use crate::util::check_builtin_macro_attribute; +//use rustc_ast::expand::allocator::{ +// global_fn_name, AllocatorMethod, AllocatorMethodInput, AllocatorTy, ALLOCATOR_METHODS, +//}; +//use rustc_middle::middle::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode}; +//use rustc_ast::ptr::P; +use rustc_ast::{self as ast};//, AttrVec, Expr, FnHeader, FnSig, Generics, Param, StmtKind}; +//use rustc_ast::{Fn, ItemKind, Mutability, Stmt, Ty, TyKind, Unsafe}; +use rustc_ast::ItemKind; +use rustc_expand::base::{Annotatable, ExtCtxt}; +//use rustc_span::symbol::{kw, sym, Ident, Symbol}; +use rustc_span::symbol::sym; +use rustc_span::Span; +//use thin_vec::{thin_vec, ThinVec}; + + +pub fn expand( + ecx: &mut ExtCtxt<'_>, + _span: Span, + meta_item: &ast::MetaItem, + item: Annotatable, +) -> Vec { + check_builtin_macro_attribute(ecx, meta_item, sym::autodiff); + let orig_item = item.clone(); + // FnSig + // inner_tokens + //ItemKind::Fn(Box) + // Allow using `#[autodiff(...)]` on a Fn + let (item, _ty_span) = if let Annotatable::Item(item) = &item + && let ItemKind::Fn(box ast::Fn { sig, .. }) = &item.kind + { + (item, ecx.with_def_site_ctxt(sig.span)) + } else { + ecx.sess + .parse_sess + .span_diagnostic + .emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![orig_item]; + }; + let _span = ecx.with_def_site_ctxt(item.span); +// let f = AllocFnFactory { span, ty_span, global: item.ident, cx: ecx }; + return vec![orig_item]; +} +// +// // Generate a bunch of new items using the AllocFnFactory +// let f = AllocFnFactory { span, ty_span, global: item.ident, cx: ecx }; +// +// // Generate item statements for the allocator methods. +// let stmts = ALLOCATOR_METHODS.iter().map(|method| f.allocator_fn(method)).collect(); +// +// // Generate anonymous constant serving as container for the allocator methods. +// let const_ty = ecx.ty(ty_span, TyKind::Tup(ThinVec::new())); +// let const_body = ecx.expr_block(ecx.block(span, stmts)); +// let const_item = ecx.item_const(span, Ident::new(kw::Underscore, span), const_ty, const_body); +// let const_item = if is_stmt { +// Annotatable::Stmt(P(ecx.stmt_item(span, const_item))) +// } else { +// Annotatable::Item(const_item) +// }; +// +// // Return the original item and the new methods. +// vec![orig_item, const_item] +//} + +//struct AllocFnFactory<'a, 'b> { +// span: Span, +// ty_span: Span, +// global: Ident, +// cx: &'b ExtCtxt<'a>, +//} + +//impl AllocFnFactory<'_, '_> { +// fn allocator_fn(&self, method: &AllocatorMethod) -> Stmt { +// let mut abi_args = ThinVec::new(); +// let args = method.inputs.iter().map(|input| self.arg_ty(input, &mut abi_args)).collect(); +// let result = self.call_allocator(method.name, args); +// let output_ty = self.ret_ty(&method.output); +// let decl = self.cx.fn_decl(abi_args, ast::FnRetTy::Ty(output_ty)); +// let header = FnHeader { unsafety: Unsafe::Yes(self.span), ..FnHeader::default() }; +// let sig = FnSig { decl, header, span: self.span }; +// let body = Some(self.cx.block_expr(result)); +// let kind = ItemKind::Fn(Box::new(Fn { +// defaultness: ast::Defaultness::Final, +// sig, +// generics: Generics::default(), +// body, +// })); +// let item = self.cx.item( +// self.span, +// Ident::from_str_and_span(&global_fn_name(method.name), self.span), +// self.attrs(), +// kind, +// ); +// self.cx.stmt_item(self.ty_span, item) +// } +// +// fn call_allocator(&self, method: Symbol, mut args: ThinVec>) -> P { +// let method = self.cx.std_path(&[sym::alloc, sym::GlobalAlloc, method]); +// let method = self.cx.expr_path(self.cx.path(self.ty_span, method)); +// let allocator = self.cx.path_ident(self.ty_span, self.global); +// let allocator = self.cx.expr_path(allocator); +// let allocator = self.cx.expr_addr_of(self.ty_span, allocator); +// args.insert(0, allocator); +// +// self.cx.expr_call(self.ty_span, method, args) +// } +// +// fn attrs(&self) -> AttrVec { +// thin_vec![self.cx.attr_word(sym::rustc_std_internal_symbol, self.span)] +// } +// +// fn arg_ty(&self, input: &AllocatorMethodInput, args: &mut ThinVec) -> P { +// match input.ty { +// AllocatorTy::Layout => { +// // If an allocator method is ever introduced having multiple +// // Layout arguments, these argument names need to be +// // disambiguated somehow. Currently the generated code would +// // fail to compile with "identifier is bound more than once in +// // this parameter list". +// let size = Ident::from_str_and_span("size", self.span); +// let align = Ident::from_str_and_span("align", self.span); +// +// let usize = self.cx.path_ident(self.span, Ident::new(sym::usize, self.span)); +// let ty_usize = self.cx.ty_path(usize); +// args.push(self.cx.param(self.span, size, ty_usize.clone())); +// args.push(self.cx.param(self.span, align, ty_usize)); +// +// let layout_new = +// self.cx.std_path(&[sym::alloc, sym::Layout, sym::from_size_align_unchecked]); +// let layout_new = self.cx.expr_path(self.cx.path(self.span, layout_new)); +// let size = self.cx.expr_ident(self.span, size); +// let align = self.cx.expr_ident(self.span, align); +// let layout = self.cx.expr_call(self.span, layout_new, thin_vec![size, align]); +// layout +// } +// +// AllocatorTy::Ptr => { +// let ident = Ident::from_str_and_span(input.name, self.span); +// args.push(self.cx.param(self.span, ident, self.ptr_u8())); +// self.cx.expr_ident(self.span, ident) +// } +// +// AllocatorTy::Usize => { +// let ident = Ident::from_str_and_span(input.name, self.span); +// args.push(self.cx.param(self.span, ident, self.usize())); +// self.cx.expr_ident(self.span, ident) +// } +// +// AllocatorTy::ResultPtr | AllocatorTy::Unit => { +// panic!("can't convert AllocatorTy to an argument") +// } +// } +// } +// +// fn ret_ty(&self, ty: &AllocatorTy) -> P { +// match *ty { +// AllocatorTy::ResultPtr => self.ptr_u8(), +// +// AllocatorTy::Unit => self.cx.ty(self.span, TyKind::Tup(ThinVec::new())), +// +// AllocatorTy::Layout | AllocatorTy::Usize | AllocatorTy::Ptr => { +// panic!("can't convert `AllocatorTy` to an output") +// } +// } +// } +// +// fn usize(&self) -> P { +// let usize = self.cx.path_ident(self.span, Ident::new(sym::usize, self.span)); +// self.cx.ty_path(usize) +// } +// +// fn ptr_u8(&self) -> P { +// let u8 = self.cx.path_ident(self.span, Ident::new(sym::u8, self.span)); +// let ty_u8 = self.cx.ty_path(u8); +// self.cx.ty_ptr(self.span, ty_u8, Mutability::Mut) +// } +//} diff --git a/compiler/rustc_builtin_macros/src/errors.rs b/compiler/rustc_builtin_macros/src/errors.rs index fde4270334b67..75c42fcee350e 100644 --- a/compiler/rustc_builtin_macros/src/errors.rs +++ b/compiler/rustc_builtin_macros/src/errors.rs @@ -157,6 +157,13 @@ pub(crate) struct TestArgs { pub(crate) span: Span, } +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff)] +pub(crate) struct AutoDiffInvalidApplication { + #[primary_span] + pub(crate) span: Span, +} + #[derive(Diagnostic)] #[diag(builtin_macros_alloc_must_statics)] pub(crate) struct AllocMustStatics { diff --git a/compiler/rustc_builtin_macros/src/lib.rs b/compiler/rustc_builtin_macros/src/lib.rs index d84742c9b8293..e6d22709c83ff 100644 --- a/compiler/rustc_builtin_macros/src/lib.rs +++ b/compiler/rustc_builtin_macros/src/lib.rs @@ -30,6 +30,7 @@ use rustc_fluent_macro::fluent_messages; use rustc_span::symbol::sym; mod alloc_error_handler; +mod autodiff; mod assert; mod cfg; mod cfg_accessible; @@ -107,6 +108,7 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) { register_attr! { alloc_error_handler: alloc_error_handler::expand, + autodiff: autodiff::expand, bench: test::expand_bench, cfg_accessible: cfg_accessible::Expander, cfg_eval: cfg_eval::expand, diff --git a/compiler/rustc_codegen_llvm/messages.ftl b/compiler/rustc_codegen_llvm/messages.ftl index c0cfe39f1e0a4..efc4abc719ade 100644 --- a/compiler/rustc_codegen_llvm/messages.ftl +++ b/compiler/rustc_codegen_llvm/messages.ftl @@ -60,6 +60,9 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO codegen_llvm_run_passes = failed to run LLVM passes codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err} +codegen_llvm_prepare_autodiff = failed to prepare AutoDiff: src: {$src}, target: {$target}, {$error} +codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare AutoDiff: {$llvm_err}, src: {$src}, target: {$target}, {$error} + codegen_llvm_sanitizer_memtag_requires_mte = `-Zsanitizer=memtag` requires `-Ctarget-feature=+mte` diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs index b6c01545f308c..a586559016cd5 100644 --- a/compiler/rustc_codegen_llvm/src/attributes.rs +++ b/compiler/rustc_codegen_llvm/src/attributes.rs @@ -285,6 +285,7 @@ pub fn from_fn_attrs<'ll, 'tcx>( instance: ty::Instance<'tcx>, ) { let codegen_fn_attrs = cx.tcx.codegen_fn_attrs(instance.def_id()); + let autodiff_attrs = cx.tcx.autodiff_attrs(instance.def_id()); let mut to_add = SmallVec::<[_; 16]>::new(); @@ -302,6 +303,8 @@ pub fn from_fn_attrs<'ll, 'tcx>( let inline = if codegen_fn_attrs.inline == InlineAttr::None && instance.def.requires_inline(cx.tcx) { InlineAttr::Hint + } else if autodiff_attrs.is_active() { + InlineAttr::Never } else { codegen_fn_attrs.inline }; diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 8655aeec13dd6..a1b546b1922f2 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -734,7 +734,7 @@ pub unsafe fn optimize_thin_module( let llcx = llvm::LLVMRustContextCreate(cgcx.fewer_names); let llmod_raw = parse_module(llcx, module_name, thin_module.data(), &diag_handler)? as *const _; let mut module = ModuleCodegen { - module_llvm: ModuleLlvm { llmod_raw, llcx, tm }, + module_llvm: ModuleLlvm { llmod_raw, llcx, tm, typetrees: Default::default() }, name: thin_module.name().to_string(), kind: ModuleKind::Regular, }; diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 9d5204034def0..a40da5f5c0606 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -10,11 +10,24 @@ use crate::errors::{ WithLlvmError, WriteBytecode, }; use crate::llvm::{self, DiagnosticInfo, PassManager}; +use crate::llvm::{LLVMReplaceAllUsesWith, LLVMVerifyFunction, Value}; use crate::llvm_util; use crate::type_::Type; +use crate::typetree::to_enzyme_typetree; use crate::LlvmCodegenBackend; use crate::ModuleLlvm; -use llvm::{ +use crate::DiffTypeTree; +use llvm::{LLVMRustGetEnumAttributeAtIndex, LLVMRustAddEnumAttributeAtIndex, LLVMRustRemoveEnumAttributeAtIndex, + enzyme_rust_forward_diff, enzyme_rust_reverse_diff, BasicBlock, CreateEnzymeLogic, + CreateTypeAnalysis, EnzymeLogicRef, EnzymeTypeAnalysisRef, LLVMAddFunction, + LLVMAppendBasicBlockInContext, LLVMBuildCall2, LLVMBuildExtractValue, LLVMBuildRet, + LLVMCountParams, LLVMCountStructElementTypes, LLVMCreateBuilderInContext, LLVMDeleteFunction, + LLVMDisposeBuilder, LLVMGetBasicBlockTerminator, LLVMGetElementType, LLVMGetModuleContext, + LLVMGetParams, LLVMGetReturnType, LLVMPositionBuilderAtEnd, LLVMSetValueName2, LLVMTypeOf, + LLVMVoidTypeInContext, LLVMGlobalGetValueType, LLVMGetStringAttributeAtIndex, + LLVMIsStringAttribute, LLVMRemoveStringAttributeAtIndex, AttributeKind, + LLVMGetFirstFunction, LLVMGetNextFunction, LLVMIsEnumAttribute, + LLVMCreateStringAttribute, LLVMRustAddFunctionAttributes, LLVMDumpModule, LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, }; use rustc_codegen_ssa::back::link::ensure_removed; @@ -24,10 +37,12 @@ use rustc_codegen_ssa::back::write::{ }; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{CompiledModule, ModuleCodegen}; +use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::profiling::SelfProfilerRef; use rustc_data_structures::small_c_str::SmallCStr; use rustc_errors::{FatalError, Handler, Level}; use rustc_fs_util::{link_or_copy, path_to_c_string}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode}; use rustc_middle::ty::TyCtxt; use rustc_session::config::{self, Lto, OutputType, Passes, SplitDwarfKind, SwitchWithOptPath}; use rustc_session::Session; @@ -37,7 +52,7 @@ use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo}; use crate::llvm::diagnostic::OptimizationDiagnosticKind; use libc::{c_char, c_int, c_uint, c_void, size_t}; -use std::ffi::CString; +use std::ffi::{CStr, CString}; use std::fs; use std::io::{self, Write}; use std::path::{Path, PathBuf}; @@ -513,8 +528,18 @@ pub(crate) unsafe fn llvm_optimize( opt_level: config::OptLevel, opt_stage: llvm::OptStage, ) -> Result<(), FatalError> { - let unroll_loops = - opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + // Enzyme: + // We want to simplify / optimize functions before AD. + // However, benchmarks show that optimizations increasing the code size + // tend to reduce AD performance. Therefore activate them first, then differentiate the code + // and finally re-optimize the module, now with all optimizations available. + // RIP compile time. + // let unroll_loops = + // opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + let unroll_loops = false; + let vectorize_slp = false; + let vectorize_loop = false; + let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let pgo_gen_path = get_pgo_gen_path(config); let pgo_use_path = get_pgo_use_path(config); @@ -569,8 +594,8 @@ pub(crate) unsafe fn llvm_optimize( using_thin_buffers, config.merge_functions, unroll_loops, - config.vectorize_slp, - config.vectorize_loop, + vectorize_slp, + vectorize_loop, config.no_builtins, config.emit_lifetime_markers, sanitizer_options.as_ref(), @@ -592,6 +617,257 @@ pub(crate) unsafe fn llvm_optimize( result.into_result().map_err(|()| llvm_err(diag_handler, LlvmError::RunLlvmPasses)) } +fn get_params(fnc: &Value) -> Vec<&Value> { + unsafe { + let param_num = LLVMCountParams(fnc) as usize; + let mut fnc_args: Vec<&Value> = vec![]; + fnc_args.reserve(param_num); + LLVMGetParams(fnc, fnc_args.as_mut_ptr()); + fnc_args.set_len(param_num); + fnc_args + } +} + +// TODO: Here we could start adding length checks for the shaddow args. +unsafe fn create_wrapper<'a>( + llmod: &'a llvm::Module, + fnc: &'a Value, + u_type: &Type, + fnc_name: String, +) -> (&'a Value, &'a BasicBlock, Vec<&'a Value>, Vec<&'a Value>, CString) { + let context = LLVMGetModuleContext(llmod); + let inner_fnc_name = "inner_".to_string() + &fnc_name; + let c_inner_fnc_name = CString::new(inner_fnc_name.clone()).unwrap(); + LLVMSetValueName2(fnc, c_inner_fnc_name.as_ptr(), inner_fnc_name.len() as usize); + + let c_outer_fnc_name = CString::new(fnc_name).unwrap(); + let outer_fnc: &Value = + LLVMAddFunction(llmod, c_outer_fnc_name.as_ptr(), LLVMGetElementType(u_type) as &Type); + + let entry = "fnc_entry".to_string(); + let c_entry = CString::new(entry).unwrap(); + let basic_block = LLVMAppendBasicBlockInContext(context, outer_fnc, c_entry.as_ptr()); + + let outer_params: Vec<&Value> = get_params(outer_fnc); + let inner_params: Vec<&Value> = get_params(fnc); + + (outer_fnc, basic_block, outer_params, inner_params, c_inner_fnc_name) +} + +pub(crate) unsafe fn extract_return_type<'a>( + llmod: &'a llvm::Module, + fnc: &'a Value, + u_type: &Type, + fnc_name: String, +) -> &'a Value { + let context = llvm::LLVMGetModuleContext(llmod); + + let inner_param_num = LLVMCountParams(fnc); + let (outer_fnc, outer_bb, mut outer_args, _inner_args, c_inner_fnc_name) = + create_wrapper(llmod, fnc, u_type, fnc_name); + + if inner_param_num as usize != outer_args.len() { + panic!("Args len shouldn't differ. Please report this."); + } + + let builder = LLVMCreateBuilderInContext(context); + LLVMPositionBuilderAtEnd(builder, outer_bb); + let struct_ret = LLVMBuildCall2( + builder, + u_type, + fnc, + outer_args.as_mut_ptr(), + outer_args.len(), + c_inner_fnc_name.as_ptr(), + ); + // We can use an arbitrary name here, since it will be used to store a tmp value. + let inner_grad_name = "foo".to_string(); + let c_inner_grad_name = CString::new(inner_grad_name).unwrap(); + let struct_ret = LLVMBuildExtractValue(builder, struct_ret, 0, c_inner_grad_name.as_ptr()); + let _ret = LLVMBuildRet(builder, struct_ret); + let _terminator = LLVMGetBasicBlockTerminator(outer_bb); + LLVMDisposeBuilder(builder); + let _fnc_ok = + LLVMVerifyFunction(outer_fnc, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); + outer_fnc +} + +// As unsafe as it can be. +#[allow(unused_variables)] +#[allow(unused)] +pub(crate) unsafe fn enzyme_ad( + llmod: &llvm::Module, + llcx: &llvm::Context, + diag_handler: &rustc_errors::Handler, + item: AutoDiffItem, +) -> Result<(), FatalError> { + let autodiff_mode = item.attrs.mode; + let rust_name = item.source; + let rust_name2 = &item.target; + + let args_activity = item.attrs.input_activity.clone(); + let ret_activity: DiffActivity = item.attrs.ret_activity; + + // get target and source function + let name = CString::new(rust_name.to_owned()).unwrap(); + let name2 = CString::new(rust_name2.clone()).unwrap(); + let src_fnc_opt = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr()); + let src_fnc = match src_fnc_opt { + Some(x) => x, + None => { + return Err(llvm_err(diag_handler, LlvmError::PrepareAutoDiff{ + src: rust_name.to_owned(), + target: rust_name2.to_owned(), + error: "could not find src function".to_owned(), + })); + } + }; + let target_fnc_opt = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()); + let target_fnc = match target_fnc_opt { + Some(x) => x, + None => { + return Err(llvm_err(diag_handler, LlvmError::PrepareAutoDiff{ + src: rust_name.to_owned(), + target: rust_name2.to_owned(), + error: "could not find target function".to_owned(), + })); + } + }; + let src_num_args = llvm::LLVMCountParams(src_fnc); + let target_num_args = llvm::LLVMCountParams(target_fnc); + assert!(src_num_args <= target_num_args); + + // create enzyme typetrees + let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + + let input_tts = + item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect(); + let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); + + let opt = 1; + let ret_primary_ret = false; + let diff_primary_ret = false; + let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(opt as u8); + let type_analysis: EnzymeTypeAnalysisRef = + CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0); + + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); + + if std::env::var("ENZYME_PRINT_TA").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintType), 1); + } + if std::env::var("ENZYME_PRINT_AA").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintActivity), 1); + } + if std::env::var("ENZYME_PRINT_PERF").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintPerf), 1); + } + if std::env::var("ENZYME_PRINT").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrint), 1); + } + + let mut res: &Value = match item.attrs.mode { + DiffMode::Forward => enzyme_rust_forward_diff( + logic_ref, + type_analysis, + src_fnc, + args_activity, + ret_activity, + ret_primary_ret, + input_tts, + output_tt, + ), + DiffMode::Reverse => enzyme_rust_reverse_diff( + logic_ref, + type_analysis, + src_fnc, + args_activity, + ret_activity, + ret_primary_ret, + diff_primary_ret, + input_tts, + output_tt, + ), + _ => unreachable!(), + }; + let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res)); + + let void_type = LLVMVoidTypeInContext(llcx); + if item.attrs.mode == DiffMode::Reverse && f_return_type != void_type { + let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); + if num_elem_in_ret_struct == 1 { + let u_type = LLVMTypeOf(target_fnc); + res = extract_return_type(llmod, res, u_type, rust_name2.clone()); // TODO: check if name or name2 + } + } + LLVMSetValueName2(res, name2.as_ptr(), rust_name2.len()); + LLVMReplaceAllUsesWith(target_fnc, res); + LLVMDeleteFunction(target_fnc); + + Ok(()) +} + +pub(crate) unsafe fn differentiate( + module: &ModuleCodegen, + cgcx: &CodegenContext, + diff_items: Vec, + _typetrees: FxHashMap, + _config: &ModuleConfig, +) -> Result<(), FatalError> { + let llmod = module.module_llvm.llmod(); + let llcx = &module.module_llvm.llcx; + let diag_handler = cgcx.create_diag_handler(); + + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); + + if std::env::var("ENZYME_PRINT_MOD").is_ok() { + unsafe {LLVMDumpModule(llmod);} + } + if std::env::var("ENZYME_TT_DEPTH").is_ok() { + let depth = std::env::var("ENZYME_TT_DEPTH").unwrap(); + let depth = depth.parse::().unwrap(); + assert!(depth >= 1); + llvm::EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::EnzymeMaxTypeDepth), depth); + } + if std::env::var("ENZYME_TT_WIDTH").is_ok() { + let width = std::env::var("ENZYME_TT_WIDTH").unwrap(); + let width = width.parse::().unwrap(); + assert!(width >= 1); + llvm::EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::MaxTypeOffset), width); + } + + for item in diff_items { + let res = enzyme_ad(llmod, llcx, &diag_handler, item); + assert!(res.is_ok()); + } + + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let attr = LLVMGetStringAttributeAtIndex(lf, c_uint::MAX, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint); + if LLVMIsStringAttribute(attr) { + LLVMRemoveStringAttributeAtIndex(lf, c_uint::MAX, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint); + } else { + LLVMRustRemoveEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); + } + + + } else { + break; + } + } + if std::env::var("ENZYME_PRINT_MOD_AFTER").is_ok() { + unsafe {LLVMDumpModule(llmod);} + } + + Ok(()) +} + // Unsafe due to LLVM calls. pub(crate) unsafe fn optimize( cgcx: &CodegenContext, @@ -615,6 +891,27 @@ pub(crate) unsafe fn optimize( llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()); } + { + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let myhwv = ""; + let prevattr = LLVMRustGetEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); + if LLVMIsEnumAttribute(prevattr) { + let attr = LLVMCreateStringAttribute(llcx, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint, myhwv.as_ptr() as *const c_char, myhwv.as_bytes().len() as c_uint); + LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); + } else { + LLVMRustAddEnumAttributeAtIndex(llcx, lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); + } + + } else { + break; + } + } + } + if let Some(opt_level) = config.opt_level { let opt_stage = match cgcx.lto { Lto::Fat => llvm::OptStage::PreLinkFatLTO, diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index b4b2ab1e1f8a9..2d16649ce17d2 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -624,6 +624,10 @@ impl<'ll, 'tcx> MiscMethods<'tcx> for CodegenCx<'ll, 'tcx> { None } } + + fn create_autodiff(&self) -> Vec { + return vec![]; + } } impl<'ll> CodegenCx<'ll, '_> { diff --git a/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs b/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs index 274e0aeaaba4f..13b5f02ab6e7c 100644 --- a/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs +++ b/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs @@ -409,7 +409,7 @@ fn add_unused_functions(cx: &CodegenCx<'_, '_>) { /// All items participating in code generation together with (instrumented) /// items inlined into them. fn codegenned_and_inlined_items(tcx: TyCtxt<'_>) -> DefIdSet { - let (items, cgus) = tcx.collect_and_partition_mono_items(()); + let (items, _, cgus) = tcx.collect_and_partition_mono_items(()); let mut visited = DefIdSet::default(); let mut result = items.clone(); diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index 665d195790c2d..6d317a400d4be 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -172,6 +172,12 @@ pub enum LlvmError<'a> { PrepareThinLtoModule, #[diag(codegen_llvm_parse_bitcode)] ParseBitcode, + #[diag(codegen_llvm_prepare_autodiff)] + PrepareAutoDiff { + src: String, + target: String, + error: String, + } } pub(crate) struct WithLlvmError<'a>(pub LlvmError<'a>, pub String); @@ -193,6 +199,7 @@ impl IntoDiagnostic<'_, EM> for WithLlvmError<'_> { } PrepareThinLtoModule => fluent::codegen_llvm_prepare_thin_lto_module_with_llvm_err, ParseBitcode => fluent::codegen_llvm_parse_bitcode_with_llvm_err, + PrepareAutoDiff { .. } => fluent::codegen_llvm_prepare_autodiff_with_llvm_err, }; let mut diag = self.0.into_diagnostic(sess); diag.set_primary_message(msg_with_llvm_err); diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 8a6a5f79b3bb9..5b1b2fcd58b3d 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -30,6 +30,7 @@ use back::owned_target_machine::OwnedTargetMachine; use back::write::{create_informational_target_machine, create_target_machine}; use errors::ParseTargetMachineConfig; +use llvm::TypeTree; pub use llvm_util::target_features; use rustc_ast::expand::allocator::AllocatorKind; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; @@ -39,11 +40,12 @@ use rustc_codegen_ssa::back::write::{ use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::{CodegenResults, CompiledModule}; -use rustc_data_structures::fx::FxIndexMap; +use rustc_data_structures::fx::{FxHashMap, FxIndexMap}; use rustc_errors::{DiagnosticMessage, ErrorGuaranteed, FatalError, Handler, SubdiagnosticMessage}; use rustc_fluent_macro::fluent_messages; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_middle::ty::TyCtxt; use rustc_middle::util::Providers; use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; @@ -77,6 +79,7 @@ mod debuginfo; mod declare; mod errors; mod intrinsic; +mod typetree; // The following is a workaround that replaces `pub mod llvm;` and that fixes issue 53912. #[path = "llvm/mod.rs"] @@ -172,6 +175,8 @@ impl WriteBackendMethods for LlvmCodegenBackend { type TargetMachineError = crate::errors::LlvmError<'static>; type ThinData = back::lto::ThinData; type ThinBuffer = back::lto::ThinBuffer; + type TypeTree = DiffTypeTree; + fn print_pass_timings(&self) { unsafe { let mut size = 0; @@ -254,6 +259,20 @@ impl WriteBackendMethods for LlvmCodegenBackend { fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer) { (module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod())) } + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result<(), FatalError> { + unsafe { back::write::differentiate(module, cgcx, diff_fncs, typetrees, config) } + } + + fn typetrees(module: &mut Self::Module) -> FxHashMap { + module.typetrees.drain().collect() + } } unsafe impl Send for LlvmCodegenBackend {} // Llvm is on a per-thread basis @@ -404,12 +423,20 @@ impl CodegenBackend for LlvmCodegenBackend { } } +#[derive(Clone, Debug)] +pub struct DiffTypeTree { + pub ret_tt: TypeTree, + pub input_tt: Vec, +} + +#[allow(dead_code)] pub struct ModuleLlvm { llcx: &'static mut llvm::Context, llmod_raw: *const llvm::Module, // independent from llcx and llmod_raw, resources get disposed by drop impl tm: OwnedTargetMachine, + typetrees: FxHashMap, } unsafe impl Send for ModuleLlvm {} @@ -420,7 +447,12 @@ impl ModuleLlvm { unsafe { let llcx = llvm::LLVMRustContextCreate(tcx.sess.fewer_names()); let llmod_raw = context::create_module(tcx, llcx, mod_name) as *const _; - ModuleLlvm { llmod_raw, llcx, tm: create_target_machine(tcx, mod_name) } + ModuleLlvm { + llmod_raw, + llcx, + tm: create_target_machine(tcx, mod_name), + typetrees: Default::default(), + } } } @@ -428,7 +460,12 @@ impl ModuleLlvm { unsafe { let llcx = llvm::LLVMRustContextCreate(tcx.sess.fewer_names()); let llmod_raw = context::create_module(tcx, llcx, mod_name) as *const _; - ModuleLlvm { llmod_raw, llcx, tm: create_informational_target_machine(tcx.sess) } + ModuleLlvm { + llmod_raw, + llcx, + tm: create_informational_target_machine(tcx.sess), + typetrees: Default::default(), + } } } @@ -449,7 +486,7 @@ impl ModuleLlvm { } }; - Ok(ModuleLlvm { llmod_raw, llcx, tm }) + Ok(ModuleLlvm { llmod_raw, llcx, tm, typetrees: Default::default() }) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index a038b3af03dd6..026d12121e535 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1,6 +1,8 @@ #![allow(non_camel_case_types)] #![allow(non_upper_case_globals)] +use rustc_ast::expand::autodiff_attrs::DiffActivity; + use super::debuginfo::{ DIArray, DIBasicType, DIBuilder, DICompositeType, DIDerivedType, DIDescriptor, DIEnumerator, DIFile, DIFlags, DIGlobalVariableExpression, DILexicalBlock, DILocation, DINameSpace, @@ -11,6 +13,8 @@ use super::debuginfo::{ use libc::{c_char, c_int, c_uint, size_t}; use libc::{c_ulonglong, c_void}; +use core::fmt; +use std::ffi::{CStr, CString}; use std::marker::PhantomData; use super::RustString; @@ -819,10 +823,197 @@ pub type SelfProfileBeforePassCallback = unsafe extern "C" fn(*mut c_void, *const c_char, *const c_char); pub type SelfProfileAfterPassCallback = unsafe extern "C" fn(*mut c_void); +#[repr(C)] +pub enum LLVMVerifierFailureAction { + LLVMAbortProcessAction, + LLVMPrintMessageAction, + LLVMReturnStatusAction, +} + +pub(crate) unsafe fn enzyme_rust_forward_diff( + logic_ref: EnzymeLogicRef, + type_analysis: EnzymeTypeAnalysisRef, + fnc: &Value, + input_diffactivity: Vec, + ret_diffactivity: DiffActivity, + mut ret_primary_ret: bool, + input_tts: Vec, + output_tt: TypeTree, +) -> &Value { + let ret_activity = cdiffe_from(ret_diffactivity); + assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF); + let mut input_activity: Vec = vec![]; + for input in input_diffactivity { + let act = cdiffe_from(input); + assert!(act == CDIFFE_TYPE::DFT_CONSTANT || act == CDIFFE_TYPE::DFT_DUP_ARG || act == CDIFFE_TYPE::DFT_DUP_NONEED); + input_activity.push(act); + } + + if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG { + if ret_primary_ret != true { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = true; + } else if ret_activity == CDIFFE_TYPE::DFT_DUP_NONEED { + if ret_primary_ret != false { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = false; + } + + let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + //let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()]; + + // We don't support volatile / extern / (global?) values. + // Just because I didn't had time to test them, and it seems less urgent. + let args_uncacheable = vec![0; input_activity.len()]; + + let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; + + let mut known_values = vec![kv_tmp; input_activity.len()]; + + let dummy_type = CFnTypeInfo { + Arguments: args_tree.as_mut_ptr(), + Return: output_tt.inner.clone(), + KnownValues: known_values.as_mut_ptr(), + }; + + EnzymeCreateForwardDiff( + logic_ref, // Logic + std::ptr::null(), + std::ptr::null(), + fnc, + ret_activity, // LLVM function, return type + input_activity.as_ptr(), + input_activity.len(), // constant arguments + type_analysis, // type analysis struct + ret_primary_ret as u8, + CDerivativeMode::DEM_ForwardMode, // return value, dret_used, top_level which was 1 + 1, // free memory + 1, // vector mode width + Option::None, + dummy_type, // additional_arg, type info (return + args) + args_uncacheable.as_ptr(), + args_uncacheable.len(), // uncacheable arguments + std::ptr::null_mut(), // write augmented function to this + ) +} + +pub(crate) unsafe fn enzyme_rust_reverse_diff( + logic_ref: EnzymeLogicRef, + type_analysis: EnzymeTypeAnalysisRef, + fnc: &Value, + input_activity: Vec, + ret_activity: DiffActivity, + mut ret_primary_ret: bool, + diff_primary_ret: bool, + input_tts: Vec, + output_tt: TypeTree, +) -> &Value { + let ret_activity = cdiffe_from(ret_activity); + assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF); + let input_activity: Vec = input_activity.iter().map(|&x| cdiffe_from(x)).collect(); + + dbg!(&fnc); + + if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG { + if ret_primary_ret != true { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = true; + } else if ret_activity == CDIFFE_TYPE::DFT_DUP_NONEED { + if ret_primary_ret != false { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = false; + } + + let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + + // We don't support volatile / extern / (global?) values. + // Just because I didn't had time to test them, and it seems less urgent. + let args_uncacheable = vec![0; input_tts.len()]; + assert!(args_uncacheable.len() == input_activity.len()); + let num_fnc_args = LLVMCountParams(fnc); + println!("num_fnc_args: {}", num_fnc_args); + println!("input_activity.len(): {}", input_activity.len()); + assert!(num_fnc_args == input_activity.len() as u32); + let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; + + + let mut known_values = vec![kv_tmp; input_tts.len()]; + + let dummy_type = CFnTypeInfo { + Arguments: args_tree.as_mut_ptr(), + Return: output_tt.inner.clone(), + KnownValues: known_values.as_mut_ptr(), + }; + + let res = EnzymeCreatePrimalAndGradient( + logic_ref, // Logic + std::ptr::null(), + std::ptr::null(), + fnc, + ret_activity, // LLVM function, return type + input_activity.as_ptr(), + input_activity.len(), // constant arguments + type_analysis, // type analysis struct + ret_primary_ret as u8, + diff_primary_ret as u8, //0 + CDerivativeMode::DEM_ReverseModeCombined, // return value, dret_used, top_level which was 1 + 1, // vector mode width + 1, // free memory + Option::None, + 0, // do not force anonymous tape + dummy_type, // additional_arg, type info (return + args) + args_uncacheable.as_ptr(), + args_uncacheable.len(), // uncacheable arguments + std::ptr::null_mut(), // write augmented function to this + 0, + ); + dbg!(&res); + res +} pub type GetSymbolsCallback = unsafe extern "C" fn(*mut c_void, *const c_char) -> *mut c_void; pub type GetSymbolsErrorCallback = unsafe extern "C" fn(*const c_char) -> *mut c_void; extern "C" { + + // Enzyme + pub fn LLVMIsStructTy(ty: &Type) -> bool; + pub fn LLVMGetReturnType(T: &Type) -> &Type; + pub fn LLVMDumpModule(M: &Module); + pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; + pub fn LLVMDeleteFunction(V: &Value); + + pub fn LLVMCreateEnumAttribute(C : &Context, Kind: Attribute, val:u64) -> &Attribute; + pub fn LLVMRemoveStringAttributeAtIndex(F : &Value, Idx: c_uint, K: *const c_char, KLen : c_uint); + pub fn LLVMGetStringAttributeAtIndex(F : &Value, Idx: c_uint, K: *const c_char, KLen : c_uint) -> &Attribute; + + pub fn LLVMAddAttributeAtIndex(F : &Value, Idx: c_uint, K: &Attribute); + pub fn LLVMRemoveEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: Attribute); + pub fn LLVMGetEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: Attribute) -> &Attribute; + + pub fn LLVMIsEnumAttribute(A : &Attribute) -> bool; + pub fn LLVMIsStringAttribute(A : &Attribute) -> bool; + pub fn LLVMVerifyFunction(V: &Value, action: LLVMVerifierFailureAction) -> bool; + pub fn LLVMGetParams(Fnc: &Value, parms: *mut &Value); + pub fn LLVMBuildCall2<'a>( + arg1: &Builder<'a>, + ty: &Type, + func: &Value, + args: *mut &Value, + num_args: size_t, + name: *const c_char, + ) -> &'a Value; + pub fn LLVMGetBasicBlockTerminator(B: &BasicBlock) -> &Value; + pub fn LLVMAddFunction<'a>(M: &Module, Name: *const c_char, Ty: &Type) -> &'a Value; + pub fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>; + pub fn LLVMGetNextFunction(V: &Value) -> Option<&Value>; + pub fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>; + pub fn LLVMGlobalGetValueType(val: &Value) -> &Type; + + pub fn LLVMRustGetFunctionType(fnc: &Value) -> &Type; pub fn LLVMRustInstallFatalErrorHandler(); pub fn LLVMRustDisableSystemDialogsOnCrash(); @@ -1018,6 +1209,9 @@ extern "C" { // Operations on attributes pub fn LLVMRustCreateAttrNoValue(C: &Context, attr: AttributeKind) -> &Attribute; + pub fn LLVMRustAddEnumAttributeAtIndex(C: &Context, V: &Value, index: c_uint, attr: AttributeKind); + pub fn LLVMRustRemoveEnumAttributeAtIndex(V: &Value, index: c_uint, attr: AttributeKind); + pub fn LLVMRustGetEnumAttributeAtIndex(V: &Value, index: c_uint, attr: AttributeKind) ->&Attribute; pub fn LLVMCreateStringAttribute( C: &Context, Name: *const c_char, @@ -2091,6 +2285,8 @@ extern "C" { #[allow(improper_ctypes)] pub fn LLVMRustWriteTypeToString(Type: &Type, s: &RustString); #[allow(improper_ctypes)] + pub fn LLVMRustWriteValueNameToString(value_ref: &Value, s: &RustString); + #[allow(improper_ctypes)] pub fn LLVMRustWriteValueToString(value_ref: &Value, s: &RustString); pub fn LLVMIsAConstantInt(value_ref: &Value) -> Option<&ConstantInt>; @@ -2362,7 +2558,6 @@ extern "C" { remark_file: *const c_char, pgo_available: bool, ); - #[allow(improper_ctypes)] pub fn LLVMRustGetMangledName(V: &Value, out: &RustString); @@ -2382,3 +2577,301 @@ extern "C" { error_callback: GetSymbolsErrorCallback, ) -> *mut c_void; } +// Manuel +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueTypeAnalysis { + _unused: [u8; 0], +} +pub type EnzymeTypeAnalysisRef = *mut EnzymeOpaqueTypeAnalysis; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueLogic { + _unused: [u8; 0], +} +pub type EnzymeLogicRef = *mut EnzymeOpaqueLogic; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueAugmentedReturn { + _unused: [u8; 0], +} +pub type EnzymeAugmentedReturnPtr = *mut EnzymeOpaqueAugmentedReturn; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct IntList { + pub data: *mut i64, + pub size: size_t, +} +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CConcreteType { + DT_Anything = 0, + DT_Integer = 1, + DT_Pointer = 2, + DT_Half = 3, + DT_Float = 4, + DT_Double = 5, + DT_Unknown = 6, +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeTypeTree { + _unused: [u8; 0], +} +pub type CTypeTreeRef = *mut EnzymeTypeTree; +extern "C" { + fn EnzymeNewTypeTree() -> CTypeTreeRef; +} +extern "C" { + fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); +} +extern "C" { + pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); +} +extern "C" { + pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64); +} + +extern "C" { + pub static mut MaxIntOffset: c_void; + pub static mut MaxTypeOffset: c_void; + pub static mut EnzymeMaxTypeDepth: c_void; + + pub static mut EnzymePrintPerf: c_void; + pub static mut EnzymePrintActivity: c_void; + pub static mut EnzymePrintType: c_void; + pub static mut EnzymePrint: c_void; + pub static mut EnzymeStrictAliasing: c_void; +} + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct CFnTypeInfo { + #[doc = " Types of arguments, assumed of size len(Arguments)"] + pub Arguments: *mut CTypeTreeRef, + #[doc = " Type of return"] + pub Return: CTypeTreeRef, + #[doc = " The specific constant(s) known to represented by an argument, if constant"] + pub KnownValues: *mut IntList, +} +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CDIFFE_TYPE { + DFT_OUT_DIFF = 0, + DFT_DUP_ARG = 1, + DFT_CONSTANT = 2, + DFT_DUP_NONEED = 3, +} + +fn cdiffe_from(act: DiffActivity) -> CDIFFE_TYPE { + return match act { + DiffActivity::None => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF, + DiffActivity::Const => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG, + DiffActivity::DuplicatedNoNeed => CDIFFE_TYPE::DFT_DUP_NONEED, + }; +} + +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CDerivativeMode { + DEM_ForwardMode = 0, + DEM_ReverseModePrimal = 1, + DEM_ReverseModeGradient = 2, + DEM_ReverseModeCombined = 3, + DEM_ForwardModeSplit = 4, +} +extern "C" { + fn EnzymeCreatePrimalAndGradient<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8,// &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + dretUsed: u8, + mode: CDerivativeMode, + width: ::std::os::raw::c_uint, + freeMemory: u8, + additionalArg: Option<&Type>, + forceAnonymousTape: u8, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + AtomicAdd: u8, + ) -> &'a Value; + //) -> LLVMValueRef; +} +extern "C" { + fn EnzymeCreateForwardDiff<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8,// &'a Builder<'_>, + _callerCtx: *const u8,// &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + mode: CDerivativeMode, + freeMemory: u8, + width: ::std::os::raw::c_uint, + additionalArg: Option<&Type>, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + ) -> &'a Value; +} +pub type CustomRuleType = ::std::option::Option< + unsafe extern "C" fn( + direction: ::std::os::raw::c_int, + ret: CTypeTreeRef, + args: *mut CTypeTreeRef, + known_values: *mut IntList, + num_args: size_t, + fnc: &Value, + ta: *const ::std::os::raw::c_void, + ) -> u8, +>; +extern "C" { + pub fn CreateTypeAnalysis( + Log: EnzymeLogicRef, + customRuleNames: *mut *mut ::std::os::raw::c_char, + customRules: *mut CustomRuleType, + numRules: size_t, + ) -> EnzymeTypeAnalysisRef; +} +extern "C" { + pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef); +} +extern "C" { + pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef); +} +extern "C" { + pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef; +} +extern "C" { + pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef); +} +extern "C" { + pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef); +} + +extern "C" { + fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; + fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; + fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; + fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); + fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); + fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ); + fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; +} + +pub struct TypeTree { + pub inner: CTypeTreeRef, +} + +impl TypeTree { + pub fn new() -> TypeTree { + let inner = unsafe { EnzymeNewTypeTree() }; + + TypeTree { inner } + } + + #[must_use] + pub fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree { + let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) }; + + TypeTree { inner } + } + + #[must_use] + pub fn only(self, idx: isize) -> TypeTree { + unsafe { + EnzymeTypeTreeOnlyEq(self.inner, idx as i64); + } + self + } + + #[must_use] + pub fn data0(self) -> TypeTree { + unsafe { + EnzymeTypeTreeData0Eq(self.inner); + } + self + } + + pub fn merge(self, other: Self) -> Self { + unsafe { + EnzymeMergeTypeTree(self.inner, other.inner); + } + drop(other); + + self + } + + #[must_use] + pub fn shift(self, layout: &str, offset: isize, max_size: isize, add_offset: usize) -> Self { + let layout = CString::new(layout).unwrap(); + + unsafe { + EnzymeTypeTreeShiftIndiciesEq( + self.inner, + layout.as_ptr(), + offset as i64, + max_size as i64, + add_offset as u64, + ) + } + + self + } +} + +impl Clone for TypeTree { + fn clone(&self) -> Self { + let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) }; + TypeTree { inner } + } +} + +impl fmt::Display for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let ptr = unsafe { EnzymeTypeTreeToString(self.inner) }; + let cstr = unsafe { CStr::from_ptr(ptr) }; + match cstr.to_str() { + Ok(x) => write!(f, "{}", x)?, + Err(err) => write!(f, "could not parse: {}", err)?, + } + + // delete C string pointer + unsafe { EnzymeTypeTreeToStringFree(ptr) } + + Ok(()) + } +} + +impl fmt::Debug for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} + +impl Drop for TypeTree { + fn drop(&mut self) { + unsafe { EnzymeFreeTypeTree(self.inner) } + } +} diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs new file mode 100644 index 0000000000000..c45f5ec5e7005 --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -0,0 +1,33 @@ +use crate::llvm; +use rustc_ast::expand::typetree::{Kind, TypeTree}; + +pub fn to_enzyme_typetree( + tree: TypeTree, + llvm_data_layout: &str, + llcx: &llvm::Context, +) -> llvm::TypeTree { + tree.0.iter().fold(llvm::TypeTree::new(), |obj, x| { + let scalar = match x.kind { + Kind::Integer => llvm::CConcreteType::DT_Integer, + Kind::Float => llvm::CConcreteType::DT_Float, + Kind::Double => llvm::CConcreteType::DT_Double, + Kind::Pointer => llvm::CConcreteType::DT_Pointer, + _ => panic!("Unknown kind {:?}", x.kind), + }; + + let tt = llvm::TypeTree::from_type(scalar, llcx).only(-1); + + let tt = if !x.child.0.is_empty() { + let inner_tt = to_enzyme_typetree(x.child.clone(), llvm_data_layout, llcx); + tt.merge(inner_tt.only(-1)) + } else { + tt + }; + + if x.offset != -1 { + obj.merge(tt.shift(llvm_data_layout, 0, x.size as isize, x.offset as usize)) + } else { + obj.merge(tt) + } + }) +} diff --git a/compiler/rustc_codegen_ssa/src/assert_module_sources.rs b/compiler/rustc_codegen_ssa/src/assert_module_sources.rs index 16bb7b12bd3c1..a4ba7bfe7d20b 100644 --- a/compiler/rustc_codegen_ssa/src/assert_module_sources.rs +++ b/compiler/rustc_codegen_ssa/src/assert_module_sources.rs @@ -46,7 +46,7 @@ pub fn assert_module_sources(tcx: TyCtxt<'_>, set_reuse: &dyn Fn(&mut CguReuseTr } let available_cgus = - tcx.collect_and_partition_mono_items(()).1.iter().map(|cgu| cgu.name()).collect(); + tcx.collect_and_partition_mono_items(()).2.iter().map(|cgu| cgu.name()).collect(); let mut ams = AssertModuleSource { tcx, diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index cb6244050df24..b8f75f71a112a 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -1,9 +1,11 @@ use super::write::CodegenContext; +use crate::back::write::ModuleConfig; use crate::traits::*; use crate::ModuleCodegen; -use rustc_data_structures::memmap::Mmap; +use rustc_data_structures::{fx::FxHashMap, memmap::Mmap}; use rustc_errors::FatalError; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use std::ffi::CString; use std::sync::Arc; @@ -76,6 +78,26 @@ impl LtoModuleCodegen { } } + /// Run autodiff on Fat LTO module + pub unsafe fn autodiff( + self, + cgcx: &CodegenContext, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result, FatalError> { + match &self { + LtoModuleCodegen::Fat { ref module, .. } => { + { + B::autodiff(cgcx, &module, diff_fncs, typetrees, config)?; + } + }, + _ => {}, + } + + Ok(self) + } + /// A "gauge" of how costly it is to optimize this module, used to sort /// biggest modules first. pub fn cost(&self) -> u64 { diff --git a/compiler/rustc_codegen_ssa/src/back/symbol_export.rs b/compiler/rustc_codegen_ssa/src/back/symbol_export.rs index 9cd4394108a4a..5fd525dd56e03 100644 --- a/compiler/rustc_codegen_ssa/src/back/symbol_export.rs +++ b/compiler/rustc_codegen_ssa/src/back/symbol_export.rs @@ -317,7 +317,7 @@ fn exported_symbols_provider_local( // external linkage is enough for monomorphization to be linked to. let need_visibility = tcx.sess.target.dynamic_linking && !tcx.sess.target.only_cdylib; - let (_, cgus) = tcx.collect_and_partition_mono_items(()); + let (_, _, cgus) = tcx.collect_and_partition_mono_items(()); for (mono_item, data) in cgus.iter().flat_map(|cgu| cgu.items().iter()) { if data.linkage != Linkage::External { diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 3d6a212433463..649421bd3929c 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -24,6 +24,7 @@ use rustc_incremental::{ use rustc_metadata::fs::copy_to_stdout; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_middle::middle::exported_symbols::SymbolExportInfo; use rustc_middle::ty::TyCtxt; use rustc_session::config::{self, CrateType, Lto, OutFileName, OutputFilenames, OutputType}; @@ -385,6 +386,8 @@ impl CodegenContext { fn generate_lto_work( cgcx: &CodegenContext, + autodiff: Vec, + typetrees: FxHashMap, needs_fat_lto: Vec>, needs_thin_lto: Vec<(String, B::ThinBuffer)>, import_only_modules: Vec<(SerializedModule, WorkProduct)>, @@ -393,15 +396,19 @@ fn generate_lto_work( if !needs_fat_lto.is_empty() { assert!(needs_thin_lto.is_empty()); - let module = + let mut module = B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise()); + if cgcx.lto == Lto::Fat { + let config = cgcx.config(ModuleKind::Regular); + module = unsafe { module.autodiff(cgcx, autodiff, typetrees, config).unwrap() }; + } // We are adding a single work item, so the cost doesn't matter. vec![(WorkItem::LTO(module), 0)] } else { assert!(needs_fat_lto.is_empty()); - let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules) + let (modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules) .unwrap_or_else(|e| e.raise()); - lto_modules + modules .into_iter() .map(|module| { let cost = module.cost(); @@ -985,6 +992,8 @@ pub(crate) enum Message { work_product: WorkProduct, }, + AddAutoDiffItems(Vec), + /// The frontend has finished generating everything for all codegen units. /// Sent from the main thread. CodegenComplete, @@ -1287,6 +1296,8 @@ fn start_executing_work( let mut needs_link = Vec::new(); let mut needs_fat_lto = Vec::new(); let mut needs_thin_lto = Vec::new(); + let mut autodiff_items = Vec::new(); + let mut typetrees = FxHashMap::::default(); let mut lto_import_only_modules = Vec::new(); let mut started_lto = false; @@ -1393,9 +1404,14 @@ fn start_executing_work( let needs_thin_lto = mem::take(&mut needs_thin_lto); let import_only_modules = mem::take(&mut lto_import_only_modules); - for (work, cost) in - generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules) - { + for (work, cost) in generate_lto_work( + &cgcx, + autodiff_items.clone(), + typetrees.clone(), + needs_fat_lto, + needs_thin_lto, + import_only_modules, + ) { let insertion_index = work_items .binary_search_by_key(&cost, |&(_, cost)| cost) .unwrap_or_else(|e| e); @@ -1508,7 +1524,16 @@ fn start_executing_work( } } - Message::CodegenDone { llvm_work_item, cost } => { + Message::CodegenDone { mut llvm_work_item, cost } => { + //// extract build typetrees + match &mut llvm_work_item { + WorkItem::Optimize(module) => { + let tt = B::typetrees(&mut module.module_llvm); + typetrees.extend(tt); + } + _ => {}, + } + // We keep the queue sorted by estimated processing cost, // so that more expensive items are processed earlier. This // is good for throughput as it gives the main thread more @@ -1549,6 +1574,10 @@ fn start_executing_work( codegen_state = Aborted; } + Message::AddAutoDiffItems(mut items) => { + autodiff_items.append(&mut items); + } + Message::WorkItem { result, worker_id } => { free_worker(worker_id); @@ -2000,6 +2029,10 @@ impl OngoingCodegen { drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::))); } + pub fn submit_autodiff_items(&self, items: Vec) { + drop(self.coordinator.sender.send(Box::new(Message::::AddAutoDiffItems(items)))); + } + pub fn check_for_errors(&self, sess: &Session) { self.shared_emitter_main.check(sess, false); } diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs index 198e5696357af..a8641ba9fbb30 100644 --- a/compiler/rustc_codegen_ssa/src/base.rs +++ b/compiler/rustc_codegen_ssa/src/base.rs @@ -590,7 +590,8 @@ pub fn codegen_crate( // Run the monomorphization collector and partition the collected items into // codegen units. - let codegen_units = tcx.collect_and_partition_mono_items(()).1; + let (_, autodiff_fncs, codegen_units) = tcx.collect_and_partition_mono_items(()); + let autodiff_fncs = autodiff_fncs.to_vec(); // Force all codegen_unit queries so they are already either red or green // when compile_codegen_unit accesses them. We are not able to re-execute @@ -659,6 +660,10 @@ pub fn codegen_crate( ); } + if !autodiff_fncs.is_empty() { + ongoing_codegen.submit_autodiff_items(autodiff_fncs); + } + // For better throughput during parallel processing by LLVM, we used to sort // CGUs largest to smallest. This would lead to better thread utilization // by, for example, preventing a large CGU from being processed last and @@ -982,7 +987,7 @@ pub fn provide(providers: &mut Providers) { config::OptLevel::SizeMin => config::OptLevel::Default, }; - let (defids, _) = tcx.collect_and_partition_mono_items(cratenum); + let (defids, _, _) = tcx.collect_and_partition_mono_items(cratenum); let any_for_speed = defids.items().any(|id| { let CodegenFnAttrs { optimize, .. } = tcx.codegen_fn_attrs(*id); diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 2e0840f2d1bc3..231ea4c8211ea 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -1,10 +1,11 @@ -use rustc_ast::{ast, attr, MetaItemKind, NestedMetaItem}; +use rustc_ast::{ast, attr, MetaItem, MetaItemKind, NestedMetaItem}; use rustc_attr::{list_contains_name, InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_errors::struct_span_err; use rustc_hir as hir; use rustc_hir::def::DefKind; use rustc_hir::def_id::{DefId, LocalDefId, LOCAL_CRATE}; use rustc_hir::{lang_items, weak_lang_items::WEAK_LANG_ITEMS, LangItem}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; use rustc_middle::mir::mono::Linkage; use rustc_middle::query::Providers; @@ -13,6 +14,7 @@ use rustc_session::{lint, parse::feature_err}; use rustc_span::symbol::Ident; use rustc_span::{sym, Span}; use rustc_target::spec::{abi, SanitizerSet}; +use std::str::FromStr; use crate::errors; use crate::target_features::from_target_feature; @@ -697,6 +699,169 @@ fn check_link_name_xor_ordinal( } } +fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { + let attrs = tcx.get_attrs(id, sym::autodiff_into); + + let attrs = attrs + .into_iter() + .filter(|attr| attr.name_or_empty() == sym::autodiff_into) + .collect::>(); + + // check for exactly one autodiff attribute on extern block + let msg_once = "autodiff attribute can only be applied once"; + let attr = match &attrs[..] { + &[] => return AutoDiffAttrs::inactive(), + &[elm] => elm, + x => { + tcx.sess + .struct_span_err(x[1].span, msg_once) + .span_label(x[1].span, "more than one") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let list = attr.meta_item_list().unwrap_or_default(); + + // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions + if list.len() == 0 { + return AutoDiffAttrs { + mode: DiffMode::Source, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + }; + } + + let msg_ad_mode = "autodiff attribute must contain autodiff mode"; + let mode = match &list[0] { + NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => { + p2.segments.first().unwrap().ident + } + _ => { + tcx.sess + .struct_span_err(attr.span, msg_ad_mode) + .span_label(attr.span, "empty argument list") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + // parse mode + let msg_mode = "mode should be either forward or reverse"; + let mode = match mode.as_str() { + //map(|x| x.as_str()) { + "Forward" => DiffMode::Forward, + "Reverse" => DiffMode::Reverse, + _ => { + tcx.sess + .struct_span_err(attr.span, msg_mode) + .span_label(attr.span, "invalid mode") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let msg_ret_activity = "autodiff attribute must contain the return activity"; + let ret_symbol = match &list[1] { + NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => { + p2.segments.first().unwrap().ident + } + _ => { + tcx.sess + .struct_span_err(attr.span, msg_ret_activity) + .span_label(attr.span, "missing return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let msg_unknown_ret_activity = "unknown return activity"; + let ret_activity = match DiffActivity::from_str(ret_symbol.as_str()) { + Ok(x) => x, + Err(_) => { + tcx.sess + .struct_span_err(attr.span, msg_unknown_ret_activity) + .span_label(attr.span, "invalid return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let msg_arg_activity = "autodiff attribute must contain the return activity"; + let mut arg_activities: Vec = vec![]; + for arg in &list[2..] { + let arg_symbol = match arg { + NestedMetaItem::MetaItem(MetaItem { + path: ref p2, kind: MetaItemKind::Word, .. + }) => p2.segments.first().unwrap().ident, + _ => { + tcx.sess + .struct_span_err( + attr.span, msg_arg_activity, + ) + .span_label(attr.span, "missing return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + match DiffActivity::from_str(arg_symbol.as_str()) { + Ok(arg_activity) => arg_activities.push(arg_activity), + Err(_) => { + tcx.sess + .struct_span_err(attr.span, msg_unknown_ret_activity) + .span_label(attr.span, "invalid input activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + } + } + + let msg_fwd_incompatible_ret = "Forward Mode is incompatible with Active ret"; + let msg_fwd_incompatible_arg = "Forward Mode is incompatible with Active ret"; + let msg_rev_incompatible_arg = "Reverse Mode is only compatible with Active, None, or Const ret"; + if mode == DiffMode::Forward { + if ret_activity == DiffActivity::Active { + tcx.sess + .struct_span_err(attr.span, msg_fwd_incompatible_ret) + .span_label(attr.span, "invalid return activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + if arg_activities.iter().filter(|&x| *x == DiffActivity::Active).count() > 0 { + tcx.sess + .struct_span_err(attr.span, msg_fwd_incompatible_arg) + .span_label(attr.span, "invalid input activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + } + + if mode == DiffMode::Reverse { + if ret_activity == DiffActivity::Duplicated + || ret_activity == DiffActivity::DuplicatedNoNeed + { + tcx.sess + .struct_span_err( + attr.span, msg_rev_incompatible_arg, + ) + .span_label(attr.span, "invalid return activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + } + + AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities } +} + pub fn provide(providers: &mut Providers) { - *providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers }; + *providers = + Providers { codegen_fn_attrs, should_inherit_track_caller, autodiff_attrs, ..*providers }; } diff --git a/compiler/rustc_codegen_ssa/src/traits/misc.rs b/compiler/rustc_codegen_ssa/src/traits/misc.rs index 04e2b8796c46a..5f64dd3367661 100644 --- a/compiler/rustc_codegen_ssa/src/traits/misc.rs +++ b/compiler/rustc_codegen_ssa/src/traits/misc.rs @@ -19,4 +19,5 @@ pub trait MiscMethods<'tcx>: BackendTypes { fn apply_target_cpu_attr(&self, llfn: Self::Function); /// Declares the extern "C" main function for the entry point. Returns None if the symbol already exists. fn declare_c_main(&self, fn_type: Self::Type) -> Option; + fn create_autodiff(&self) -> Vec; } diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index ecf5095d8a335..05995062435f5 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -2,8 +2,10 @@ use crate::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use crate::back::write::{CodegenContext, FatLtoInput, ModuleConfig}; use crate::{CompiledModule, ModuleCodegen}; +use rustc_data_structures::fx::FxHashMap; use rustc_errors::{FatalError, Handler}; use rustc_middle::dep_graph::WorkProduct; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; pub trait WriteBackendMethods: 'static + Sized + Clone { type Module: Send + Sync; @@ -12,6 +14,7 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { type ModuleBuffer: ModuleBufferMethods; type ThinData: Send + Sync; type ThinBuffer: ThinBufferMethods; + type TypeTree: Clone; /// Merge all modules into main_module and returning it fn run_link( @@ -58,6 +61,15 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { ) -> Result; fn prepare_thin(module: ModuleCodegen) -> (String, Self::ThinBuffer); fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer); + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result<(), FatalError>; + fn typetrees(module: &mut Self::Module) -> FxHashMap; } pub trait ThinBufferMethods: Send + Sync { diff --git a/compiler/rustc_expand/src/expand.rs b/compiler/rustc_expand/src/expand.rs index f87f4aba2b9ea..93282c8b2b571 100644 --- a/compiler/rustc_expand/src/expand.rs +++ b/compiler/rustc_expand/src/expand.rs @@ -406,6 +406,7 @@ impl<'a, 'b> MacroExpander<'a, 'b> { } /// Recursively expand all macro invocations in this AST fragment. + /// Manuel: Add autodiff pub fn fully_expand_fragment(&mut self, input_fragment: AstFragment) -> AstFragment { let orig_expansion_data = self.cx.current_expansion.clone(); let orig_force_mode = self.cx.force_mode; diff --git a/compiler/rustc_feature/src/builtin_attrs.rs b/compiler/rustc_feature/src/builtin_attrs.rs index e808e4815fe0b..c0c883915b3c7 100644 --- a/compiler/rustc_feature/src/builtin_attrs.rs +++ b/compiler/rustc_feature/src/builtin_attrs.rs @@ -353,6 +353,18 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[ ungated!(used, Normal, template!(Word, List: "compiler|linker"), WarnFollowing, @only_local: true), ungated!(link_ordinal, Normal, template!(List: "ordinal"), ErrorPreceding), + // Autodiff + ungated!( + autodiff_into, Normal, + template!(Word, List: r#""...""#), + DuplicatesOk, + ), + ungated!( + autodiff, Normal, + template!(List: r#""...""#), + DuplicatesOk, + ), + // Limits: ungated!(recursion_limit, CrateLevel, template!(NameValueStr: "N"), FutureWarnFollowing), ungated!(type_length_limit, CrateLevel, template!(NameValueStr: "N"), FutureWarnFollowing), diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 4390486b0deb1..f357127794438 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -96,6 +96,11 @@ extern "C" char *LLVMRustGetLastError(void) { return Ret; } +extern "C" LLVMTypeRef LLVMRustGetFunctionType(LLVMValueRef Fn) { + auto Ftype = unwrap(Fn)->getFunctionType(); + return wrap(Ftype); +} + extern "C" void LLVMRustSetLastError(const char *Err) { free((void *)LastError); LastError = strdup(Err); @@ -314,6 +319,19 @@ extern "C" LLVMAttributeRef LLVMRustCreateAttrNoValue(LLVMContextRef C, return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr))); } +extern "C" void LLVMRustRemoveEnumAttributeAtIndex(LLVMValueRef F, size_t index, LLVMRustAttribute RustAttr) { + LLVMRemoveEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + +extern "C" void LLVMRustAddEnumAttributeAtIndex(LLVMContextRef C, LLVMValueRef F, size_t index, LLVMRustAttribute RustAttr) { + LLVMAddAttributeAtIndex(F, index, LLVMRustCreateAttrNoValue(C, RustAttr)); +} + +extern "C" LLVMAttributeRef LLVMRustGetEnumAttributeAtIndex(LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + return LLVMGetEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + extern "C" LLVMAttributeRef LLVMRustCreateAlignmentAttr(LLVMContextRef C, uint64_t Bytes) { return wrap(Attribute::getWithAlignment(*unwrap(C), llvm::Align(Bytes))); diff --git a/compiler/rustc_middle/src/arena.rs b/compiler/rustc_middle/src/arena.rs index 1d573a746b918..f341a645e7d81 100644 --- a/compiler/rustc_middle/src/arena.rs +++ b/compiler/rustc_middle/src/arena.rs @@ -97,6 +97,7 @@ macro_rules! arena_types { [] upvars_mentioned: rustc_data_structures::fx::FxIndexMap, [] object_safety_violations: rustc_middle::traits::ObjectSafetyViolation, [] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>, + [] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem, [decode] attribute: rustc_ast::Attribute, [] name_set: rustc_data_structures::unord::UnordSet, [] ordered_name_set: rustc_data_structures::fx::FxIndexSet, diff --git a/compiler/rustc_middle/src/query/erase.rs b/compiler/rustc_middle/src/query/erase.rs index e20e9d9312c1b..7f33139d67b8d 100644 --- a/compiler/rustc_middle/src/query/erase.rs +++ b/compiler/rustc_middle/src/query/erase.rs @@ -190,6 +190,10 @@ impl EraseType for (&'_ T0, &'_ [T1]) { type Result = [u8; size_of::<(&'static (), &'static [()])>()]; } +impl EraseType for (&'_ T0, &'_ [T1], &'_ [T2]) { + type Result = [u8; size_of::<(&'static (), &'static [()], &'static [()])>()]; +} + macro_rules! trivial { ($($ty:ty),+ $(,)?) => { $( diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index 062b03e71fdc1..b16cfaf0e01a6 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -10,6 +10,7 @@ use crate::dep_graph; use crate::infer::canonical::{self, Canonical}; use crate::lint::LintExpectation; use crate::metadata::ModChild; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem}; use crate::middle::codegen_fn_attrs::CodegenFnAttrs; use crate::middle::debugger_visualizer::DebuggerVisualizerFile; use crate::middle::exported_symbols::{ExportedSymbol, SymbolExportInfo}; @@ -1229,6 +1230,13 @@ rustc_queries! { separate_provide_extern } + /// The list autodiff extern functions in current crate + query autodiff_attrs(def_id: DefId) -> &'tcx AutoDiffAttrs { + desc { |tcx| "computing autodiff attributes of `{}`", tcx.def_path_str(def_id) } + arena_cache + cache_on_disk_if { def_id.is_local() } + } + query asm_target_features(def_id: DefId) -> &'tcx FxIndexSet { desc { |tcx| "computing target features for inline asm of `{}`", tcx.def_path_str(def_id) } } @@ -1878,7 +1886,7 @@ rustc_queries! { separate_provide_extern } - query collect_and_partition_mono_items(_: ()) -> (&'tcx DefIdSet, &'tcx [CodegenUnit<'tcx>]) { + query collect_and_partition_mono_items(_: ()) -> (&'tcx DefIdSet, &'tcx [AutoDiffItem], &'tcx [CodegenUnit<'tcx>]) { eval_always desc { "collect_and_partition_mono_items" } } diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 739d4fa886ec3..31d60e97cded7 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -176,6 +176,8 @@ pub struct ResolverGlobalCtxt { /// Mapping from ident span to path span for paths that don't exist as written, but that /// exist under `std`. For example, wrote `str::from_utf8` instead of `std::str::from_utf8`. pub confused_type_with_std_module: FxHashMap, + /// Mapping of autodiff function IDs + pub autodiff_map: FxHashMap, pub doc_link_resolutions: FxHashMap, pub doc_link_traits_in_scope: FxHashMap>, pub all_macro_rules: FxHashMap>, diff --git a/compiler/rustc_monomorphize/Cargo.toml b/compiler/rustc_monomorphize/Cargo.toml index fe097424e8ad4..e2758cc330d8d 100644 --- a/compiler/rustc_monomorphize/Cargo.toml +++ b/compiler/rustc_monomorphize/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" serde = "1" serde_json = "1" tracing = "0.1" +rustc_ast = { path = "../rustc_ast" } rustc_data_structures = { path = "../rustc_data_structures" } rustc_errors = { path = "../rustc_errors" } rustc_hir = { path = "../rustc_hir" } @@ -18,3 +19,4 @@ rustc_middle = { path = "../rustc_middle" } rustc_session = { path = "../rustc_session" } rustc_span = { path = "../rustc_span" } rustc_target = { path = "../rustc_target" } +rustc_symbol_mangling = { path = "../rustc_symbol_mangling" } diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index 82fee7c8dfe58..db3d3fc88ac05 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -204,7 +204,7 @@ pub enum MonoItemCollectionMode { pub struct UsageMap<'tcx> { // Maps every mono item to the mono items used by it. - used_map: FxHashMap, Vec>>, + pub used_map: FxHashMap, Vec>>, // Maps every mono item to the mono items that use it. user_map: FxHashMap, Vec>>, @@ -1244,6 +1244,7 @@ impl<'v> RootCollector<'_, 'v> { /// monomorphized copy of the start lang item based on /// the return type of `main`. This is not needed when /// the user writes their own `start` manually. + /// TODO: remove annotations after automatic differentation pass fn push_extra_entry_roots(&mut self) { let Some((main_def_id, EntryFnType::Main { .. })) = self.entry_fn else { return; diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 4009e28924068..e30eec6787f06 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -92,6 +92,10 @@ //! source-level module, functions from the same module will be available for //! inlining, even when they are not marked `#[inline]`. +use rustc_symbol_mangling::symbol_name_for_instance_in_crate; +use rustc_ast::expand::typetree::{Kind, Type, TypeTree}; +use rustc_target::abi::FieldsShape; + use std::cmp; use std::collections::hash_map::Entry; use std::fs::{self, File}; @@ -103,6 +107,7 @@ use rustc_data_structures::sync; use rustc_hir::def::DefKind; use rustc_hir::def_id::{DefId, DefIdSet, LOCAL_CRATE}; use rustc_hir::definitions::DefPathDataName; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags; use rustc_middle::middle::exported_symbols::{SymbolExportInfo, SymbolExportLevel}; use rustc_middle::mir::mono::{ @@ -111,7 +116,7 @@ use rustc_middle::mir::mono::{ }; use rustc_middle::query::Providers; use rustc_middle::ty::print::{characteristic_def_id_of_type, with_no_trimmed_paths}; -use rustc_middle::ty::{self, visit::TypeVisitableExt, InstanceDef, TyCtxt}; +use rustc_middle::ty::{self, visit::TypeVisitableExt, InstanceDef, TyCtxt, ParamEnv, ParamEnvAnd, Adt, Ty}; use rustc_session::config::{DumpMonoStatsFormat, SwitchWithOptPath}; use rustc_session::CodegenUnits; use rustc_span::symbol::Symbol; @@ -242,7 +247,14 @@ where &mut can_be_internalized, export_generics, ); - if visibility == Visibility::Hidden && can_be_internalized { + //if visibility == Visibility::Hidden && can_be_internalized { + + //dbg!(&characteristic_def_id); + let autodiff_active = characteristic_def_id + .map(|x| cx.tcx.autodiff_attrs(x).is_active()) + .unwrap_or(false); + + if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized { internalization_candidates.insert(mono_item); } let size_estimate = mono_item.size_estimate(cx.tcx); @@ -1078,7 +1090,7 @@ where } } -fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[CodegenUnit<'_>]) { +fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[AutoDiffItem], &[CodegenUnit<'_>]) { let collection_mode = match tcx.sess.opts.unstable_opts.print_mono_items { Some(ref s) => { let mode = s.to_lowercase(); @@ -1137,6 +1149,68 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Co }) .collect(); + let autodiff_items = items + .iter() + .filter_map(|item| match *item { + MonoItem::Fn(ref instance) => Some((item, instance)), + _ => None, + }) + .filter_map(|(item, instance)| { + let target_id = instance.def_id(); + let target_attrs = tcx.autodiff_attrs(target_id); + if !target_attrs.apply_autodiff() { + return None; + } + //println!("target_id: {:?}", target_id); + + let target_symbol = + symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); + //let range = usage_map.used_map.get(&item).unwrap(); + //TODO: check if last and next line are correct after rebasing + + println!("target_symbol: {:?}", target_symbol); + println!("target_attrs: {:?}", target_attrs); + println!("target_id: {:?}", target_id); + //print item + println!("item: {:?}", item); + let source = usage_map.used_map.get(&item).unwrap() + .into_iter() + .filter_map(|item| match *item { + MonoItem::Fn(ref instance_s) => { + let source_id = instance_s.def_id(); + println!("source_id_inner: {:?}", source_id); + println!("instance_s: {:?}", instance_s); + + if tcx.autodiff_attrs(source_id).is_active() { + println!("source_id is active"); + return Some(instance_s); + } + //target_symbol: "_ZN14rosenbrock_rev12d_rosenbrock17h3352c4f00c3082daE" + //target_attrs: AutoDiffAttrs { mode: Reverse, ret_activity: Active, input_activity: [Duplicated] } + //target_id: DefId(0:8 ~ rosenbrock_rev[2708]::d_rosenbrock) + //item: Fn(Instance { def: Item(DefId(0:8 ~ rosenbrock_rev[2708]::d_rosenbrock)), args: [] }) + //source_id_inner: DefId(0:4 ~ rosenbrock_rev[2708]::main) + //instance_s: Instance { def: Item(DefId(0:4 ~ rosenbrock_rev[2708]::main)), args: [] } + + + None + } + _ => None, + }) + .next(); + println!("source: {:?}", source); + + source.map(|inst| { + println!("source_id: {:?}", inst.def_id()); + let (inputs, output) = fnc_typetrees(inst.ty(tcx, ParamEnv::empty()), tcx); + let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); + + target_attrs.clone().into_item(symb, target_symbol, inputs, output) + }) + }); + + let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items); + // Output monomorphization stats per def_id if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats { if let Err(err) = @@ -1197,7 +1271,145 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Co } } - (tcx.arena.alloc(mono_items), codegen_units) + (tcx.arena.alloc(mono_items), autodiff_items, codegen_units) +} + +pub fn typetree_empty() -> TypeTree { + TypeTree(vec![]) +} + +pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize) -> TypeTree { + if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do whith fn ptr?"); + } + let inner_ty = ty.builtin_deref(true).unwrap().ty; + let child = typetree_from_ty(inner_ty, tcx, depth + 1); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; + return TypeTree(vec![tt]); + } + + if ty.is_scalar() { + assert!(!ty.is_any_ptr()); + let (kind, size) = if ty.is_integral() { + (Kind::Integer, 8) + } else { + assert!(ty.is_floating_point()); + match ty { + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + _ => panic!("floatTy scalar that is neither f32 nor f64"), + } + }; + return TypeTree(vec![Type { offset: -1, child: typetree_empty(), kind, size }]); + } + + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: ty }; + + let layout = tcx.layout_of(param_env_and); + assert!(layout.is_ok()); + + let layout = layout.unwrap().layout; + let fields = layout.fields(); + let max_size = layout.size(); + + if ty.is_adt() { + let adt_def = ty.ty_adt_def().unwrap(); + let substs = match ty.kind() { + Adt(_, subst_ref) => subst_ref, + _ => panic!(""), + }; + + if adt_def.is_struct() { + let (offsets, _memory_index) = match fields { + FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m), + _ => panic!(""), + }; + + let fields = adt_def.all_fields(); + let fields = fields + .into_iter() + .zip(offsets.into_iter()) + .filter_map(|(field, offset)| { + let field_ty: Ty<'_> = field.ty(tcx, substs); + let field_ty: Ty<'_> = + tcx.normalize_erasing_regions(ParamEnv::empty(), field_ty); + + if field_ty.is_phantom_data() { + return None; + } + + let mut child = typetree_from_ty(field_ty, tcx, depth + 1).0; + + for c in &mut child { + if c.offset == -1 { + c.offset = offset.bytes() as isize + } else { + c.offset += offset.bytes() as isize; + } + } + + Some(child) + }) + .flatten() + .collect::>(); + + let ret_tt = TypeTree(fields); + return ret_tt; + } else { + unimplemented!("adt that isn't a struct"); + } + } + + if ty.is_array() { + let (stride, count) = match fields { + FieldsShape::Array { stride: s, count: c } => (s, c), + _ => panic!(""), + }; + let byte_stride = stride.bytes_usize(); + let byte_max_size = max_size.bytes_usize(); + + assert!(byte_stride * *count as usize == byte_max_size); + assert!(*count > 0); // return empty TT for empty? + let sub_ty = ty.builtin_index().unwrap(); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1); + + // calculate size of subtree + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty }; + let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize; + let tt = TypeTree( + std::iter::repeat(subtt) + .take(*count as usize) + .enumerate() + .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) + .flatten() + .collect(), + ); + + return tt; + } + + if ty.is_slice() { + let sub_ty = ty.builtin_index().unwrap(); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1); + + return subtt; + } + + typetree_empty() +} + +pub fn fnc_typetrees<'tcx>(fn_ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> (Vec, TypeTree) { + let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); + + // TODO: verify. + let x: ty::FnSig<'_> = fnc_binder.skip_binder(); + + let inputs = x.inputs().into_iter().map(|x| typetree_from_ty(*x, tcx, 0)).collect(); + + let output = typetree_from_ty(x.output(), tcx, 0); + + (inputs, output) } /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s @@ -1282,12 +1494,12 @@ pub fn provide(providers: &mut Providers) { providers.collect_and_partition_mono_items = collect_and_partition_mono_items; providers.is_codegened_item = |tcx, def_id| { - let (all_mono_items, _) = tcx.collect_and_partition_mono_items(()); + let (all_mono_items, _, _) = tcx.collect_and_partition_mono_items(()); all_mono_items.contains(&def_id) }; providers.codegen_unit = |tcx, name| { - let (_, all) = tcx.collect_and_partition_mono_items(()); + let (_, _, all) = tcx.collect_and_partition_mono_items(()); all.iter() .find(|cgu| cgu.name() == name) .unwrap_or_else(|| panic!("failed to find cgu with name {name:?}")) diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs index a8a27e761cb3f..19aa8a308b3a2 100644 --- a/compiler/rustc_passes/src/check_attr.rs +++ b/compiler/rustc_passes/src/check_attr.rs @@ -233,6 +233,7 @@ impl CheckAttrVisitor<'_> { self.check_generic_attr(hir_id, attr, target, Target::Fn); self.check_proc_macro(hir_id, target, ProcMacroKind::Derive) } + sym::autodiff_into => self.check_autodiff(hir_id, attr, span, target), _ => {} } @@ -2394,6 +2395,20 @@ impl CheckAttrVisitor<'_> { self.abort.set(true); } } + + /// Checks if `#[autodiff]` is applied to an item other than a foreign module. + fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, _span: Span, _target: Target) { + //match target { + // Target::ForeignMod => {} + // _ => { + // self.tcx + // .sess + // .struct_span_err(attr.span, "attribute should be applied to an `extern` block") + // .span_label(span, "not an `extern` block") + // .emit(); + // } + //} + } } impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> { diff --git a/compiler/rustc_resolve/src/lib.rs b/compiler/rustc_resolve/src/lib.rs index 501747df5c908..58e6d82595e7a 100644 --- a/compiler/rustc_resolve/src/lib.rs +++ b/compiler/rustc_resolve/src/lib.rs @@ -1522,6 +1522,7 @@ impl<'a, 'tcx> Resolver<'a, 'tcx> { trait_impls: self.trait_impls, proc_macros, confused_type_with_std_module, + autodiff_map: Default::default(), doc_link_resolutions: self.doc_link_resolutions, doc_link_traits_in_scope: self.doc_link_traits_in_scope, all_macro_rules: self.all_macro_rules, diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 3f99d2a4b1ffb..306860e30453f 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -437,6 +437,8 @@ symbols! { attributes, augmented_assignments, auto_traits, + autodiff, + autodiff_into, automatically_derived, avx, avx512_target_feature, @@ -1022,6 +1024,7 @@ symbols! { miri, misc, mmx_reg, + mode, modifiers, module, module_path, diff --git a/config.example.toml b/config.example.toml index 66fa91d4bad15..6050848cb3a05 100644 --- a/config.example.toml +++ b/config.example.toml @@ -142,6 +142,9 @@ change-id = 116998 # Whether or not to specify `-DLLVM_TEMPORARILY_ALLOW_OLD_TOOLCHAIN=YES` #allow-old-toolchain = false +# Whether to build enzyme +#enzyme = false + # Whether to include the Polly optimizer. #polly = false diff --git a/library/autodiff/Cargo.toml b/library/autodiff/Cargo.toml new file mode 100644 index 0000000000000..cbbff8d375e3d --- /dev/null +++ b/library/autodiff/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "autodiff" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + + +[profile.release] +lto = "fat" + +[profile.dev] +lto = "fat" + +[lib] +name = "autodiff" +proc-macro = true + +[dependencies] +quote = "1.0" +proc-macro2 = "1" +proc-macro-error = "1" +syn = { version = "1", features = ["extra-traits", "full", "visit", "visit-mut"]} + +[dev-dependencies] +macrotest = "1" +trybuild = "1" +ndarray = "0.15" diff --git a/library/autodiff/examples/array.rs b/library/autodiff/examples/array.rs new file mode 100644 index 0000000000000..60c6b63fd84cb --- /dev/null +++ b/library/autodiff/examples/array.rs @@ -0,0 +1,23 @@ +use autodiff::autodiff; + +#[autodiff(d_array, Reverse, Active, Duplicated)] +fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 { + arr[0][0][0] * arr[1][1][1] +} + +fn main() { + let arr = [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]; + let mut d_arr = [[[0.0; 2]; 2]; 2]; + + d_array(&arr, &mut d_arr, 1.0); + + dbg!(&d_arr); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/box.rs b/library/autodiff/examples/box.rs new file mode 100644 index 0000000000000..5d4f114830bf4 --- /dev/null +++ b/library/autodiff/examples/box.rs @@ -0,0 +1,24 @@ +use autodiff::autodiff; + +#[autodiff(cos_box, Reverse, Active, Duplicated)] +fn sin(x: &Box) -> f32 { + f32::sin(**x) +} + +fn main() { + let x = Box::::new(3.14); + let mut df_dx = Box::::new(0.0); + cos_box(&x, &mut df_dx, 1.0); + + dbg!(&df_dx); + + assert!(*df_dx == f32::cos(*x)); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/broken_matvec.rs b/library/autodiff/examples/broken_matvec.rs new file mode 100644 index 0000000000000..0c4b2cfe6e927 --- /dev/null +++ b/library/autodiff/examples/broken_matvec.rs @@ -0,0 +1,34 @@ +use autodiff::autodiff; + +type Matrix = Vec>; +type Vector = Vec; + +#[autodiff(d_matvec, Forward, Const)] +fn matvec(#[dup] mat: &Matrix, vec: &Vector, #[dup] out: &mut Vector) { + for i in 0..mat.len() - 1 { + for j in 0..mat[0].len() - 1 { + out[i] += mat[i][j] * vec[j]; + } + } +} + +fn main() { + let mat = vec![vec![1.0, 1.0], vec![1.0, 1.0]]; + let mut d_mat = vec![vec![0.0, 0.0], vec![0.0, 0.0]]; + let inp = vec![1.0, 1.0]; + let mut out = vec![0.0, 0.0]; + let mut out_tang = vec![0.0, 1.0]; + + //matvec(&mat, &inp, &mut out); + d_matvec(&mat, &mut d_mat, &inp, &mut out, &mut out_tang); + + dbg!(&out); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/hessian_sin.rs b/library/autodiff/examples/hessian_sin.rs new file mode 100644 index 0000000000000..6b1e776476fd2 --- /dev/null +++ b/library/autodiff/examples/hessian_sin.rs @@ -0,0 +1,28 @@ +use autodiff::autodiff; + +fn sin(x: &Vec, y: &mut f32) { + *y = x.into_iter().map(|x| f32::sin(*x)).sum() +} + +#[autodiff(sin, Reverse, Const, Duplicated, Duplicated)] +fn jac(x: &Vec, d_x: &mut Vec, y: &mut f32, y_t: &f32); + +#[autodiff(jac, Forward, Const, Duplicated, Const, Const, Const)] +fn hessian(x: &Vec, y_x: &Vec, d_x: &mut Vec, y: &mut f32, y_t: &f32); + +fn main() { + let inp = vec![3.1415 / 2., 1.0, 0.5]; + let mut d_inp = vec![0.0, 0.0, 0.0]; + let mut y = 0.0; + let tang = vec![1.0, 0.0, 0.0]; + hessian(&inp, &tang, &mut d_inp, &mut y, &1.0); + dbg!(&d_inp); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/ndarray.rs b/library/autodiff/examples/ndarray.rs new file mode 100644 index 0000000000000..34402c43cb3e6 --- /dev/null +++ b/library/autodiff/examples/ndarray.rs @@ -0,0 +1,25 @@ +use autodiff::autodiff; + +use ndarray::Array1; + +#[autodiff(d_collect, Reverse, Active)] +fn collect(#[dup] x: &Array1) -> f32 { + x[0] +} + +fn main() { + let a = Array1::zeros(19); + let mut d_a = Array1::zeros(19); + + d_collect(&a, &mut d_a, 1.0); + + dbg!(&d_a); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/rosenbrock_fwd.rs b/library/autodiff/examples/rosenbrock_fwd.rs new file mode 100644 index 0000000000000..a3ab7a47578d0 --- /dev/null +++ b/library/autodiff/examples/rosenbrock_fwd.rs @@ -0,0 +1,34 @@ +use autodiff::autodiff; + +#[autodiff(d_rosenbrock, Forward, DuplicatedNoNeed)] +fn rosenbrock(#[dup] x: &[f64; 2]) -> f64 { + let mut res = 0.0; + for i in 0..(x.len() - 1) { + let a = x[i + 1] - x[i] * x[i]; + let b = x[i] - 1.0; + res += 100.0 * a * a + b * b; + } + res +} + +fn main() { + let x = [3.14, 2.4]; + let output = rosenbrock(&x); + println!("{output}"); + let df_dx = d_rosenbrock(&x, &[1.0, 0.0]); + let df_dy = d_rosenbrock(&x, &[0.0, 1.0]); + + dbg!(&df_dx, &df_dy); + + // https://www.wolframalpha.com/input?i2d=true&i=x%3D3.14%3B+y%3D2.4%3B+D%5Brosenbrock+function%5C%2840%29x%5C%2844%29+y%5C%2841%29+%2Cy%5D + assert!((df_dx - 9373.54).abs() < 0.1); + assert!((df_dy - (-1491.92)).abs() < 0.1); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/rosenbrock_fwd_iter.rs b/library/autodiff/examples/rosenbrock_fwd_iter.rs new file mode 100644 index 0000000000000..1648014392f19 --- /dev/null +++ b/library/autodiff/examples/rosenbrock_fwd_iter.rs @@ -0,0 +1,34 @@ +use autodiff::autodiff; + +#[autodiff(d_rosenbrock, Forward, DuplicatedNoNeed)] +fn rosenbrock(#[dup] x: &[f64; 2]) -> f64 { + (0..x.len() - 1) + .map(|i| { + let (a, b) = (x[i + 1] - x[i] * x[i], x[i] - 1.0); + 100.0 * a * a + b * b + }) + .sum() +} + +fn main() { + let x = [3.14f64, 2.4]; + let output = rosenbrock(&x); + println!("{output}"); + + let df_dx = d_rosenbrock(&x, &[1.0, 0.0]); + let df_dy = d_rosenbrock(&x, &[0.0, 1.0]); + + dbg!(&df_dx, &df_dy); + + // https://www.wolframalpha.com/input?i2d=true&i=x%3D3.14%3B+y%3D2.4%3B+D%5Brosenbrock+function%5C%2840%29x%5C%2844%29+y%5C%2841%29+%2Cy%5D + assert!((df_dx - 9373.54).abs() < 0.1); + assert!((df_dy - (-1491.92)).abs() < 0.1); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/rosenbrock_rev.rs b/library/autodiff/examples/rosenbrock_rev.rs new file mode 100644 index 0000000000000..b4ce00b5afe9d --- /dev/null +++ b/library/autodiff/examples/rosenbrock_rev.rs @@ -0,0 +1,33 @@ +use autodiff::autodiff; + +#[autodiff(d_rosenbrock, Reverse, Active)] +fn rosenbrock(#[dup] x: &[f64; 2]) -> f64 { + let mut res = 0.0; + for i in 0..(x.len() - 1) { + let a = x[i + 1] - x[i] * x[i]; + let b = x[i] - 1.0; + res += 100.0 * a * a + b * b; + } + res +} + +fn main() { + let x = [3.14, 2.4]; + let output = rosenbrock(&x); + println!("{output}"); + + let mut df_dx = [0.0f64; 2]; + d_rosenbrock(&x, &mut df_dx, 1.0); + + // https://www.wolframalpha.com/input?i2d=true&i=x%3D3.14%3B+y%3D2.4%3B+D%5Brosenbrock+function%5C%2840%29x%5C%2844%29+y%5C%2841%29+%2Cy%5D + assert!((df_dx[0] - 9373.54).abs() < 0.01); + assert!((df_dx[1] - (-1491.92)).abs() < 0.01); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/sin.rs b/library/autodiff/examples/sin.rs new file mode 100644 index 0000000000000..1655b1e7ecd09 --- /dev/null +++ b/library/autodiff/examples/sin.rs @@ -0,0 +1,36 @@ +use autodiff::autodiff; + +#[autodiff(cos_inplace, Reverse, Const)] +fn sin_inplace(#[dup] x: &f32, #[dup] y: &mut f32) { + *y = x.sin(); +} + + +fn main() { + // Here we can use ==, even though we work on f32. + // Enzyme will recognize the sin function and replace it with llvm's cos function (see below). + // Calling f32::cos directly will also result in calling llvm's cos function. + let a = 3.1415; + let mut da = 0.0; + let mut y = 0.0; + cos_inplace(&a, &mut da, &mut y, &mut 1.0); + + dbg!(&a, &da, &y); + assert!(da - f32::cos(a) == 0.0); +} + +// Just for curious readers, this is the (inner) function that Enzyme does generate: +// define internal { float } @diffe_ZN3sin3sin17h18f17f71fe94e58fE(float %0, float %1) unnamed_addr #35 { +// %3 = call fast float @llvm.cos.f32(float %0) +// %4 = fmul fast float %1, %3 +// %5 = insertvalue { float } undef, float %4, 0 +// ret { float } %5 +// } + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/sqrt.rs b/library/autodiff/examples/sqrt.rs new file mode 100644 index 0000000000000..d15c6f5ec2051 --- /dev/null +++ b/library/autodiff/examples/sqrt.rs @@ -0,0 +1,21 @@ +use autodiff::autodiff; + +#[autodiff(d_sqrt, Reverse, Active)] +fn sqrt(#[active] a: f32, #[dup] b: &f32, c: &f32, #[active] d: f32) -> f32 { + a * (b * b + c*c*d*d).sqrt() +} + +fn main() { + let mut d_b = 0.0; + + let (d_a, d_d) = d_sqrt(1.0, &1.0, &mut d_b, &1.0, 1.0, 1.0); + dbg!(d_a, d_b, d_d); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/struct.rs b/library/autodiff/examples/struct.rs new file mode 100644 index 0000000000000..1235307fdbcbf --- /dev/null +++ b/library/autodiff/examples/struct.rs @@ -0,0 +1,33 @@ +use autodiff::autodiff; + +use std::io; + +// Will be represented as {f32, i16, i16} when passed by reference +// will be represented as i64 if passed by value +struct Foo { + c1: i16, + a: f32, + c2: i16, +} + +#[autodiff(cos, Reverse, Active, Duplicated)] +fn sin(x: &Foo) -> f32 { + assert!(x.c1 < x.c2); + f32::sin(x.a) +} + +fn main() { + let mut s = String::new(); + println!("Please enter a value for c1"); + io::stdin().read_line(&mut s).unwrap(); + let c2 = s.trim_end().parse::().unwrap(); + dbg!(c2); + + let foo = Foo { c1: 4, a: 3.14, c2 }; + let mut df_dfoo = Foo { c1: 4, a: 0.0, c2 }; + + dbg!(df_dfoo.a); + dbg!(cos(&foo, &mut df_dfoo, 1.0)); + dbg!(df_dfoo.a); + dbg!(f32::cos(foo.a)); +} diff --git a/library/autodiff/examples/vec.rs b/library/autodiff/examples/vec.rs new file mode 100644 index 0000000000000..e82618fac4dac --- /dev/null +++ b/library/autodiff/examples/vec.rs @@ -0,0 +1,24 @@ +use autodiff::autodiff; + +#[autodiff(d_sum, Forward, Duplicated)] +fn sum(#[dup] x: &Vec>) -> f32 { + x.into_iter().map(|x| x.into_iter().map(|x| x.sqrt())).flatten().sum() +} + +fn main() { + let a = vec![vec![1.0, 2.0, 4.0, 8.0]]; + //let mut b = vec![vec![0.0, 0.0, 0.0, 0.0]]; + let b = vec![vec![1.0, 0.0, 0.0, 0.0]]; + + dbg!(&d_sum(&a, &b)); + + dbg!(&b); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples_broken/biquad.rs b/library/autodiff/examples_broken/biquad.rs new file mode 100644 index 0000000000000..7689b1cd1fc51 --- /dev/null +++ b/library/autodiff/examples_broken/biquad.rs @@ -0,0 +1,54 @@ +use autodiff::autodiff; + +#[derive(Debug)] +struct Biquad { + coeffs: [[f32; 5]; N], +} + +impl Biquad { + pub fn new() -> Self { + Biquad { coeffs: [[0.0; 5]; N] } + } + + pub fn process(&self, samples: &[f32], target: &[f32]) -> f32 { + // do some horrible inefficient biquad filtering + let mut samples = samples.to_vec(); + let mut samples_out = vec![0.0; samples.len()]; + + for coeff_set in self.coeffs { + for idx in 0..samples.len() { + samples_out[idx] = coeff_set[0] * samples[idx]; + + if idx > 0 { + samples_out[idx] += coeff_set[1] * samples[idx - 1] - + coeff_set[3] * samples_out[idx - 1]; + } + if idx > 1 { + samples_out[idx] += coeff_set[2] * samples[idx - 2] - + coeff_set[4] * samples_out[idx - 2]; + } + } + + (samples, samples_out) = (samples_out, samples); + } + + samples_out.into_iter().zip(target.into_iter()).map(|(a, b)| a - b).sum() + } + + #[autodiff(Self::process, Reverse, Active)] + pub fn deriv(#[dup] &self, params: &mut Self, samples: &[f32], target: &[f32], ret_adj: f32); +} + +fn main() { + let biquad = Biquad::<10>::new(); + let mut dbiquad = Biquad::<10>::new(); + + // create ramp and pulse train + let signal = (0..1024).map(|x| (x as f32) / 1024.0).collect::>(); + let target = (0..1024).map(|x| if x % 2 == 0 { 0.0 } else { 1.0 }).collect::>(); + + dbg!(&biquad.process(&signal, &target)); + biquad.deriv(&mut dbiquad, &signal, &target, 1.0); + + dbg!(&dbiquad); +} diff --git a/library/autodiff/examples_broken/broken_iter.rs b/library/autodiff/examples_broken/broken_iter.rs new file mode 100644 index 0000000000000..16d205f7373c8 --- /dev/null +++ b/library/autodiff/examples_broken/broken_iter.rs @@ -0,0 +1,20 @@ +#![feature(bench_black_box)] +use autodiff::autodiff; +use std::ptr; + +#[autodiff(sin_vec, Reverse, Active)] +fn cos_vec(#[dup] x: &Vec) -> f32 { + // uses enum internally and breaks + let res = x.into_iter().collect::>(); + + *res[0] +} + +fn main() { + let x = vec![1.0, 1.0, 1.0]; + let mut d_x = vec![0.0; 3]; + + sin_vec(&x, &mut d_x, 1.0); + + dbg!(&d_x, &x); +} diff --git a/library/autodiff/examples_broken/broken_recursive.rs b/library/autodiff/examples_broken/broken_recursive.rs new file mode 100644 index 0000000000000..a1f3ff25eb511 --- /dev/null +++ b/library/autodiff/examples_broken/broken_recursive.rs @@ -0,0 +1,66 @@ +#![feature(bench_black_box)] +use autodiff::autodiff; + +// TODO: As seen by the bloated code generated for the iterative version, +// we definetly have to disable unroll, slpvec, loop-vec before AD. +// We also should check if we have other opts that Julia, C++, Fortran etc. don't have +// and which could make our input code more "complex". +// We then however have to start doing whole-module opt after AD to re-include them, +// instead of just using enzyme to optimize the generated function. + +#[autodiff(d_power_recursive, Forward, DuplicatedNoNeed)] +fn power_recursive(#[dup] a: f64, n: i32) -> f64 { + if n == 0 { + return 1.0; + } + return a * power_recursive(a, n - 1); +} + +#[autodiff(d_power_iterative, Reverse, DuplicatedNoNeed)] +fn power_iterative(#[active] a: f64, n: i32) -> f64 { + let mut res = 1.0; + for _ in 0..n { + res *= a; + } + res +} + +fn main() { + // d/dx x^n = n * x^(n-1) + let n = 4; + let nf = n as f64; + let a = 1.337; + assert!(power_recursive(a, n) == power_iterative(a, n)); + let dpr = d_power_recursive(a, 1.0, n); + let dpi = d_power_iterative(a, n, 1.0); + let control = nf * a.powi(n - 1); + dbg!(dpr); + dbg!(dpi); + dbg!(control); + assert!(dpr == control); + assert!(dpi == control); +} + +// Again, for the curious. We can find n * x^(n-1) nicely in the LLVM-IR +// +// define internal double @fwddiffe_ZN9recursive15power_recursive17h789de751cfc6154dE(double %0, double %1, i32 %2) unnamed_addr #8 { +// => if (n == 0) goto 5: and return 0. Correct, since for n==0 we have 0 * x ^ (0-1) = 0 +// => if (n != 0) goto 7: +// %4 = icmp eq i32 %2, 0 +// br i1 %4, label %5, label %7 +// +// 5: ; preds = %7, %3 +// %6 = phi fast double [ %14, %7 ], [ 0.000000e+00, %3 ] +// ret double %6 +// +// 7: ; preds = %3 +// => reduce n by 1, +// %8 = add i32 %2, -1 +// %9 = call { double, double } @fwddiffe_ZN9recursive15power_recursive17h789de751cfc6154dE.1229(double %0, double %1, i32 %8) +// %10 = extractvalue { double, double } %9, 0 +// %11 = extractvalue { double, double } %9, 1 +// %12 = fmul fast double %11, %0 +// %13 = fmul fast double %1, %10 +// %14 = fadd fast double %12, %13 +// br label %5 +// } diff --git a/library/autodiff/examples_broken/broken_second_order.rs b/library/autodiff/examples_broken/broken_second_order.rs new file mode 100644 index 0000000000000..8b427d7dae36a --- /dev/null +++ b/library/autodiff/examples_broken/broken_second_order.rs @@ -0,0 +1,17 @@ +#![feature(bench_black_box)] +use autodiff::autodiff; + +fn sin(x: &f32) -> f32 { + f32::sin(*x) +} + +#[autodiff(sin, Reverse, Active, Active)] +fn cos(x: &f32, adj: f32) -> f32; + +//#[autodiff(cos, Reverse, Active, Active, Const)] +//fn neg_sin(x: &f32, adj: f32, adj_sec: f32) -> f32; + +fn main() { + dbg!(&cos(&1.0, 1.0)); + //dbg!(&neg_sin(&1.0, 1.0, 1.0)); +} diff --git a/library/autodiff/src/gen.rs b/library/autodiff/src/gen.rs new file mode 100644 index 0000000000000..59b37b997e8cf --- /dev/null +++ b/library/autodiff/src/gen.rs @@ -0,0 +1,217 @@ +use crate::parser::{is_ref_mut, PrimalSig}; +use crate::parser::{Activity, DiffItem, Mode}; +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::{format_ident, quote}; +use syn::{parse_quote, FnArg, Ident, Pat, ReturnType, Type}; + +pub(crate) fn generate_header(item: &DiffItem) -> TokenStream { + let mode = match item.header.mode { + Mode::Forward => format_ident!("Forward"), + Mode::Reverse => format_ident!("Reverse"), + }; + let ret_act = item.header.ret_act.to_ident(); + let param_act = item.params.iter().map(|x| x.to_ident()); + + quote!(#[autodiff_into(#mode, #ret_act, #( #param_act, )*)]) +} + +pub(crate) fn primal_fnc(item: &mut DiffItem) -> TokenStream { + // construct body of primal if not given + let body = item.block.clone().map(|x| quote!(#x)).unwrap_or_else(|| { + let header_fnc = &item.header.name; + //let primal_wrapper = format_ident!("primal_{}", item.primal.ident); + //item.primal.ident = primal_wrapper.clone(); + let inputs = item.primal.inputs.iter().map(|x| only_ident(x)).collect::>(); + + quote!({ + #header_fnc(#(#inputs,)*) + }) + }); + + let sig = &item.primal; + let PrimalSig { ident, inputs, output } = sig; + + let ident = + if item.block.is_some() { ident.clone() } else { format_ident!("primal_{}", ident) }; + + let sig = quote!(fn #ident(#(#inputs,)*) #output); + + quote!( + #[autodiff_into] + #sig + #body + ) +} + +fn only_ident(arg: &FnArg) -> Ident { + match arg { + FnArg::Receiver(_) => format_ident!("self"), + FnArg::Typed(t) => match &*t.pat { + Pat::Ident(ident) => ident.ident.clone(), + _ => panic!(""), + }, + } +} + +fn only_type(arg: &FnArg) -> Type { + match arg { + FnArg::Receiver(_) => parse_quote!(Self), + FnArg::Typed(t) => match &*t.ty { + Type::Reference(t) => *t.elem.clone(), + x => x.clone(), + }, + } +} + +fn as_ref_mut(arg: &FnArg, name: &str, mutable: bool) -> FnArg { + match arg { + FnArg::Receiver(_) => { + let name = format_ident!("{}_self", name); + if mutable { parse_quote!(#name: &mut Self) } else { parse_quote!(#name: &Self) } + } + FnArg::Typed(t) => { + let inner = match &*t.ty { + Type::Reference(t) => &t.elem, + _ => panic!(""), // should not be reachable, as we checked mutability before + }; + + let pat_name = match &*t.pat { + Pat::Ident(x) => &x.ident, + _ => panic!(""), + }; + + let name = format_ident!("{}_{}", name, pat_name); + if mutable { parse_quote!(#name: &mut #inner) } else { parse_quote!(#name: &#inner) } + } + } +} + +pub(crate) fn adjoint_fnc(item: &DiffItem) -> TokenStream { + let mut res_inputs: Vec = Vec::new(); + let mut add_inputs: Vec = Vec::new(); + let out_type = match &item.primal.output { + ReturnType::Type(_, x) => Some(*x.clone()), + _ => None, + }; + + let mut outputs = if item.header.ret_act == Activity::Duplicated { + vec![out_type.clone().unwrap()] + } else { + vec![] + }; + + let PrimalSig { ident, inputs, .. } = &item.primal; + + for (input, activity) in inputs.iter().zip(item.params.iter()) { + res_inputs.push(input.clone()); + + match (item.header.mode, activity, is_ref_mut(&input)) { + (Mode::Forward, Activity::Duplicated | Activity::DuplicatedNoNeed, Some(true)) => { + res_inputs.push(as_ref_mut(&input, "grad", true)); + add_inputs.push(as_ref_mut(&input, "grad", true)); + } + (Mode::Forward, Activity::Duplicated | Activity::DuplicatedNoNeed, Some(false)) => { + res_inputs.push(as_ref_mut(&input, "dual", false)); + add_inputs.push(as_ref_mut(&input, "dual", false)); + out_type.clone().map(|x| outputs.push(x)); + } + (Mode::Forward, Activity::Duplicated, None) => outputs.push(only_type(&input)), + (Mode::Reverse, Activity::Duplicated, Some(false)) => { + res_inputs.push(as_ref_mut(&input, "grad", true)); + add_inputs.push(as_ref_mut(&input, "grad", true)); + } + (Mode::Reverse, Activity::Duplicated | Activity::DuplicatedNoNeed, Some(true)) => { + res_inputs.push(as_ref_mut(&input, "grad", false)); + add_inputs.push(as_ref_mut(&input, "grad", false)); + } + (Mode::Reverse, Activity::Active, None) => outputs.push(only_type(&input)), + _ => {} + } + } + + match (item.header.mode, item.header.ret_act) { + (Mode::Reverse, Activity::Active) => { + let t: FnArg = match &item.primal.output { + ReturnType::Type(_, ty) => parse_quote!(tang_y: #ty), + _ => panic!(""), + }; + res_inputs.push(t.clone()); + add_inputs.push(t); + } + _ => {} + } + + // for adjoint function -> take header if primal + // -> take ident of primal function + let adjoint_ident = if item.block.is_some() { + if let Some(ident) = item.header.name.get_ident() { + ident.clone() + } else { + abort!( + item.header.name, + "not a function name"; + help = "`#[autodiff]` function name should be a single word instead of path" + ); + } + } else { + item.primal.ident.clone() + }; + + let output = match outputs.len() { + 0 => quote!(), + 1 => { + let output = outputs.first().unwrap(); + + quote!(-> #output) + } + _ => quote!(-> (#(#outputs,)*)), + }; + + let sig = quote!(fn #adjoint_ident(#(#res_inputs,)*) #output); + let inputs = inputs + .iter() + .map(|x| match x { + FnArg::Typed(ty) => { + let pat = &ty.pat; + quote!(#pat) + } + FnArg::Receiver(_) => quote!(self), + }) + .collect::>(); + let add_inputs = add_inputs + .iter() + .map(|x| match x { + FnArg::Typed(ty) => { + let pat = &ty.pat; + quote!(#pat) + } + FnArg::Receiver(_) => quote!(self), + }) + .collect::>(); + + let call_ident = match item.block.is_some() { + false => { + let ident = format_ident!("primal_{}", ident); + if item.header.name.segments.first().unwrap().ident == "Self" { + quote!(Self::#ident) + } else { + quote!(#ident) + } + } + true => quote!(#ident), + }; + + let body = quote!({ + core::hint::black_box((#call_ident(#(#inputs,)*), #(#add_inputs,)*)); + + core::hint::black_box(unsafe { core::mem::zeroed() }) + }); + let header = generate_header(&item); + + quote!( + #header + #sig + #body + ) +} diff --git a/library/autodiff/src/lib.rs b/library/autodiff/src/lib.rs new file mode 100644 index 0000000000000..b1d265fa9c59b --- /dev/null +++ b/library/autodiff/src/lib.rs @@ -0,0 +1,31 @@ +use proc_macro::TokenStream; +use proc_macro_error::proc_macro_error; +use quote::quote; + +mod gen; +mod parser; + +#[proc_macro_attribute] +#[proc_macro_error] +pub fn autodiff(args: TokenStream, input: TokenStream) -> TokenStream { + let mut params = parser::parse(args.into(), input.clone().into()); + let (primal, adjoint) = (gen::primal_fnc(&mut params), gen::adjoint_fnc(¶ms)); + + let res = quote!( + #primal + #adjoint + ); + + res.into() +} + +#[test] +pub fn expanding() { + macrotest::expand("tests/expand/*.rs"); +} + +#[test] +fn ui() { + let t = trybuild::TestCases::new(); + t.compile_fail("tests/ui/*.rs"); +} diff --git a/library/autodiff/src/parser.rs b/library/autodiff/src/parser.rs new file mode 100644 index 0000000000000..d11eea24d5015 --- /dev/null +++ b/library/autodiff/src/parser.rs @@ -0,0 +1,464 @@ +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::{format_ident, quote}; +use syn::{ + parse::Parser, parse_quote, punctuated::Punctuated, Attribute, Block, FnArg, ForeignItemFn, + Ident, Item, Path, ReturnType, Signature, Token, Type, +}; + +#[derive(Debug)] +pub struct PrimalSig { + pub(crate) ident: Ident, + pub(crate) inputs: Vec, + pub(crate) output: ReturnType, +} + +#[derive(Debug)] +pub struct DiffItem { + pub(crate) header: Header, + pub(crate) params: Vec, + pub(crate) primal: PrimalSig, + pub(crate) block: Option>, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub(crate) enum Mode { + Forward, + Reverse, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub(crate) enum Activity { + Const, + Active, + Duplicated, + DuplicatedNoNeed, +} + +impl Activity { + fn from_header(name: Option<&Ident>) -> Activity { + if name.is_none() { + return Activity::Const; + } + + match name.unwrap().to_string().as_str() { + "Const" => Activity::Const, + "Active" => Activity::Active, + "Duplicated" => Activity::Duplicated, + "DuplicatedNoNeed" => Activity::DuplicatedNoNeed, + _ => { + abort!( + name, + "unknown activity"; + help = "`#[autodiff]` should use activities (Const|Active|Duplicated|DuplicatedNoNeed)" + ); + } + } + } + + fn from_inline(name: Attribute) -> Activity { + let name = name.path.segments.first().unwrap(); + match name.ident.to_string().as_str() { + "const" => Activity::Const, + "active" => Activity::Active, + "dup" => Activity::Duplicated, + "dup_noneed" => Activity::DuplicatedNoNeed, + _ => { + abort!( + name, + "unknown activity"; + help = "`#[autodiff]` should use activities (const|active|dup|dup_noneed)" + ); + } + } + } + + pub(crate) fn to_ident(&self) -> Ident { + format_ident!( + "{}", + match self { + Activity::Const => "Const", + Activity::Active => "Active", + Activity::Duplicated => "Duplicated", + Activity::DuplicatedNoNeed => "DuplicatedNoNeed", + } + ) + } +} + +#[derive(Debug)] +pub(crate) struct Header { + pub name: Path, + pub mode: Mode, + pub ret_act: Activity, +} + +impl Header { + fn from_params(name: &Path, mode: Option<&Ident>, ret_activity: Option<&Ident>) -> Self { + // parse mode and return activity + let mode = mode + .map(|x| match x.to_string().as_str() { + "forward" | "Forward" => Mode::Forward, + "reverse" | "Reverse" => Mode::Reverse, + _ => { + abort!( + mode, + "should be forward or reverse"; + help = "`#[autodiff]` modes should be either forward or reverse" + ); + } + }) + .unwrap_or(Mode::Forward); + let ret_act = Activity::from_header(ret_activity); + + // check for invalid mode and return activity combinations + match (mode, ret_act) { + (Mode::Forward, Activity::Active) => abort!( + ret_activity, + "active return for forward mode"; + help = "`#[autodiff]` return should be Const, Duplicated or DuplicatedNoNeed in forward mode" + ), + (Mode::Reverse, Activity::Duplicated | Activity::DuplicatedNoNeed) => abort!( + ret_activity, + "duplicated return for reverse mode"; + help = "`#[autodiff]` return should be Const or Active in reverse mode" + ), + + _ => {} + } + + Header { name: name.clone(), mode, ret_act } + } + + fn parse(args: TokenStream) -> (Header, Vec) { + let args_parsed: Vec<_> = + match Punctuated::::parse_terminated.parse(args.clone().into()) { + Ok(x) => x.into_iter().collect(), + Err(_) => abort!( + args, + "duplicated return for reverse mode"; + help = "`#[autodiff]` return should be Const or Active in reverse mode" + ), + }; + + match &args_parsed[..] { + [name] => (Self::from_params(&name, None, None), vec![]), + [name, mode] => { + (Self::from_params(&name, Some(&mode.get_ident().unwrap()), None), vec![]) + } + [name, mode, ret_act, rem @ ..] => { + let params = Self::from_params( + &name, + Some(&mode.get_ident().unwrap()), + Some(&ret_act.get_ident().unwrap()), + ); + let rem = rem.into_iter() + .map(|x| x.get_ident().unwrap()) + .map(|x| Activity::from_header(Some(x))) + .map(|x| match (params.mode, x) { + (Mode::Forward, Activity::Active) => { + abort!( + args, + "active argument in forward mode"; + help = "`#[autodiff]` forward mode should be either Const, Duplicated" + ); + }, + (_, x) => x, + }) + .collect(); + + (params, rem) + } + _ => { + abort!( + args, + "please specify the autodiff function"; + help = "`#[autodiff]` needs a function name for primal or adjoint" + ); + } + } + } +} + +pub(crate) fn is_ref_mut(t: &FnArg) -> Option { + match t { + FnArg::Receiver(pat) => Some(pat.mutability.is_some()), + FnArg::Typed(pat) => match &*pat.ty { + Type::Reference(t) => Some(t.mutability.is_some()), + _ => None, + }, + } +} + +fn is_scalar(t: &Type) -> bool { + let t_f32: Type = parse_quote!(f32); + let t_f64: Type = parse_quote!(f64); + t == &t_f32 || t == &t_f64 +} + +fn ret_arg(arg: &FnArg) -> Type { + match arg { + FnArg::Receiver(_) => parse_quote!(Self), + FnArg::Typed(t) => match &*t.ty { + Type::Reference(t) => *t.elem.clone(), + x => x.clone(), + }, + } +} + +pub(crate) fn reduce_params( + mut sig: Signature, + header_acts: Vec, + is_adjoint: bool, + header: &Header, +) -> (PrimalSig, Vec) { + let mut args = Vec::new(); + let mut ret = Vec::new(); + let mut acts = Vec::new(); + let mut last_arg: Option = None; + + let mut arg_it = sig.inputs.iter_mut(); + let mut header_acts_it = header_acts.iter(); + + while let Some(arg) = arg_it.next() { + // Compare current with last argument when parsing duplicated rules. This only + // happens when we parse the signature of adjoint/augmented primal function + if let Some(prev_arg) = last_arg.take() { + match (header.mode, is_ref_mut(&prev_arg), is_ref_mut(&arg)) { + (Mode::Forward, Some(false), Some(true) | None) => abort!( + arg, + "should be an immutable reference"; + help = "`#[autodiff]` input parameter should duplicate tangent into second parameter for forward mode" + ), + (Mode::Forward, Some(true), Some(false) | None) => abort!( + arg, + "should be a mutable reference"; + help = "`#[autodiff]` output parameter should duplicate derivative into second parameter for forward mode" + ), + (Mode::Reverse, Some(false), Some(false) | None) => abort!( + arg, + "should be a mutable reference"; + help = "`#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode" + ), + (Mode::Reverse, Some(true), Some(true) | None) => abort!( + arg, + "should be an immutable reference"; + help = "`#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode" + ), + _ => {} + } + + continue; + } + + // parse current attribute macro + let attrs: Vec<_> = match arg { + FnArg::Typed(pat) => pat.attrs.drain(..).collect(), + FnArg::Receiver(pat) => pat.attrs.drain(..).collect(), + }; + let attr = attrs.first(); + let act: Activity = match (header_acts.is_empty(), attr) { + (false, None) => header_acts_it.next().map(|x| *x).unwrap_or(Activity::Const), + (true, Some(x)) => Activity::from_inline(x.clone()), + (true, None) => Activity::Const, + _ => { + abort!( + arg, + "inline activity"; + help = "`#[autodiff]` should have activities either specified in header or as inline attributes" + ); + } + }; + + // compare indirection with activity + match (header.mode, is_ref_mut(&arg), act) { + (Mode::Forward, None, Activity::Duplicated) => abort!( + arg, + "type not behind reference"; + help = "`#[autodiff]` duplicated types should be behind a reference" + ), + (Mode::Forward, Some(false), Activity::DuplicatedNoNeed) => abort!( + arg, + "should be mutable reference"; + help = "`#[autodiff]` parameter should be output for DuplicatedNoNeed activity" + ), + (Mode::Reverse, Some(_), Activity::Active) => abort!( + arg, + "type behind reference"; + help = "`#[autodiff]` active parameter should be concrete in reverse mode" + ), + (Mode::Reverse, None, Activity::Duplicated | Activity::DuplicatedNoNeed) => abort!( + arg, + "type not behind reference"; + help = "`#[autodiff]` duplicated parameters should be behind reference in reverse mode" + ), + (Mode::Reverse, Some(false), Activity::DuplicatedNoNeed) => abort!( + arg, + "use duplicated instead"; + help = "`#[autodiff]` input parameter cannot be declared as duplicatednoneed" + ), + (Mode::Forward, Some(false), Activity::Duplicated) + if header.ret_act != Activity::Const => + { + ret.push(ret_arg(&arg)) + } + (Mode::Reverse, None, Activity::Active) => ret.push(ret_arg(&arg)), + (Mode::Forward, Some(_), Activity::Duplicated | Activity::DuplicatedNoNeed) + | (Mode::Reverse, _, Activity::Duplicated | Activity::DuplicatedNoNeed) + if is_adjoint => + { + last_arg = Some(arg.clone()) + } + _ => {} + } + + args.push(arg.clone()); + acts.push(act); + } + + // if we have adjoint signature and are in forward mode + // if duplicated -> return type * (n + 1) times + // if duplicated_no_need -> return type * n times + // if const -> no return + + // if we have adjoint signature and are in reverse mode + // if active -> input type * n times + // construct return type based on mode + let ret = if is_adjoint { + let ret_typs = match &sig.output { + ReturnType::Type(_, ref x) => match &**x { + Type::Tuple(x) => x.elems.iter().cloned().collect(), + x => vec![x.clone()], + }, + ReturnType::Default => vec![], + }; + + match (header.mode, header.ret_act) { + (Mode::Forward, Activity::Duplicated) => { + let expected = ret_typs[0].clone(); + let list = vec![expected.clone(); ret.len() + 1]; + + if list != ret_typs { + let ret = quote!((#(#list,)*)); + abort!( + sig.output, + "invalid output"; + help = format!("`#[autodiff]` expected {}", ret) + ); + } + + parse_quote!(-> #expected) + } + (Mode::Forward, Activity::DuplicatedNoNeed) => { + let expected = ret_typs[0].clone(); + let list = vec![expected.clone(); ret.len()]; + + if list != ret_typs { + let ret = quote!((#(#list,)*)); + abort!( + sig.output, + "invalid output"; + help = format!("`#[autodiff]` expected {}", ret) + ); + } + + parse_quote!(-> #expected) + } + (Mode::Reverse, Activity::Active) => { + // tangent of output is latest in parameter list + let ret_typ = match (args.pop(), acts.pop()) { + (Some(x), Some(y)) => { + let x = ret_arg(&x); + if !is_scalar(&x) { + abort!( + x, + "output tangent not a floating point"; + help = "`#[autodiff]` the output tangent should be a floating point" + ); + } else if y != Activity::Const { + abort!( + x, + "output tangent not const"; + help = "`#[autodiff]` the last parameter of an adjoint with active return should be a constant tangent" + ); + } else { + parse_quote!(-> #x) + } + } + (None, None) => abort!( + sig, + "missing output tangent parameter"; + help = "`#[autodiff]` the last parameter of an adjoint with active return should exist" + ), + _ => unreachable!(), + }; + + // check that the return tuple confirms with return types + if ret_typs != ret { + let ret = quote!((#(#ret,)*)); + abort!( + sig.output, + "invalid output"; + help = format!("`#[autodiff]` expected {}", ret) + ) + } + + ret_typ + } + (_, Activity::Const) if ret.len() > 0 => { + abort!( + ret[0], + "constant return but more than one return"; + help = "`#[autodiff]` adjoint should have a return type when active" + ) + } + _ => ReturnType::Default, + } + } else { + if header.ret_act != Activity::Const && sig.output == ReturnType::Default { + abort!( + sig, + "no return type"; + help = "`#[autodiff]` non-const return activity but no return type" + ) + } + + sig.output.clone() + }; + + let sig = if is_adjoint { + // header is used for calling if we are adjoint + format_ident!("{}", sig.ident) + } else { + sig.ident.clone() + }; + + (PrimalSig { ident: sig, inputs: args, output: ret }, acts) +} + +pub(crate) fn parse(args: TokenStream, input: TokenStream) -> DiffItem { + // first parse function + let (_attrs, _, sig, block) = match syn::parse2::(input) { + Ok(Item::Fn(item)) => (item.attrs, item.vis, item.sig, Some(item.block)), + Ok(Item::Verbatim(x)) => match syn::parse2::(x) { + Ok(item) => (item.attrs, item.vis, item.sig, None), + Err(err) => panic!("Could not parse item {}", err), + }, + Ok(item) => { + abort!( + item, + "item is not a function"; + help = "`#[autodiff]` can only be used on primal or adjoint functions" + ) + } + Err(err) => panic!("Could not parse item: {}", err), + }; + + // then parse attributes + let (header, param_attrs) = Header::parse(args); + + // reduce parameters to primal parameter set + let (primal, params) = reduce_params(sig, param_attrs, !block.is_some(), &header); + + DiffItem { header, primal, params, block } +} diff --git a/library/autodiff/tests/expand/forward_duplicated.expanded.rs b/library/autodiff/tests/expand/forward_duplicated.expanded.rs new file mode 100644 index 0000000000000..c3e30939f92d4 --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated.expanded.rs @@ -0,0 +1,10 @@ +use autodiff::autodiff; +#[autodiff_into] +fn square(a: &Vec, b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} +#[autodiff_into(Forward, Const, Duplicated, Duplicated)] +fn d_square(a: &Vec, dual_a: &Vec, b: &mut f32, grad_b: &mut f32) { + core::hint::black_box((square(a, b), dual_a, grad_b)); + core::hint::black_box(unsafe { core::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/forward_duplicated.rs b/library/autodiff/tests/expand/forward_duplicated.rs new file mode 100644 index 0000000000000..9a0bfc6c13a47 --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_square, Forward, Const)] +fn square(#[dup] a: &Vec, #[dup] b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} diff --git a/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs b/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs new file mode 100644 index 0000000000000..12b3cb898797f --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs @@ -0,0 +1,15 @@ +use autodiff::autodiff; +#[autodiff_into] +fn square2(a: &Vec, b: &Vec) -> f32 { + a.into_iter().map(f32::square).sum() +} +#[autodiff_into(Forward, Duplicated, Duplicated, Duplicated)] +fn d_square2( + a: &Vec, + dual_a: &Vec, + b: &Vec, + dual_b: &Vec, +) -> (f32, f32, f32) { + core::hint::black_box((square2(a, b), dual_a, dual_b)); + core::hint::black_box(unsafe { core::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/forward_duplicated_return.rs b/library/autodiff/tests/expand/forward_duplicated_return.rs new file mode 100644 index 0000000000000..3397e5309ea96 --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated_return.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_square2, Forward, Duplicated)] +fn square2(#[dup] a: &Vec, #[dup] b: &Vec) -> f32 { + a.into_iter().map(f32::square).sum() +} diff --git a/library/autodiff/tests/expand/reverse_duplicated.expanded.rs b/library/autodiff/tests/expand/reverse_duplicated.expanded.rs new file mode 100644 index 0000000000000..04f462a09b2dd --- /dev/null +++ b/library/autodiff/tests/expand/reverse_duplicated.expanded.rs @@ -0,0 +1,10 @@ +use autodiff::autodiff; +#[autodiff_into] +fn square(a: &Vec, b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} +#[autodiff_into(Reverse, Const, Duplicated, Duplicated)] +fn d_square(a: &Vec, grad_a: &mut Vec, b: &mut f32, grad_b: &f32) { + core::hint::black_box((square(a, b), grad_a, grad_b)); + core::hint::black_box(unsafe { core::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/reverse_duplicated.rs b/library/autodiff/tests/expand/reverse_duplicated.rs new file mode 100644 index 0000000000000..107a708bec848 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_duplicated.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_square, Reverse, Const)] +fn square(#[dup] a: &Vec, #[dup] b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} diff --git a/library/autodiff/tests/expand/reverse_return_array.expanded.rs b/library/autodiff/tests/expand/reverse_return_array.expanded.rs new file mode 100644 index 0000000000000..48e0d99fd2797 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_array.expanded.rs @@ -0,0 +1,10 @@ +use autodiff::autodiff; +#[autodiff_into] +fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 { + arr[0][0][0] * arr[1][1][1] +} +#[autodiff_into(Reverse, Active, Duplicated)] +fn d_array(arr: &[[[f32; 2]; 2]; 2], grad_arr: &mut [[[f32; 2]; 2]; 2], tang_y: f32) { + core::hint::black_box((array(arr), grad_arr, tang_y)); + core::hint::black_box(unsafe { core::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/reverse_return_array.rs b/library/autodiff/tests/expand/reverse_return_array.rs new file mode 100644 index 0000000000000..da080a6b3a860 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_array.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_array, Reverse, Active)] +fn array(#[dup] arr: &[[[f32; 2]; 2]; 2]) -> f32 { + arr[0][0][0] * arr[1][1][1] +} diff --git a/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs b/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs new file mode 100644 index 0000000000000..3517912222615 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs @@ -0,0 +1,17 @@ +use autodiff::autodiff; +#[autodiff_into] +fn sqrt(a: f32, b: &f32, c: &f32, d: f32) -> f32 { + a * (b * b + c * c * d * d).sqrt() +} +#[autodiff_into(Reverse, Active, Active, Duplicated, Const, Active)] +fn d_sqrt( + a: f32, + b: &f32, + grad_b: &mut f32, + c: &f32, + d: f32, + tang_y: f32, +) -> (f32, f32) { + core::hint::black_box((sqrt(a, b, c, d), grad_b, tang_y)); + core::hint::black_box(unsafe { core::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/reverse_return_mixed.rs b/library/autodiff/tests/expand/reverse_return_mixed.rs new file mode 100644 index 0000000000000..3260c3560d523 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_mixed.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sqrt, Reverse, Active)] +fn sqrt(#[active] a: f32, #[dup] b: &f32, c: &f32, #[active] d: f32) -> f32 { + a * (b * b + c*c*d*d).sqrt() +} diff --git a/library/autodiff/tests/ui/active_in_forward_mode.rs b/library/autodiff/tests/ui/active_in_forward_mode.rs new file mode 100644 index 0000000000000..10366b1b422b8 --- /dev/null +++ b/library/autodiff/tests/ui/active_in_forward_mode.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Forward, DuplicatedNoNeed, Active)] +fn sin(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/active_in_forward_mode.stderr b/library/autodiff/tests/ui/active_in_forward_mode.stderr new file mode 100644 index 0000000000000..cd413564068ae --- /dev/null +++ b/library/autodiff/tests/ui/active_in_forward_mode.stderr @@ -0,0 +1,7 @@ +error: active argument in forward mode + --> tests/ui/active_in_forward_mode.rs:3:12 + | +3 | #[autodiff(d_sin, Forward, DuplicatedNoNeed, Active)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` forward mode should be either Const, Duplicated diff --git a/library/autodiff/tests/ui/activities_inline_and_header.rs b/library/autodiff/tests/ui/activities_inline_and_header.rs new file mode 100644 index 0000000000000..1ecf37ec60a8f --- /dev/null +++ b/library/autodiff/tests/ui/activities_inline_and_header.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Active, Active)] +fn sin(#[active] x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/activities_inline_and_header.stderr b/library/autodiff/tests/ui/activities_inline_and_header.stderr new file mode 100644 index 0000000000000..b4d50d02a26a4 --- /dev/null +++ b/library/autodiff/tests/ui/activities_inline_and_header.stderr @@ -0,0 +1,7 @@ +error: inline activity + --> tests/ui/activities_inline_and_header.rs:4:18 + | +4 | fn sin(#[active] x: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` should have activities either specified in header or as inline attributes diff --git a/library/autodiff/tests/ui/invalid_indirection.rs b/library/autodiff/tests/ui/invalid_indirection.rs new file mode 100644 index 0000000000000..627a7cb0fc6f9 --- /dev/null +++ b/library/autodiff/tests/ui/invalid_indirection.rs @@ -0,0 +1,19 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Const)] +fn duplicated_without_reference(#[dup] x: f32) { +} + +#[autodiff(d_sin, Reverse, Const)] +fn active_with_reference(#[active] x: &f32) { +} + +#[autodiff(d_sin, Forward, Const)] +fn duplicated_forward(#[dup] x: f32) { +} + +#[autodiff(d_sin, Forward, Const)] +fn duplicated_no_need_forward(#[dup_noneed] x: &f32) { +} + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_indirection.stderr b/library/autodiff/tests/ui/invalid_indirection.stderr new file mode 100644 index 0000000000000..cb27c542018e5 --- /dev/null +++ b/library/autodiff/tests/ui/invalid_indirection.stderr @@ -0,0 +1,31 @@ +error: type not behind reference + --> tests/ui/invalid_indirection.rs:4:40 + | +4 | fn duplicated_without_reference(#[dup] x: f32) { + | ^^^^^^ + | + = help: `#[autodiff]` duplicated parameters should be behind reference in reverse mode + +error: type behind reference + --> tests/ui/invalid_indirection.rs:8:36 + | +8 | fn active_with_reference(#[active] x: &f32) { + | ^^^^^^^ + | + = help: `#[autodiff]` active parameter should be concrete in reverse mode + +error: type not behind reference + --> tests/ui/invalid_indirection.rs:12:30 + | +12 | fn duplicated_forward(#[dup] x: f32) { + | ^^^^^^ + | + = help: `#[autodiff]` duplicated types should be behind a reference + +error: should be mutable reference + --> tests/ui/invalid_indirection.rs:16:45 + | +16 | fn duplicated_no_need_forward(#[dup_noneed] x: &f32) { + | ^^^^^^^ + | + = help: `#[autodiff]` parameter should be output for DuplicatedNoNeed activity diff --git a/library/autodiff/tests/ui/invalid_mutability_pairs.rs b/library/autodiff/tests/ui/invalid_mutability_pairs.rs new file mode 100644 index 0000000000000..708ecc597a5be --- /dev/null +++ b/library/autodiff/tests/ui/invalid_mutability_pairs.rs @@ -0,0 +1,24 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Forward, Duplicated)] +fn fwd_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + +#[autodiff(d_sin, Forward, Duplicated)] +fn output_immutable(#[dup] x: &mut f32, y: &f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn rev_input_no_reference(#[dup] x: &f32, y: f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn rev_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn input_immutable(#[dup] x: &f32, y: &f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn output_mutable(#[dup] x: &mut f32, y: &mut f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn dupnoneed_input(#[dup_noneed] x: &f32, y: &f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_mutability_pairs.stderr b/library/autodiff/tests/ui/invalid_mutability_pairs.stderr new file mode 100644 index 0000000000000..37af0c2ad52ee --- /dev/null +++ b/library/autodiff/tests/ui/invalid_mutability_pairs.stderr @@ -0,0 +1,55 @@ +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:4:48 + | +4 | fn fwd_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` output parameter should duplicate derivative into second parameter for forward mode + +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:7:41 + | +7 | fn output_immutable(#[dup] x: &mut f32, y: &f32) -> f32; + | ^^^^^^^ + | + = help: `#[autodiff]` output parameter should duplicate derivative into second parameter for forward mode + +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:10:43 + | +10 | fn rev_input_no_reference(#[dup] x: &f32, y: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: should be an immutable reference + --> tests/ui/invalid_mutability_pairs.rs:13:48 + | +13 | fn rev_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:16:36 + | +16 | fn input_immutable(#[dup] x: &f32, y: &f32) -> f32; + | ^^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: should be an immutable reference + --> tests/ui/invalid_mutability_pairs.rs:19:39 + | +19 | fn output_mutable(#[dup] x: &mut f32, y: &mut f32) -> f32; + | ^^^^^^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: use duplicated instead + --> tests/ui/invalid_mutability_pairs.rs:22:34 + | +22 | fn dupnoneed_input(#[dup_noneed] x: &f32, y: &f32) -> f32; + | ^^^^^^^ + | + = help: `#[autodiff]` input parameter cannot be declared as duplicatednoneed diff --git a/library/autodiff/tests/ui/invalid_return.rs b/library/autodiff/tests/ui/invalid_return.rs new file mode 100644 index 0000000000000..b3c8bce1166bf --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return.rs @@ -0,0 +1,12 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Forward, Active)] +fn sin1(x: f32) -> f32; + +#[autodiff(d_sin, Reverse, Duplicated)] +fn sin2(x: f32) -> f32; + +#[autodiff(d_sin, Reverse, DuplicatedNoNeed)] +fn sin3(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_return.stderr b/library/autodiff/tests/ui/invalid_return.stderr new file mode 100644 index 0000000000000..4ddaccdba0f72 --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return.stderr @@ -0,0 +1,23 @@ +error: active return for forward mode + --> tests/ui/invalid_return.rs:3:28 + | +3 | #[autodiff(d_sin, Forward, Active)] + | ^^^^^^ + | + = help: `#[autodiff]` return should be Const, Duplicated or DuplicatedNoNeed in forward mode + +error: duplicated return for reverse mode + --> tests/ui/invalid_return.rs:6:28 + | +6 | #[autodiff(d_sin, Reverse, Duplicated)] + | ^^^^^^^^^^ + | + = help: `#[autodiff]` return should be Const or Active in reverse mode + +error: duplicated return for reverse mode + --> tests/ui/invalid_return.rs:9:28 + | +9 | #[autodiff(d_sin, Reverse, DuplicatedNoNeed)] + | ^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` return should be Const or Active in reverse mode diff --git a/library/autodiff/tests/ui/invalid_return_type.rs b/library/autodiff/tests/ui/invalid_return_type.rs new file mode 100644 index 0000000000000..7b91ccd2d650a --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return_type.rs @@ -0,0 +1,16 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Active)] +fn active_but_no_return(#[active] x: f32) { +} + +#[autodiff(d_sin, Reverse, Active)] +fn invalid_primal_value(#[active] x: f32, #[active] y: Vec, #[active] z: Tensor, y_tang: f32) -> (i32, f32); + +#[autodiff(d_sin, Forward, Duplicated)] +fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32, f32); + +#[autodiff(d_sin, Forward, DuplicatedNoNeed)] +fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32); + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_return_type.stderr b/library/autodiff/tests/ui/invalid_return_type.stderr new file mode 100644 index 0000000000000..90e5e47a2a33d --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return_type.stderr @@ -0,0 +1,31 @@ +error: no return type + --> tests/ui/invalid_return_type.rs:4:1 + | +4 | fn active_but_no_return(#[active] x: f32) { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` non-const return activity but no return type + +error: invalid output + --> tests/ui/invalid_return_type.rs:8:100 + | +8 | fn invalid_primal_value(#[active] x: f32, #[active] y: Vec, #[active] z: Tensor, y_tang: f32) -> (i32, f32); + | ^^^^^^^^^^^^^ + | + = help: `#[autodiff]` expected (f32, Vec < f32 >, Tensor,) + +error: invalid output + --> tests/ui/invalid_return_type.rs:11:121 + | +11 | fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32, f32); + | ^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` expected (f32, f32, f32, f32,) + +error: invalid output + --> tests/ui/invalid_return_type.rs:14:121 + | +14 | fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32); + | ^^^^^^^^^^^^^ + | + = help: `#[autodiff]` expected (f32, f32, f32,) diff --git a/library/autodiff/tests/ui/no_function_name.rs b/library/autodiff/tests/ui/no_function_name.rs new file mode 100644 index 0000000000000..8222ca4aaf37d --- /dev/null +++ b/library/autodiff/tests/ui/no_function_name.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff] +fn sin(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/no_function_name.stderr b/library/autodiff/tests/ui/no_function_name.stderr new file mode 100644 index 0000000000000..e98add3164c9f --- /dev/null +++ b/library/autodiff/tests/ui/no_function_name.stderr @@ -0,0 +1,8 @@ +error: please specify the autodiff function + --> tests/ui/no_function_name.rs:3:1 + | +3 | #[autodiff] + | ^^^^^^^^^^^ + | + = help: `#[autodiff]` needs a function name for primal or adjoint + = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/library/autodiff/tests/ui/not_a_function.rs b/library/autodiff/tests/ui/not_a_function.rs new file mode 100644 index 0000000000000..0a3c11725a086 --- /dev/null +++ b/library/autodiff/tests/ui/not_a_function.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff] +struct NotAFunction; + +fn main() {} diff --git a/library/autodiff/tests/ui/not_a_function.stderr b/library/autodiff/tests/ui/not_a_function.stderr new file mode 100644 index 0000000000000..c681841532a5e --- /dev/null +++ b/library/autodiff/tests/ui/not_a_function.stderr @@ -0,0 +1,7 @@ +error: item is not a function + --> tests/ui/not_a_function.rs:4:1 + | +4 | struct NotAFunction; + | ^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` can only be used on primal or adjoint functions diff --git a/library/autodiff/tests/ui/reverse_tangent.rs b/library/autodiff/tests/ui/reverse_tangent.rs new file mode 100644 index 0000000000000..603f7fd1789ce --- /dev/null +++ b/library/autodiff/tests/ui/reverse_tangent.rs @@ -0,0 +1,12 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Active)] +fn invalid_output_tangent_type(#[active] x: f32, y_tang: i32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn active_output_tangent(#[active] x: f32, #[active] y_tang: f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn tangent_missing() -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/reverse_tangent.stderr b/library/autodiff/tests/ui/reverse_tangent.stderr new file mode 100644 index 0000000000000..a7b4b6e3d97d6 --- /dev/null +++ b/library/autodiff/tests/ui/reverse_tangent.stderr @@ -0,0 +1,23 @@ +error: output tangent not a floating point + --> tests/ui/reverse_tangent.rs:4:58 + | +4 | fn invalid_output_tangent_type(#[active] x: f32, y_tang: i32) -> f32; + | ^^^ + | + = help: `#[autodiff]` the output tangent should be a floating point + +error: output tangent not const + --> tests/ui/reverse_tangent.rs:7:62 + | +7 | fn active_output_tangent(#[active] x: f32, #[active] y_tang: f32) -> f32; + | ^^^ + | + = help: `#[autodiff]` the last parameter of an adjoint with active return should be a constant tangent + +error: missing output tangent parameter + --> tests/ui/reverse_tangent.rs:10:1 + | +10 | fn tangent_missing() -> f32; + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` the last parameter of an adjoint with active return should exist diff --git a/library/autodiff/tests/ui/wrong_mode.rs b/library/autodiff/tests/ui/wrong_mode.rs new file mode 100644 index 0000000000000..1b500711de109 --- /dev/null +++ b/library/autodiff/tests/ui/wrong_mode.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, WrongMode)] +fn sin(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/wrong_mode.stderr b/library/autodiff/tests/ui/wrong_mode.stderr new file mode 100644 index 0000000000000..ca18d81abb306 --- /dev/null +++ b/library/autodiff/tests/ui/wrong_mode.stderr @@ -0,0 +1,7 @@ +error: should be forward or reverse + --> tests/ui/wrong_mode.rs:3:19 + | +3 | #[autodiff(d_sin, WrongMode)] + | ^^^^^^^^^ + | + = help: `#[autodiff]` modes should be either forward or reverse diff --git a/library/core/src/macros/mod.rs b/library/core/src/macros/mod.rs index 125a6f57bfbaa..658b640ae4ce3 100644 --- a/library/core/src/macros/mod.rs +++ b/library/core/src/macros/mod.rs @@ -1416,6 +1416,18 @@ pub(crate) mod builtin { }; } + /// Differentiate function + ///#[unstable( + /// feature = "autodiff", + /// issue = "29598", + /// reason = "autodiff is not stable enough" + ///)] + ///#[rustc_builtin_macro] + ///#[macro_export] + ///pub macro autodiff($item:item) { + /// /* compiler built-in */ + ///} + /// Parses a file as an expression or an item according to the context. /// /// **Warning**: For multi-file Rust projects, the `include!` macro is probably not what you diff --git a/src/bootstrap/configure.py b/src/bootstrap/configure.py index bfef3e672407d..6369a1a557a8f 100755 --- a/src/bootstrap/configure.py +++ b/src/bootstrap/configure.py @@ -70,6 +70,7 @@ def v(*args): # channel, etc. o("optimize-llvm", "llvm.optimize", "build optimized LLVM") o("llvm-assertions", "llvm.assertions", "build LLVM with assertions") +o("llvm-enzyme", "llvm.enzyme", "build LLVM with Enzyme") o("llvm-plugins", "llvm.plugins", "build LLVM with plugin interface") o("debug-assertions", "rust.debug-assertions", "build with debugging assertions") o("debug-assertions-std", "rust.debug-assertions-std", "build the standard library with debugging assertions") diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index 441931e415cc6..775b86628be79 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1539,6 +1539,7 @@ pub struct Assemble { pub target_compiler: Compiler, } +#[allow(unreachable_code)] impl Step for Assemble { type Output = Compiler; const ONLY_HOSTS: bool = true; @@ -1599,6 +1600,24 @@ impl Step for Assemble { return target_compiler; } + // Build enzyme + let enzyme_install = if builder.config.llvm_enzyme { + Some(builder.ensure(llvm::Enzyme { target: build_compiler.host })) + } else { + None + }; + + if let Some(enzyme_install) = enzyme_install { + let src_lib = enzyme_install.join("build/Enzyme/LLVMEnzyme-17.so"); + + let libdir = builder.sysroot_libdir(build_compiler, build_compiler.host); + let target_libdir = builder.sysroot_libdir(target_compiler, target_compiler.host); + let dst_lib = libdir.join("libLLVMEnzyme-17.so"); + let target_dst_lib = target_libdir.join("libLLVMEnzyme-17.so"); + builder.copy(&src_lib, &dst_lib); + builder.copy(&src_lib, &target_dst_lib); + } + // Build the libraries for this compiler to link to (i.e., the libraries // it uses at runtime). NOTE: Crates the target compiler compiles don't // link to these. (FIXME: Is that correct? It seems to be correct most diff --git a/src/bootstrap/src/core/build_steps/llvm.rs b/src/bootstrap/src/core/build_steps/llvm.rs index 24351118a5aa1..11e377be92e24 100644 --- a/src/bootstrap/src/core/build_steps/llvm.rs +++ b/src/bootstrap/src/core/build_steps/llvm.rs @@ -802,6 +802,72 @@ fn get_var(var_base: &str, host: &str, target: &str) -> Option { .or_else(|| env::var_os(var_base)) } +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct Enzyme { + pub target: TargetSelection, +} + +impl Step for Enzyme { + type Output = PathBuf; + const ONLY_HOSTS: bool = true; + + fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { + run.path("src/tools/enzyme/enzyme") + } + + fn make_run(run: RunConfig<'_>) { + run.builder.ensure(Enzyme { target: run.target }); + } + + /// Compile Enzyme for `target`. + fn run(self, builder: &Builder<'_>) -> PathBuf { + if builder.config.dry_run() { + let out_dir = builder.enzyme_out(self.target); + return out_dir; + } + let target = self.target; + + let LlvmResult { llvm_config, .. } = builder.ensure(Llvm { target: self.target }); + + let out_dir = builder.enzyme_out(target); + let done_stamp = out_dir.join("enzyme-finished-building"); + if done_stamp.exists() { + return out_dir; + } + + builder.info(&format!("Building Enzyme for {}", target)); + let _time = helpers::timeit(&builder); + t!(fs::create_dir_all(&out_dir)); + + builder.update_submodule(&Path::new("src").join("tools").join("enzyme")); + let mut cfg = cmake::Config::new(builder.src.join("src/tools/enzyme/enzyme/")); + // TODO: Find a nicer way to use Enzyme Debug builds + //cfg.profile("Debug"); + //cfg.define("CMAKE_BUILD_TYPE", "Debug"); + configure_cmake(builder, target, &mut cfg, true, LdFlags::default(), &[]); + + // Re-use the same flags as llvm to control the level of debug information + // generated for lld. + let profile = match (builder.config.llvm_optimize, builder.config.llvm_release_debuginfo) { + (false, _) => "Debug", + (true, false) => "Release", + (true, true) => "RelWithDebInfo", + }; + + cfg.out_dir(&out_dir) + .profile(profile) + .env("LLVM_CONFIG_REAL", &llvm_config) + .define("LLVM_ENABLE_ASSERTIONS", "ON") + .define("ENZYME_EXTERNAL_SHARED_LIB", "OFF") + .define("LLVM_DIR", builder.llvm_out(target)); + + cfg.build(); + + t!(File::create(&done_stamp)); + out_dir + } +} + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] pub struct Lld { pub target: TargetSelection, diff --git a/src/bootstrap/src/core/builder.rs b/src/bootstrap/src/core/builder.rs index 90e09d12a9d50..b2fa044cd0da3 100644 --- a/src/bootstrap/src/core/builder.rs +++ b/src/bootstrap/src/core/builder.rs @@ -1371,6 +1371,11 @@ impl<'a> Builder<'a> { } } + // TODO: adjust -14 ending for Enzyme + // https://rust-lang.zulipchat.com/#narrow/stream/182449-t-compiler.2Fhelp/topic/.E2.9C.94.20link.20new.20library.20into.20stage1.2Frustc + rustflags.arg("-l"); + rustflags.arg("LLVMEnzyme-17"); + let use_new_symbol_mangling = match self.config.rust_new_symbol_mangling { Some(setting) => { // If an explicit setting is given, use that diff --git a/src/bootstrap/src/core/config/config.rs b/src/bootstrap/src/core/config/config.rs index 5b5334b0a5572..7034be88aed48 100644 --- a/src/bootstrap/src/core/config/config.rs +++ b/src/bootstrap/src/core/config/config.rs @@ -173,6 +173,7 @@ pub struct Config { // llvm codegen options pub llvm_assertions: bool, pub llvm_tests: bool, + pub llvm_enzyme: bool, pub llvm_plugins: bool, pub llvm_optimize: bool, pub llvm_thin_lto: bool, @@ -823,6 +824,7 @@ define_config! { release_debuginfo: Option = "release-debuginfo", assertions: Option = "assertions", tests: Option = "tests", + enzyme: Option = "enzyme", plugins: Option = "plugins", ccache: Option = "ccache", static_libstdcpp: Option = "static-libstdcpp", @@ -1356,6 +1358,7 @@ impl Config { // we'll infer default values for them later let mut llvm_assertions = None; let mut llvm_tests = None; + let mut llvm_enzyme = None; let mut llvm_plugins = None; let mut debug = None; let mut debug_assertions = None; @@ -1500,6 +1503,7 @@ impl Config { set(&mut config.ninja_in_file, llvm.ninja); llvm_assertions = llvm.assertions; llvm_tests = llvm.tests; + llvm_enzyme = llvm.enzyme; llvm_plugins = llvm.plugins; set(&mut config.llvm_optimize, llvm.optimize); set(&mut config.llvm_thin_lto, llvm.thin_lto); @@ -1565,6 +1569,7 @@ impl Config { check_ci_llvm!(llvm.polly); check_ci_llvm!(llvm.clang); check_ci_llvm!(llvm.build_config); + check_ci_llvm!(llvm.enzyme); check_ci_llvm!(llvm.plugins); } @@ -1658,6 +1663,7 @@ impl Config { config.llvm_assertions = llvm_assertions.unwrap_or(false); config.llvm_tests = llvm_tests.unwrap_or(false); + config.llvm_enzyme = llvm_enzyme.unwrap_or(false); config.llvm_plugins = llvm_plugins.unwrap_or(false); config.rust_optimize = optimize.unwrap_or(RustOptimize::Bool(true)); diff --git a/src/bootstrap/src/lib.rs b/src/bootstrap/src/lib.rs index d7f49a6d11b9c..ac74c632a1b1f 100644 --- a/src/bootstrap/src/lib.rs +++ b/src/bootstrap/src/lib.rs @@ -806,6 +806,10 @@ impl Build { self.out.join(&*target.triple).join("lld") } + fn enzyme_out(&self, target: TargetSelection) -> PathBuf { + self.out.join(&*target.triple).join("enzyme") + } + /// Output directory for all documentation for a target fn doc_out(&self, target: TargetSelection) -> PathBuf { self.out.join(&*target.triple).join("doc") diff --git a/src/test/ui/terminal-width/flag-human.rs b/src/test/ui/terminal-width/flag-human.rs new file mode 100644 index 0000000000000..4b94ebb01fc8e --- /dev/null +++ b/src/test/ui/terminal-width/flag-human.rs @@ -0,0 +1,9 @@ +// compile-flags: --diagnostic-width=20 + +// This test checks that `-Z diagnostic-width` effects the human error output by restricting it to an +// arbitrarily low value so that the effect is visible. + +fn main() { + let _: () = 42; + //~^ ERROR mismatched types +} diff --git a/src/test/ui/terminal-width/flag-json.rs b/src/test/ui/terminal-width/flag-json.rs new file mode 100644 index 0000000000000..3add1d7d9301e --- /dev/null +++ b/src/test/ui/terminal-width/flag-json.rs @@ -0,0 +1,9 @@ +// compile-flags: --diagnostic-width=20 --error-format=json + +// This test checks that `-Z diagnostic-width` effects the JSON error output by restricting it to an +// arbitrarily low value so that the effect is visible. + +fn main() { + let _: () = 42; + //~^ ERROR mismatched types +} diff --git a/src/test/ui/terminal-width/flag-json.stderr b/src/test/ui/terminal-width/flag-json.stderr new file mode 100644 index 0000000000000..b21391d1640ef --- /dev/null +++ b/src/test/ui/terminal-width/flag-json.stderr @@ -0,0 +1,40 @@ +{"message":"mismatched types","code":{"code":"E0308","explanation":"Expected type did not match the received type. + +Erroneous code examples: + +```compile_fail,E0308 +fn plus_one(x: i32) -> i32 { + x + 1 +} + +plus_one(\"Not a number\"); +// ^^^^^^^^^^^^^^ expected `i32`, found `&str` + +if \"Not a bool\" { +// ^^^^^^^^^^^^ expected `bool`, found `&str` +} + +let x: f32 = \"Not a float\"; +// --- ^^^^^^^^^^^^^ expected `f32`, found `&str` +// | +// expected due to this +``` + +This error occurs when an expression was used in a place where the compiler +expected an expression of a different type. It can occur in several cases, the +most common being when calling a function and passing an argument which has a +different type than the matching type in the function declaration. +"},"level":"error","spans":[{"file_name":"$DIR/flag-json.rs","byte_start":243,"byte_end":245,"line_start":7,"line_end":7,"column_start":17,"column_end":19,"is_primary":true,"text":[{"text":" let _: () = 42;","highlight_start":17,"highlight_end":19}],"label":"expected `()`, found integer","suggested_replacement":null,"suggestion_applicability":null,"expansion":null},{"file_name":"$DIR/flag-json.rs","byte_start":238,"byte_end":240,"line_start":7,"line_end":7,"column_start":12,"column_end":14,"is_primary":false,"text":[{"text":" let _: () = 42;","highlight_start":12,"highlight_end":14}],"label":"expected due to this","suggested_replacement":null,"suggestion_applicability":null,"expansion":null}],"children":[],"rendered":"error[E0308]: mismatched types + --> $DIR/flag-json.rs:7:17 + | +LL | ..._: () = 42; + | -- ^^ expected `()`, found integer + | | + | expected due to this + +"} +{"message":"aborting due to previous error","code":null,"level":"error","spans":[],"children":[],"rendered":"error: aborting due to previous error + +"} +{"message":"For more information about this error, try `rustc --explain E0308`.","code":null,"level":"failure-note","spans":[],"children":[],"rendered":"For more information about this error, try `rustc --explain E0308`. +"} diff --git a/src/tools/enzyme b/src/tools/enzyme new file mode 160000 index 0000000000000..01c279c6b4674 --- /dev/null +++ b/src/tools/enzyme @@ -0,0 +1 @@ +Subproject commit 01c279c6b46746172182dc2bf18466b95e67e199 diff --git a/tests/rustdoc-ui/doctest/terminal-width.rs b/tests/rustdoc-ui/doctest/terminal-width.rs new file mode 100644 index 0000000000000..61961d5ec710e --- /dev/null +++ b/tests/rustdoc-ui/doctest/terminal-width.rs @@ -0,0 +1,5 @@ +// compile-flags: -Zunstable-options --diagnostic-width=10 +#![deny(rustdoc::bare_urls)] + +/// This is a long line that contains a http://link.com +pub struct Foo; //~^ ERROR diff --git a/tests/rustdoc-ui/doctest/terminal-width.stderr b/tests/rustdoc-ui/doctest/terminal-width.stderr new file mode 100644 index 0000000000000..fed049d2b37bc --- /dev/null +++ b/tests/rustdoc-ui/doctest/terminal-width.stderr @@ -0,0 +1,15 @@ +error: this URL is not a hyperlink + --> $DIR/diagnostic-width.rs:4:41 + | +LL | ... a http://link.com + | ^^^^^^^^^^^^^^^ help: use an automatic link instead: `` + | +note: the lint level is defined here + --> $DIR/diagnostic-width.rs:2:9 + | +LL | ...ny(rustdoc::bare_url... + | ^^^^^^^^^^^^^^^^^^ + = note: bare URLs are not automatically turned into clickable links + +error: aborting due to previous error + diff --git a/tests/ui/json/autodiff.rs b/tests/ui/json/autodiff.rs new file mode 100644 index 0000000000000..54f94c3765bf6 --- /dev/null +++ b/tests/ui/json/autodiff.rs @@ -0,0 +1,16 @@ +// Check autodiff attribute +// edition:2018 + +extern "C" fn rosenbrock(a: f32, b: f32, x: f32, y: f32) -> f32 { + let (z, w) = (a-x, y-x*x); + + z*z + b*w*w +} + +#[autodiff(rosenbrock, mode = "forward")] +extern "C" { + fn dx_rosenbrock(a: f32, b: f32, x: f32, y: f32, d_x: &mut f32); + fn dy_rosenbrock(a: f32, b: f32, x: f32, y: f32, d_y: &mut f32); +} + +fn main() {}