From 3e4e28aa4021ebd3dc1adc9d2a899d8fcbfe7f11 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 30 Oct 2023 01:28:50 -0400 Subject: [PATCH 01/17] Cleanup Rust-Enzyme history Co-authored-by: Lorenz Schmidt git@lorenzschmidt.com Co-authored-by: William Moses gh@wsmoses.com --- .github/workflows/enzyme-ci.yml | 38 ++ .gitmodules | 3 + Cargo.lock | 1 + Cargo.toml | 1 + README.md | 63 ++- compiler/rustc_ast/src/mut_visit.rs | 2 +- compiler/rustc_codegen_llvm/src/attributes.rs | 3 + compiler/rustc_codegen_llvm/src/back/lto.rs | 3 +- compiler/rustc_codegen_llvm/src/back/write.rs | 306 ++++++++++- compiler/rustc_codegen_llvm/src/base.rs | 12 +- compiler/rustc_codegen_llvm/src/context.rs | 4 + compiler/rustc_codegen_llvm/src/lib.rs | 44 +- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 484 +++++++++++++++++- compiler/rustc_codegen_llvm/src/typetree.rs | 33 ++ .../src/assert_module_sources.rs | 2 +- compiler/rustc_codegen_ssa/src/back/lto.rs | 25 +- .../src/back/symbol_export.rs | 2 +- compiler/rustc_codegen_ssa/src/back/write.rs | 47 +- compiler/rustc_codegen_ssa/src/base.rs | 9 +- .../rustc_codegen_ssa/src/codegen_attrs.rs | 162 +++++- compiler/rustc_codegen_ssa/src/traits/misc.rs | 1 + .../rustc_codegen_ssa/src/traits/write.rs | 12 + compiler/rustc_feature/src/builtin_attrs.rs | 7 + compiler/rustc_interface/src/tests.rs | 1 + .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 5 + compiler/rustc_middle/src/arena.rs | 1 + .../rustc_middle/src/middle/autodiff_attrs.rs | 94 ++++ compiler/rustc_middle/src/middle/mod.rs | 2 + compiler/rustc_middle/src/middle/typetree.rs | 39 ++ compiler/rustc_middle/src/query/erase.rs | 4 + compiler/rustc_middle/src/query/mod.rs | 10 +- compiler/rustc_middle/src/ty/mod.rs | 2 + compiler/rustc_monomorphize/Cargo.toml | 1 + compiler/rustc_monomorphize/src/collector.rs | 1 + .../rustc_monomorphize/src/partitioning.rs | 4 +- compiler/rustc_passes/src/check_attr.rs | 15 + compiler/rustc_resolve/src/lib.rs | 1 + compiler/rustc_session/src/options.rs | 2 + compiler/rustc_span/src/symbol.rs | 2 + config.example.toml | 3 + library/autodiff/Cargo.lock | 314 ++++++++++++ library/autodiff/Cargo.toml | 28 + library/autodiff/examples/array.rs | 23 + library/autodiff/examples/box.rs | 24 + library/autodiff/examples/broken_matvec.rs | 34 ++ library/autodiff/examples/hessian_sin.rs | 28 + library/autodiff/examples/ndarray.rs | 25 + library/autodiff/examples/rosenbrock_fwd.rs | 34 ++ .../autodiff/examples/rosenbrock_fwd_iter.rs | 34 ++ library/autodiff/examples/rosenbrock_rev.rs | 33 ++ library/autodiff/examples/sin.rs | 36 ++ library/autodiff/examples/sqrt.rs | 21 + library/autodiff/examples/struct.rs | 33 ++ library/autodiff/examples/vec.rs | 24 + library/autodiff/examples_broken/biquad.rs | 54 ++ .../autodiff/examples_broken/broken_iter.rs | 20 + .../examples_broken/broken_recursive.rs | 66 +++ .../examples_broken/broken_second_order.rs | 17 + library/autodiff/src/gen.rs | 217 ++++++++ library/autodiff/src/lib.rs | 31 ++ library/autodiff/src/parser.rs | 464 +++++++++++++++++ .../expand/forward_duplicated.expanded.rs | 10 + .../tests/expand/forward_duplicated.rs | 6 + .../forward_duplicated_return.expanded.rs | 15 + .../tests/expand/forward_duplicated_return.rs | 6 + .../expand/reverse_duplicated.expanded.rs | 10 + .../tests/expand/reverse_duplicated.rs | 6 + .../expand/reverse_return_array.expanded.rs | 10 + .../tests/expand/reverse_return_array.rs | 6 + .../expand/reverse_return_mixed.expanded.rs | 17 + .../tests/expand/reverse_return_mixed.rs | 6 + .../tests/ui/active_in_forward_mode.rs | 6 + .../tests/ui/active_in_forward_mode.stderr | 7 + .../tests/ui/activities_inline_and_header.rs | 6 + .../ui/activities_inline_and_header.stderr | 7 + .../autodiff/tests/ui/invalid_indirection.rs | 19 + .../tests/ui/invalid_indirection.stderr | 31 ++ .../tests/ui/invalid_mutability_pairs.rs | 24 + .../tests/ui/invalid_mutability_pairs.stderr | 55 ++ library/autodiff/tests/ui/invalid_return.rs | 12 + .../autodiff/tests/ui/invalid_return.stderr | 23 + .../autodiff/tests/ui/invalid_return_type.rs | 16 + .../tests/ui/invalid_return_type.stderr | 31 ++ library/autodiff/tests/ui/no_function_name.rs | 6 + .../autodiff/tests/ui/no_function_name.stderr | 8 + library/autodiff/tests/ui/not_a_function.rs | 6 + .../autodiff/tests/ui/not_a_function.stderr | 7 + library/autodiff/tests/ui/reverse_tangent.rs | 12 + .../autodiff/tests/ui/reverse_tangent.stderr | 23 + library/autodiff/tests/ui/wrong_mode.rs | 6 + library/autodiff/tests/ui/wrong_mode.stderr | 7 + library/core/src/macros/mod.rs | 12 + src/bootstrap/configure.py | 1 + src/bootstrap/src/core/build_steps/compile.rs | 19 + src/bootstrap/src/core/build_steps/llvm.rs | 66 +++ src/bootstrap/src/core/builder.rs | 5 + src/bootstrap/src/core/config/config.rs | 25 +- src/bootstrap/src/lib.rs | 4 + src/test/ui/terminal-width/flag-human.rs | 9 + src/test/ui/terminal-width/flag-json.rs | 9 + src/test/ui/terminal-width/flag-json.stderr | 40 ++ src/tools/enzyme | 1 + tests/rustdoc-ui/doctest/terminal-width.rs | 5 + .../rustdoc-ui/doctest/terminal-width.stderr | 15 + tests/ui/json/autodiff.rs | 16 + 105 files changed, 3600 insertions(+), 42 deletions(-) create mode 100644 .github/workflows/enzyme-ci.yml create mode 100644 compiler/rustc_codegen_llvm/src/typetree.rs create mode 100644 compiler/rustc_middle/src/middle/autodiff_attrs.rs create mode 100644 compiler/rustc_middle/src/middle/typetree.rs create mode 100644 library/autodiff/Cargo.lock create mode 100644 library/autodiff/Cargo.toml create mode 100644 library/autodiff/examples/array.rs create mode 100644 library/autodiff/examples/box.rs create mode 100644 library/autodiff/examples/broken_matvec.rs create mode 100644 library/autodiff/examples/hessian_sin.rs create mode 100644 library/autodiff/examples/ndarray.rs create mode 100644 library/autodiff/examples/rosenbrock_fwd.rs create mode 100644 library/autodiff/examples/rosenbrock_fwd_iter.rs create mode 100644 library/autodiff/examples/rosenbrock_rev.rs create mode 100644 library/autodiff/examples/sin.rs create mode 100644 library/autodiff/examples/sqrt.rs create mode 100644 library/autodiff/examples/struct.rs create mode 100644 library/autodiff/examples/vec.rs create mode 100644 library/autodiff/examples_broken/biquad.rs create mode 100644 library/autodiff/examples_broken/broken_iter.rs create mode 100644 library/autodiff/examples_broken/broken_recursive.rs create mode 100644 library/autodiff/examples_broken/broken_second_order.rs create mode 100644 library/autodiff/src/gen.rs create mode 100644 library/autodiff/src/lib.rs create mode 100644 library/autodiff/src/parser.rs create mode 100644 library/autodiff/tests/expand/forward_duplicated.expanded.rs create mode 100644 library/autodiff/tests/expand/forward_duplicated.rs create mode 100644 library/autodiff/tests/expand/forward_duplicated_return.expanded.rs create mode 100644 library/autodiff/tests/expand/forward_duplicated_return.rs create mode 100644 library/autodiff/tests/expand/reverse_duplicated.expanded.rs create mode 100644 library/autodiff/tests/expand/reverse_duplicated.rs create mode 100644 library/autodiff/tests/expand/reverse_return_array.expanded.rs create mode 100644 library/autodiff/tests/expand/reverse_return_array.rs create mode 100644 library/autodiff/tests/expand/reverse_return_mixed.expanded.rs create mode 100644 library/autodiff/tests/expand/reverse_return_mixed.rs create mode 100644 library/autodiff/tests/ui/active_in_forward_mode.rs create mode 100644 library/autodiff/tests/ui/active_in_forward_mode.stderr create mode 100644 library/autodiff/tests/ui/activities_inline_and_header.rs create mode 100644 library/autodiff/tests/ui/activities_inline_and_header.stderr create mode 100644 library/autodiff/tests/ui/invalid_indirection.rs create mode 100644 library/autodiff/tests/ui/invalid_indirection.stderr create mode 100644 library/autodiff/tests/ui/invalid_mutability_pairs.rs create mode 100644 library/autodiff/tests/ui/invalid_mutability_pairs.stderr create mode 100644 library/autodiff/tests/ui/invalid_return.rs create mode 100644 library/autodiff/tests/ui/invalid_return.stderr create mode 100644 library/autodiff/tests/ui/invalid_return_type.rs create mode 100644 library/autodiff/tests/ui/invalid_return_type.stderr create mode 100644 library/autodiff/tests/ui/no_function_name.rs create mode 100644 library/autodiff/tests/ui/no_function_name.stderr create mode 100644 library/autodiff/tests/ui/not_a_function.rs create mode 100644 library/autodiff/tests/ui/not_a_function.stderr create mode 100644 library/autodiff/tests/ui/reverse_tangent.rs create mode 100644 library/autodiff/tests/ui/reverse_tangent.stderr create mode 100644 library/autodiff/tests/ui/wrong_mode.rs create mode 100644 library/autodiff/tests/ui/wrong_mode.stderr create mode 100644 src/test/ui/terminal-width/flag-human.rs create mode 100644 src/test/ui/terminal-width/flag-json.rs create mode 100644 src/test/ui/terminal-width/flag-json.stderr create mode 160000 src/tools/enzyme create mode 100644 tests/rustdoc-ui/doctest/terminal-width.rs create mode 100644 tests/rustdoc-ui/doctest/terminal-width.stderr create mode 100644 tests/ui/json/autodiff.rs diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml new file mode 100644 index 0000000000000..4064d4709a5ed --- /dev/null +++ b/.github/workflows/enzyme-ci.yml @@ -0,0 +1,38 @@ +name: Rust CI + +on: + push: + branches: + - master + pull_request: + branches: + - master + merge_group: + +jobs: + build: + name: Rust Integration CI LLVM ${{ matrix.llvm }} ${{ matrix.build }} ${{ matrix.os }} + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: [openstack22] + + timeout-minutes: 600 + steps: + - name: checkout the source code + uses: actions/checkout@v3 + with: + fetch-depth: 2 + - name: build + run: | + mkdir build + cd build + ../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-clang --enable-lld --enable-option-checking --enable-ninja --disable-docs + ../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc + rustup toolchain link enzyme `pwd`/build/`rustup target list --installed`/stage1 + rustup toolchain install nightly # enables -Z unstable-options + - name: test + run: | + cargo +enzyme test --examples diff --git a/.gitmodules b/.gitmodules index f5025097a18dc..7e217cb215dd8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -43,3 +43,6 @@ path = library/backtrace url = https://github.com/rust-lang/backtrace-rs.git shallow = true +[submodule "src/tools/enzyme"] + path = src/tools/enzyme + url = https://github.com/EnzymeAD/Enzyme.git diff --git a/Cargo.lock b/Cargo.lock index 0761268c9d411..06be5561ceec4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4312,6 +4312,7 @@ dependencies = [ "rustc_middle", "rustc_session", "rustc_span", + "rustc_symbol_mangling", "rustc_target", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 9b11ae8744b4f..ab42109434320 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ exclude = [ "src/tools/x", # stdarch has its own Cargo workspace "library/stdarch", + "library/autodiff", ] [profile.release.package.compiler_builtins] diff --git a/README.md b/README.md index a88ee4b8bf061..2ac2d3b38d679 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,68 @@ -# The Rust Programming Language +# The Rust Programming Language +Enzyme [![Rust Community](https://img.shields.io/badge/Rust_Community%20-Join_us-brightgreen?style=plastic&logo=rust)](https://www.rust-lang.org/community) This is the main source code repository for [Rust]. It contains the compiler, -standard library, and documentation. +standard library, and documentation. It is modified to use Enzyme for AutoDiff. + +Please configure this fork using the following command: + +``` +mkdir build +cd build +../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-clang --enable-lld --enable-option-checking --enable-ninja --disable-docs +``` + +Afterwards you can build rustc using: +``` +../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc +``` + +Afterwards rustc toolchain link will allow you to use it through cargo: +``` +rustup toolchain link enzyme `pwd`/build/`rustup target list --installed`/stage1 +rustup toolchain install nightly # enables -Z unstable-options +``` + +You can then look at examples in the `library/autodiff/examples/*` folder and run them with + +```bash +# rosenbrock forward iteration +cargo +enzyme run --example rosenbrock_fwd_iter --release + +# or all of them +cargo +enzyme test --examples +``` + +## Enzyme Config +To help with debugging, Enzyme can be configured using environment variables. +```bash +export ENZYME_PRINT_TA=1 +export ENZYME_PRINT_AA=1 +export ENZYME_PRINT=1 +export ENZYME_PRINT_MOD=1 +export ENZYME_PRINT_MOD_AFTER=1 +``` +The first three will print TypeAnalysis, ActivityAnalysis and the llvm-ir on a function basis, respectively. +The last two variables will print the whole module directly before and after Enzyme differented the functions. + +When experimenting with flags please make sure that EnzymeStrictAliasing=0 +is not changed, since it is required for Enzyme to handle enums correctly. + +## Bug reporting +Bugs are pretty much expected at this point of the development process. +In order to help us please minimize the Rust code as far as possible. +This tool might be a nicer helper: https://github.com/Nilstrieb/cargo-minimize +If you have some knowledge of LLVM-IR we also greatly appreciate it if you could help +us by compiling your minimized Rust code to LLVM-IR and reducing it further. + +The only exception to this strategy is error based on "Can not deduce type of X", +where reducing your example will make it harder for us to understand the origin of the bug. +In this case please just try to inline all dependencies into a single crate or even file, +without deleting used code. + + + [Rust]: https://www.rust-lang.org/ diff --git a/compiler/rustc_ast/src/mut_visit.rs b/compiler/rustc_ast/src/mut_visit.rs index 0634ee970ec5e..23e7975edd65b 100644 --- a/compiler/rustc_ast/src/mut_visit.rs +++ b/compiler/rustc_ast/src/mut_visit.rs @@ -381,7 +381,7 @@ pub fn visit_bounds(bounds: &mut GenericBounds, vis: &mut T) { } // No `noop_` prefix because there isn't a corresponding method in `MutVisitor`. -pub fn visit_fn_sig(FnSig { header, decl, span }: &mut FnSig, vis: &mut T) { +pub fn visit_fn_sig(FnSig { header, decl, span, .. }: &mut FnSig, vis: &mut T) { vis.visit_fn_header(header); vis.visit_fn_decl(decl); vis.visit_span(span); diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs index b6c01545f308c..a586559016cd5 100644 --- a/compiler/rustc_codegen_llvm/src/attributes.rs +++ b/compiler/rustc_codegen_llvm/src/attributes.rs @@ -285,6 +285,7 @@ pub fn from_fn_attrs<'ll, 'tcx>( instance: ty::Instance<'tcx>, ) { let codegen_fn_attrs = cx.tcx.codegen_fn_attrs(instance.def_id()); + let autodiff_attrs = cx.tcx.autodiff_attrs(instance.def_id()); let mut to_add = SmallVec::<[_; 16]>::new(); @@ -302,6 +303,8 @@ pub fn from_fn_attrs<'ll, 'tcx>( let inline = if codegen_fn_attrs.inline == InlineAttr::None && instance.def.requires_inline(cx.tcx) { InlineAttr::Hint + } else if autodiff_attrs.is_active() { + InlineAttr::Never } else { codegen_fn_attrs.inline }; diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 8655aeec13dd6..c63870dfe4327 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -273,6 +273,7 @@ fn fat_lto( info!("pushing cached module {:?}", wp.cgu_name); (buffer, CString::new(wp.cgu_name).unwrap()) })); + for module in modules { match module { FatLtoInput::InMemory(m) => in_memory.push(m), @@ -734,7 +735,7 @@ pub unsafe fn optimize_thin_module( let llcx = llvm::LLVMRustContextCreate(cgcx.fewer_names); let llmod_raw = parse_module(llcx, module_name, thin_module.data(), &diag_handler)? as *const _; let mut module = ModuleCodegen { - module_llvm: ModuleLlvm { llmod_raw, llcx, tm }, + module_llvm: ModuleLlvm { llmod_raw, llcx, tm, typetrees: Default::default() }, name: thin_module.name().to_string(), kind: ModuleKind::Regular, }; diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 9d5204034def0..153b09d867a29 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -10,11 +10,24 @@ use crate::errors::{ WithLlvmError, WriteBytecode, }; use crate::llvm::{self, DiagnosticInfo, PassManager}; +use crate::llvm::{LLVMReplaceAllUsesWith, LLVMVerifyFunction, Value}; use crate::llvm_util; use crate::type_::Type; +use crate::typetree::to_enzyme_typetree; use crate::LlvmCodegenBackend; use crate::ModuleLlvm; +use crate::{base, DiffTypeTree}; use llvm::{ + enzyme_rust_forward_diff, enzyme_rust_reverse_diff, BasicBlock, CreateEnzymeLogic, + CreateTypeAnalysis, EnzymeLogicRef, EnzymeTypeAnalysisRef, LLVMAddFunction, + LLVMAppendBasicBlockInContext, LLVMBuildCall2, LLVMBuildExtractValue, LLVMBuildRet, + LLVMCountParams, LLVMCountStructElementTypes, LLVMCreateBuilderInContext, LLVMDeleteFunction, + LLVMDisposeBuilder, LLVMGetBasicBlockTerminator, LLVMGetElementType, LLVMGetModuleContext, + LLVMGetParams, LLVMGetReturnType, LLVMPositionBuilderAtEnd, LLVMSetValueName2, LLVMTypeOf, + LLVMVoidTypeInContext, LLVMGlobalGetValueType, LLVMGetStringAttributeAtIndex, + LLVMIsStringAttribute, LLVMRemoveStringAttributeAtIndex, LLVMRemoveEnumAttributeAtIndex, AttributeKind, + LLVMGetFirstFunction, LLVMGetNextFunction, LLVMGetEnumAttributeAtIndex, LLVMIsEnumAttribute, + LLVMCreateStringAttribute, LLVMRustAddFunctionAttributes, LLVMCreateEnumAttribute, LLVMDumpModule, LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, }; use rustc_codegen_ssa::back::link::ensure_removed; @@ -24,10 +37,12 @@ use rustc_codegen_ssa::back::write::{ }; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{CompiledModule, ModuleCodegen}; +use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::profiling::SelfProfilerRef; use rustc_data_structures::small_c_str::SmallCStr; use rustc_errors::{FatalError, Handler, Level}; use rustc_fs_util::{link_or_copy, path_to_c_string}; +use rustc_middle::middle::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode}; use rustc_middle::ty::TyCtxt; use rustc_session::config::{self, Lto, OutputType, Passes, SplitDwarfKind, SwitchWithOptPath}; use rustc_session::Session; @@ -37,7 +52,7 @@ use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo}; use crate::llvm::diagnostic::OptimizationDiagnosticKind; use libc::{c_char, c_int, c_uint, c_void, size_t}; -use std::ffi::CString; +use std::ffi::{CStr, CString}; use std::fs; use std::io::{self, Write}; use std::path::{Path, PathBuf}; @@ -513,8 +528,18 @@ pub(crate) unsafe fn llvm_optimize( opt_level: config::OptLevel, opt_stage: llvm::OptStage, ) -> Result<(), FatalError> { - let unroll_loops = - opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + // Enzyme: + // We want to simplify / optimize functions before AD. + // However, benchmarks show that optimizations increasing the code size + // tend to reduce AD performance. Therefore activate them first, then differentiate the code + // and finally re-optimize the module, now with all optimizations available. + // RIP compile time. + // let unroll_loops = + // opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + let unroll_loops = false; + let vectorize_slp = false; + let vectorize_loop = false; + let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let pgo_gen_path = get_pgo_gen_path(config); let pgo_use_path = get_pgo_use_path(config); @@ -569,8 +594,8 @@ pub(crate) unsafe fn llvm_optimize( using_thin_buffers, config.merge_functions, unroll_loops, - config.vectorize_slp, - config.vectorize_loop, + vectorize_slp, + vectorize_loop, config.no_builtins, config.emit_lifetime_markers, sanitizer_options.as_ref(), @@ -592,6 +617,255 @@ pub(crate) unsafe fn llvm_optimize( result.into_result().map_err(|()| llvm_err(diag_handler, LlvmError::RunLlvmPasses)) } +fn get_params(fnc: &Value) -> Vec<&Value> { + unsafe { + let param_num = LLVMCountParams(fnc) as usize; + let mut fnc_args: Vec<&Value> = vec![]; + fnc_args.reserve(param_num); + LLVMGetParams(fnc, fnc_args.as_mut_ptr()); + fnc_args.set_len(param_num); + fnc_args + } +} + +// TODO: cleanup +unsafe fn create_wrapper<'a>( + llmod: &'a llvm::Module, + //module: &'a ModuleCodegen, + fnc: &'a Value, + u_type: &Type, + fnc_name: String, +) -> (&'a Value, &'a BasicBlock, Vec<&'a Value>, Vec<&'a Value>, CString) { + //let llmod = module.module_llvm.llmod(); + let context = LLVMGetModuleContext(llmod); + let inner_fnc_name = "inner_".to_string() + &fnc_name; + let c_inner_fnc_name = CString::new(inner_fnc_name.clone()).unwrap(); + LLVMSetValueName2(fnc, c_inner_fnc_name.as_ptr(), inner_fnc_name.len() as usize); + + let c_outer_fnc_name = CString::new(fnc_name).unwrap(); + let outer_fnc: &Value = + LLVMAddFunction(llmod, c_outer_fnc_name.as_ptr(), LLVMGetElementType(u_type) as &Type); + + let entry = "fnc_entry".to_string(); + let c_entry = CString::new(entry).unwrap(); + let basic_block = LLVMAppendBasicBlockInContext(context, outer_fnc, c_entry.as_ptr()); + + let outer_params: Vec<&Value> = get_params(outer_fnc); + let inner_params: Vec<&Value> = get_params(fnc); + + (outer_fnc, basic_block, outer_params, inner_params, c_inner_fnc_name) +} + +//pub(crate) fn get_type(t: LLVMTypeRef) -> CString { +// unsafe { CString::from_raw(LLVMPrintTypeToString(t)) } +//} + +// TODO: Don't write a wrapper function, just unwrap the struct inside of the same fnc. +// Might help during debugging, if you have one function less to jump trough +pub(crate) unsafe fn extract_return_type<'a>( + llmod: &'a llvm::Module, + fnc: &'a Value, + u_type: &Type, + fnc_name: String, +) -> &'a Value { + //let llmod = module.module_llvm.llmod(); + let context = llvm::LLVMGetModuleContext(llmod); + //dbg!("Unpacking", fnc_name.clone()); + //dbg!("From: ", f_type, " into ", u_type); + + let inner_param_num = LLVMCountParams(fnc); + let (outer_fnc, outer_bb, mut outer_args, _inner_args, c_inner_fnc_name) = + create_wrapper(llmod, fnc, u_type, fnc_name); + + if inner_param_num as usize != outer_args.len() { + panic!("Args len shouldn't differ. Please report this."); + } + + let builder = LLVMCreateBuilderInContext(context); + LLVMPositionBuilderAtEnd(builder, outer_bb); + let struct_ret = LLVMBuildCall2( + builder, + u_type, + fnc, + outer_args.as_mut_ptr(), + outer_args.len(), + c_inner_fnc_name.as_ptr(), + ); + // We can use an arbitrary name here, since it will be used to store a tmp value. + let inner_grad_name = "foo".to_string(); + let c_inner_grad_name = CString::new(inner_grad_name).unwrap(); + let struct_ret = LLVMBuildExtractValue(builder, struct_ret, 0, c_inner_grad_name.as_ptr()); + let _ret = LLVMBuildRet(builder, struct_ret); + let _terminator = LLVMGetBasicBlockTerminator(outer_bb); + //assert!(LLVMIsNull(terminator)!=0, "no terminator"); + LLVMDisposeBuilder(builder); + + let _fnc_ok = + LLVMVerifyFunction(outer_fnc, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); + //dbg!(outer_fnc); + //assert!(fnc_ok); + //if let Err(e) = verify_function(outer_fnc) { + // panic!("Creating a wrapper function failed! {}", e); + //} + + outer_fnc +} + +// As unsafe as it can be. +#[allow(unused_variables)] +#[allow(unused)] +pub(crate) unsafe fn enzyme_ad( + llmod: &llvm::Module, + llcx: &llvm::Context, + item: AutoDiffItem, +) -> Result<(), FatalError> { + let autodiff_mode = item.attrs.mode; + let rust_name = item.source; + let rust_name2 = &item.target; + + let args_activity = item.attrs.input_activity.clone(); + let ret_activity: DiffActivity = item.attrs.ret_activity; + + // get target and source function + let name = CString::new(rust_name.to_owned()).unwrap(); + let name2 = CString::new(rust_name2.clone()).unwrap(); + let src_fnc = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr()).unwrap(); + let target_fnc = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()).unwrap(); + + // create enzyme typetrees + let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + + let input_tts = + item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect(); + let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); + + let opt = 1; + let ret_primary_ret = false; + let diff_primary_ret = false; + let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(opt as u8); + let type_analysis: EnzymeTypeAnalysisRef = + CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0); + + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); + + if std::env::var("ENZYME_PRINT_TA").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintType), 1); + } + if std::env::var("ENZYME_PRINT_AA").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintActivity), 1); + } + if std::env::var("ENZYME_PRINT_PERF").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintPerf), 1); + } + if std::env::var("ENZYME_PRINT").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrint), 1); + } + + let mut res: &Value = match item.attrs.mode { + DiffMode::Forward => enzyme_rust_forward_diff( + logic_ref, + type_analysis, + src_fnc, + args_activity, + ret_activity, + ret_primary_ret, + input_tts, + output_tt, + ), + DiffMode::Reverse => enzyme_rust_reverse_diff( + logic_ref, + type_analysis, + src_fnc, + args_activity, + ret_activity, + ret_primary_ret, + diff_primary_ret, + input_tts, + output_tt, + ), + _ => unreachable!(), + }; + let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res)); + + let void_type = LLVMVoidTypeInContext(llcx); + if item.attrs.mode == DiffMode::Reverse && f_return_type != void_type { + //dbg!("Reverse Mode sanitizer"); + //dbg!(f_type); + //dbg!(f_return_type); + let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); + if num_elem_in_ret_struct == 1 { + let u_type = LLVMTypeOf(target_fnc); + res = extract_return_type(llmod, res, u_type, rust_name2.clone()); // TODO: check if name or name2 + } + } + //dbg!(&target_fnc); + LLVMSetValueName2(res, name2.as_ptr(), rust_name2.len()); + LLVMReplaceAllUsesWith(target_fnc, res); + LLVMDeleteFunction(target_fnc); + + Ok(()) +} + +pub(crate) unsafe fn differentiate( + module: &ModuleCodegen, + _cgcx: &CodegenContext, + diff_items: Vec, + _typetrees: FxHashMap, + _config: &ModuleConfig, +) -> Result<(), FatalError> { + let llmod = module.module_llvm.llmod(); + let llcx = &module.module_llvm.llcx; + + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); + + if std::env::var("ENZYME_PRINT_MOD").is_ok() { + unsafe {LLVMDumpModule(llmod);} + } + if std::env::var("ENZYME_TT_DEPTH").is_ok() { + let depth = std::env::var("ENZYME_TT_DEPTH").unwrap(); + let depth = depth.parse::().unwrap(); + assert!(depth >= 1); + llvm::EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::EnzymeMaxTypeDepth), depth); + } + if std::env::var("ENZYME_TT_WIDTH").is_ok() { + let width = std::env::var("ENZYME_TT_WIDTH").unwrap(); + let width = width.parse::().unwrap(); + assert!(width >= 1); + llvm::EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::MaxTypeOffset), width); + } + + for item in diff_items { + let res = enzyme_ad(llmod, llcx, item); + assert!(res.is_ok()); + } + + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let attr = LLVMGetStringAttributeAtIndex(lf, c_uint::MAX, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint); + if LLVMIsStringAttribute(attr) { + LLVMRemoveStringAttributeAtIndex(lf, c_uint::MAX, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint); + } else { + LLVMRemoveEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); + } + + + } else { + break; + } + } + if std::env::var("ENZYME_PRINT_MOD_AFTER").is_ok() { + unsafe {LLVMDumpModule(llmod);} + } + + Ok(()) +} + // Unsafe due to LLVM calls. pub(crate) unsafe fn optimize( cgcx: &CodegenContext, @@ -615,6 +889,28 @@ pub(crate) unsafe fn optimize( llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()); } + { + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let myhwv = ""; + let prevattr = LLVMGetEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); + if LLVMIsEnumAttribute(prevattr) { + let attr = LLVMCreateStringAttribute(llcx, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint, myhwv.as_ptr() as *const c_char, myhwv.as_bytes().len() as c_uint); + LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); + } else { + let attr = LLVMCreateEnumAttribute(llcx, AttributeKind::SanitizeHWAddress, 0); + LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); + } + + } else { + break; + } + } + } + if let Some(opt_level) = config.opt_level { let opt_stage = match cgcx.lto { Lto::Fat => llvm::OptStage::PreLinkFatLTO, diff --git a/compiler/rustc_codegen_llvm/src/base.rs b/compiler/rustc_codegen_llvm/src/base.rs index b659fd02eecf6..1d9157e6355f4 100644 --- a/compiler/rustc_codegen_llvm/src/base.rs +++ b/compiler/rustc_codegen_llvm/src/base.rs @@ -25,6 +25,7 @@ use rustc_codegen_ssa::base::maybe_create_entry_wrapper; use rustc_codegen_ssa::mono_item::MonoItemExt; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{ModuleCodegen, ModuleKind}; +use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::small_c_str::SmallCStr; use rustc_middle::dep_graph; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; @@ -82,9 +83,10 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen recorder.record_arg(cgu.size_estimate().to_string()); }); // Instantiate monomorphizations without filling out definitions yet... - let llvm_module = ModuleLlvm::new(tcx, cgu_name.as_str()); - { + let mut llvm_module = ModuleLlvm::new(tcx, cgu_name.as_str()); + let typetrees = { let cx = CodegenCx::new(tcx, cgu, &llvm_module); + let mono_items = cx.codegen_unit.items_in_deterministic_order(cx.tcx); for &(mono_item, data) in &mono_items { mono_item.predefine::>(&cx, data.linkage, data.visibility); @@ -132,7 +134,11 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen if cx.sess().opts.debuginfo != DebugInfo::None { cx.debuginfo_finalize(); } - } + + FxHashMap::default() + }; + + llvm_module.typetrees = typetrees; ModuleCodegen { name: cgu_name.to_string(), diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index b4b2ab1e1f8a9..2d16649ce17d2 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -624,6 +624,10 @@ impl<'ll, 'tcx> MiscMethods<'tcx> for CodegenCx<'ll, 'tcx> { None } } + + fn create_autodiff(&self) -> Vec { + return vec![]; + } } impl<'ll> CodegenCx<'ll, '_> { diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 8a6a5f79b3bb9..011a208eb6389 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -30,6 +30,7 @@ use back::owned_target_machine::OwnedTargetMachine; use back::write::{create_informational_target_machine, create_target_machine}; use errors::ParseTargetMachineConfig; +use llvm::TypeTree; pub use llvm_util::target_features; use rustc_ast::expand::allocator::AllocatorKind; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; @@ -44,6 +45,8 @@ use rustc_errors::{DiagnosticMessage, ErrorGuaranteed, FatalError, Handler, Subd use rustc_fluent_macro::fluent_messages; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; +use rustc_middle::middle::autodiff_attrs::AutoDiffItem; +use rustc_middle::ty::query::Providers; use rustc_middle::ty::TyCtxt; use rustc_middle::util::Providers; use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; @@ -77,6 +80,7 @@ mod debuginfo; mod declare; mod errors; mod intrinsic; +mod typetree; // The following is a workaround that replaces `pub mod llvm;` and that fixes issue 53912. #[path = "llvm/mod.rs"] @@ -172,6 +176,8 @@ impl WriteBackendMethods for LlvmCodegenBackend { type TargetMachineError = crate::errors::LlvmError<'static>; type ThinData = back::lto::ThinData; type ThinBuffer = back::lto::ThinBuffer; + type TypeTree = DiffTypeTree; + fn print_pass_timings(&self) { unsafe { let mut size = 0; @@ -254,6 +260,20 @@ impl WriteBackendMethods for LlvmCodegenBackend { fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer) { (module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod())) } + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result<(), FatalError> { + unsafe { back::write::differentiate(module, cgcx, diff_fncs, typetrees, config) } + } + + fn typetrees(module: &mut Self::Module) -> FxHashMap { + module.typetrees.drain().collect() + } } unsafe impl Send for LlvmCodegenBackend {} // Llvm is on a per-thread basis @@ -404,12 +424,20 @@ impl CodegenBackend for LlvmCodegenBackend { } } +#[derive(Clone, Debug)] +pub struct DiffTypeTree { + pub ret_tt: TypeTree, + pub input_tt: Vec, +} + +#[allow(dead_code)] pub struct ModuleLlvm { llcx: &'static mut llvm::Context, llmod_raw: *const llvm::Module, // independent from llcx and llmod_raw, resources get disposed by drop impl tm: OwnedTargetMachine, + typetrees: FxHashMap, } unsafe impl Send for ModuleLlvm {} @@ -420,7 +448,12 @@ impl ModuleLlvm { unsafe { let llcx = llvm::LLVMRustContextCreate(tcx.sess.fewer_names()); let llmod_raw = context::create_module(tcx, llcx, mod_name) as *const _; - ModuleLlvm { llmod_raw, llcx, tm: create_target_machine(tcx, mod_name) } + ModuleLlvm { + llmod_raw, + llcx, + tm: create_target_machine(tcx, mod_name), + typetrees: Default::default(), + } } } @@ -428,7 +461,12 @@ impl ModuleLlvm { unsafe { let llcx = llvm::LLVMRustContextCreate(tcx.sess.fewer_names()); let llmod_raw = context::create_module(tcx, llcx, mod_name) as *const _; - ModuleLlvm { llmod_raw, llcx, tm: create_informational_target_machine(tcx.sess) } + ModuleLlvm { + llmod_raw, + llcx, + tm: create_informational_target_machine(tcx.sess), + typetrees: Default::default(), + } } } @@ -449,7 +487,7 @@ impl ModuleLlvm { } }; - Ok(ModuleLlvm { llmod_raw, llcx, tm }) + Ok(ModuleLlvm { llmod_raw, llcx, tm, typetrees: Default::default() }) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index a038b3af03dd6..c5514d5bff823 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1,6 +1,9 @@ #![allow(non_camel_case_types)] #![allow(non_upper_case_globals)] +use rustc_codegen_ssa::coverageinfo::map as coverage_map; +use rustc_middle::middle::autodiff_attrs::DiffActivity; + use super::debuginfo::{ DIArray, DIBasicType, DIBuilder, DICompositeType, DIDerivedType, DIDescriptor, DIEnumerator, DIFile, DIFlags, DIGlobalVariableExpression, DILexicalBlock, DILocation, DINameSpace, @@ -11,6 +14,8 @@ use super::debuginfo::{ use libc::{c_char, c_int, c_uint, size_t}; use libc::{c_ulonglong, c_void}; +use core::fmt; +use std::ffi::{CStr, CString}; use std::marker::PhantomData; use super::RustString; @@ -187,7 +192,7 @@ pub enum AttributeKind { OptimizeNone = 24, ReturnsTwice = 25, ReadNone = 26, - SanitizeHWAddress = 28, + SanitizeHWAddress = 51, WillReturn = 29, StackProtectReq = 30, StackProtectStrong = 31, @@ -819,10 +824,186 @@ pub type SelfProfileBeforePassCallback = unsafe extern "C" fn(*mut c_void, *const c_char, *const c_char); pub type SelfProfileAfterPassCallback = unsafe extern "C" fn(*mut c_void); +#[repr(C)] +pub enum LLVMVerifierFailureAction { + LLVMAbortProcessAction, + LLVMPrintMessageAction, + LLVMReturnStatusAction, +} + +pub(crate) unsafe fn enzyme_rust_forward_diff( + logic_ref: EnzymeLogicRef, + type_analysis: EnzymeTypeAnalysisRef, + fnc: &Value, + input_diffactivity: Vec, + ret_diffactivity: DiffActivity, + mut ret_primary_ret: bool, + input_tts: Vec, + output_tt: TypeTree, +) -> &Value { + let ret_activity = cdiffe_from(ret_diffactivity); + assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF); + let mut input_activity: Vec = vec![]; + for input in input_diffactivity { + let act = cdiffe_from(input); + assert!(act == CDIFFE_TYPE::DFT_CONSTANT || act == CDIFFE_TYPE::DFT_DUP_ARG || act == CDIFFE_TYPE::DFT_DUP_NONEED); + input_activity.push(act); + } + + if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG { + if ret_primary_ret != true { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = true; + } else if ret_activity == CDIFFE_TYPE::DFT_DUP_NONEED { + if ret_primary_ret != false { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = false; + } + + let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + //let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()]; + + // We don't support volatile / extern / (global?) values. + // Just because I didn't had time to test them, and it seems less urgent. + let args_uncacheable = vec![0; input_activity.len()]; + + let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; + + let mut known_values = vec![kv_tmp; input_activity.len()]; + + let dummy_type = CFnTypeInfo { + Arguments: args_tree.as_mut_ptr(), + Return: output_tt.inner.clone(), + KnownValues: known_values.as_mut_ptr(), + }; + + EnzymeCreateForwardDiff( + logic_ref, // Logic + std::ptr::null(), + std::ptr::null(), + fnc, + ret_activity, // LLVM function, return type + input_activity.as_ptr(), + input_activity.len(), // constant arguments + type_analysis, // type analysis struct + ret_primary_ret as u8, + CDerivativeMode::DEM_ForwardMode, // return value, dret_used, top_level which was 1 + 1, // free memory + 1, // vector mode width + Option::None, + dummy_type, // additional_arg, type info (return + args) + args_uncacheable.as_ptr(), + args_uncacheable.len(), // uncacheable arguments + std::ptr::null_mut(), // write augmented function to this + ) +} + +pub(crate) unsafe fn enzyme_rust_reverse_diff( + logic_ref: EnzymeLogicRef, + type_analysis: EnzymeTypeAnalysisRef, + fnc: &Value, + input_activity: Vec, + ret_activity: DiffActivity, + mut ret_primary_ret: bool, + diff_primary_ret: bool, + input_tts: Vec, + output_tt: TypeTree, +) -> &Value { + let ret_activity = cdiffe_from(ret_activity); + assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF); + let input_activity: Vec = input_activity.iter().map(|&x| cdiffe_from(x)).collect(); + + if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG { + if ret_primary_ret != true { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = true; + } else if ret_activity == CDIFFE_TYPE::DFT_DUP_NONEED { + if ret_primary_ret != false { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = false; + } + + let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + + // We don't support volatile / extern / (global?) values. + // Just because I didn't had time to test them, and it seems less urgent. + let args_uncacheable = vec![0; input_tts.len()]; + let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; + + + let mut known_values = vec![kv_tmp; input_tts.len()]; + + let dummy_type = CFnTypeInfo { + Arguments: args_tree.as_mut_ptr(), + Return: output_tt.inner.clone(), + KnownValues: known_values.as_mut_ptr(), + }; + + EnzymeCreatePrimalAndGradient( + logic_ref, // Logic + std::ptr::null(), + std::ptr::null(), + fnc, + ret_activity, // LLVM function, return type + input_activity.as_ptr(), + input_activity.len(), // constant arguments + type_analysis, // type analysis struct + ret_primary_ret as u8, + diff_primary_ret as u8, //0 + CDerivativeMode::DEM_ReverseModeCombined, // return value, dret_used, top_level which was 1 + 1, // vector mode width + 1, // free memory + Option::None, + 0, // do not force anonymous tape + dummy_type, // additional_arg, type info (return + args) + args_uncacheable.as_ptr(), + args_uncacheable.len(), // uncacheable arguments + std::ptr::null_mut(), // write augmented function to this + 0, + ) +} pub type GetSymbolsCallback = unsafe extern "C" fn(*mut c_void, *const c_char) -> *mut c_void; pub type GetSymbolsErrorCallback = unsafe extern "C" fn(*const c_char) -> *mut c_void; extern "C" { + + // Enzyme + //pub fn LLVMReplaceAllUsesWith(old: &Value, new: &Value); + pub fn GibtsNicht(M: &Module) -> bool; + pub fn LLVMIsStructTy(ty: &Type) -> bool; + pub fn LLVMGetReturnType(T: &Type) -> &Type; + pub fn LLVMDumpModule(M: &Module); + pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; + pub fn LLVMDeleteFunction(V: &Value); + pub fn LLVMRemoveStringAttributeAtIndex(F : &Value, Idx: c_uint, K: *const c_char, KLen : c_uint); + pub fn LLVMGetStringAttributeAtIndex(F : &Value, Idx: c_uint, K: *const c_char, KLen : c_uint) -> &Attribute; + pub fn LLVMRemoveEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: AttributeKind); + pub fn LLVMGetEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: AttributeKind) -> &Attribute; + pub fn LLVMIsEnumAttribute(A : &Attribute) -> bool; + pub fn LLVMCreateEnumAttribute(C : &Context, Kind: AttributeKind, val:u64) -> &Attribute; + pub fn LLVMIsStringAttribute(A : &Attribute) -> bool; + pub fn LLVMVerifyFunction(V: &Value, action: LLVMVerifierFailureAction) -> bool; + pub fn LLVMGetParams(Fnc: &Value, parms: *mut &Value); + pub fn LLVMBuildCall2<'a>( + arg1: &Builder<'a>, + ty: &Type, + func: &Value, + args: *mut &Value, + num_args: size_t, + name: *const c_char, + ) -> &'a Value; + pub fn LLVMGetBasicBlockTerminator(B: &BasicBlock) -> &Value; + pub fn LLVMAddFunction<'a>(M: &Module, Name: *const c_char, Ty: &Type) -> &'a Value; + pub fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>; + pub fn LLVMGetNextFunction(V: &Value) -> Option<&Value>; + pub fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>; + pub fn LLVMGlobalGetValueType(val: &Value) -> &Type; + + pub fn LLVMRustGetFunctionType(fnc: &Value) -> &Type; pub fn LLVMRustInstallFatalErrorHandler(); pub fn LLVMRustDisableSystemDialogsOnCrash(); @@ -2091,6 +2272,8 @@ extern "C" { #[allow(improper_ctypes)] pub fn LLVMRustWriteTypeToString(Type: &Type, s: &RustString); #[allow(improper_ctypes)] + pub fn LLVMRustWriteValueNameToString(value_ref: &Value, s: &RustString); + #[allow(improper_ctypes)] pub fn LLVMRustWriteValueToString(value_ref: &Value, s: &RustString); pub fn LLVMIsAConstantInt(value_ref: &Value) -> Option<&ConstantInt>; @@ -2362,7 +2545,6 @@ extern "C" { remark_file: *const c_char, pgo_available: bool, ); - #[allow(improper_ctypes)] pub fn LLVMRustGetMangledName(V: &Value, out: &RustString); @@ -2382,3 +2564,301 @@ extern "C" { error_callback: GetSymbolsErrorCallback, ) -> *mut c_void; } +// Manuel +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueTypeAnalysis { + _unused: [u8; 0], +} +pub type EnzymeTypeAnalysisRef = *mut EnzymeOpaqueTypeAnalysis; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueLogic { + _unused: [u8; 0], +} +pub type EnzymeLogicRef = *mut EnzymeOpaqueLogic; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueAugmentedReturn { + _unused: [u8; 0], +} +pub type EnzymeAugmentedReturnPtr = *mut EnzymeOpaqueAugmentedReturn; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct IntList { + pub data: *mut i64, + pub size: size_t, +} +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CConcreteType { + DT_Anything = 0, + DT_Integer = 1, + DT_Pointer = 2, + DT_Half = 3, + DT_Float = 4, + DT_Double = 5, + DT_Unknown = 6, +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeTypeTree { + _unused: [u8; 0], +} +pub type CTypeTreeRef = *mut EnzymeTypeTree; +extern "C" { + fn EnzymeNewTypeTree() -> CTypeTreeRef; +} +extern "C" { + fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); +} +extern "C" { + pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); +} +extern "C" { + pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64); +} + +extern "C" { + pub static mut MaxIntOffset: c_void; + pub static mut MaxTypeOffset: c_void; + pub static mut EnzymeMaxTypeDepth: c_void; + + pub static mut EnzymePrintPerf: c_void; + pub static mut EnzymePrintActivity: c_void; + pub static mut EnzymePrintType: c_void; + pub static mut EnzymePrint: c_void; + pub static mut EnzymeStrictAliasing: c_void; +} + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct CFnTypeInfo { + #[doc = " Types of arguments, assumed of size len(Arguments)"] + pub Arguments: *mut CTypeTreeRef, + #[doc = " Type of return"] + pub Return: CTypeTreeRef, + #[doc = " The specific constant(s) known to represented by an argument, if constant"] + pub KnownValues: *mut IntList, +} +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CDIFFE_TYPE { + DFT_OUT_DIFF = 0, + DFT_DUP_ARG = 1, + DFT_CONSTANT = 2, + DFT_DUP_NONEED = 3, +} + +fn cdiffe_from(act: DiffActivity) -> CDIFFE_TYPE { + return match act { + DiffActivity::None => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF, + DiffActivity::Const => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG, + DiffActivity::DuplicatedNoNeed => CDIFFE_TYPE::DFT_DUP_NONEED, + }; +} + +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CDerivativeMode { + DEM_ForwardMode = 0, + DEM_ReverseModePrimal = 1, + DEM_ReverseModeGradient = 2, + DEM_ReverseModeCombined = 3, + DEM_ForwardModeSplit = 4, +} +extern "C" { + fn EnzymeCreatePrimalAndGradient<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8, // &'a Builder<'_>, + _callerCtx: *const u8,// &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + dretUsed: u8, + mode: CDerivativeMode, + width: ::std::os::raw::c_uint, + freeMemory: u8, + additionalArg: Option<&Type>, + forceAnonymousTape: u8, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + AtomicAdd: u8, + ) -> &'a Value; + //) -> LLVMValueRef; +} +extern "C" { + fn EnzymeCreateForwardDiff<'a>( + arg1: EnzymeLogicRef, + _builderCtx: *const u8,// &'a Builder<'_>, + _callerCtx: *const u8,// &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + mode: CDerivativeMode, + freeMemory: u8, + width: ::std::os::raw::c_uint, + additionalArg: Option<&Type>, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + ) -> &'a Value; +} +pub type CustomRuleType = ::std::option::Option< + unsafe extern "C" fn( + direction: ::std::os::raw::c_int, + ret: CTypeTreeRef, + args: *mut CTypeTreeRef, + known_values: *mut IntList, + num_args: size_t, + fnc: &Value, + ta: *const ::std::os::raw::c_void, + ) -> u8, +>; +extern "C" { + pub fn CreateTypeAnalysis( + Log: EnzymeLogicRef, + customRuleNames: *mut *mut ::std::os::raw::c_char, + customRules: *mut CustomRuleType, + numRules: size_t, + ) -> EnzymeTypeAnalysisRef; +} +extern "C" { + pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef); +} +extern "C" { + pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef); +} +extern "C" { + pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef; +} +extern "C" { + pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef); +} +extern "C" { + pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef); +} + +extern "C" { + fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; + fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; + fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; + fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); + fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); + fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ); + fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; +} + +pub struct TypeTree { + pub inner: CTypeTreeRef, +} + +impl TypeTree { + pub fn new() -> TypeTree { + let inner = unsafe { EnzymeNewTypeTree() }; + + TypeTree { inner } + } + + #[must_use] + pub fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree { + let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) }; + + TypeTree { inner } + } + + #[must_use] + pub fn only(self, idx: isize) -> TypeTree { + unsafe { + EnzymeTypeTreeOnlyEq(self.inner, idx as i64); + } + self + } + + #[must_use] + pub fn data0(self) -> TypeTree { + unsafe { + EnzymeTypeTreeData0Eq(self.inner); + } + self + } + + pub fn merge(self, other: Self) -> Self { + unsafe { + EnzymeMergeTypeTree(self.inner, other.inner); + } + drop(other); + + self + } + + #[must_use] + pub fn shift(self, layout: &str, offset: isize, max_size: isize, add_offset: usize) -> Self { + let layout = CString::new(layout).unwrap(); + + unsafe { + EnzymeTypeTreeShiftIndiciesEq( + self.inner, + layout.as_ptr(), + offset as i64, + max_size as i64, + add_offset as u64, + ) + } + + self + } +} + +impl Clone for TypeTree { + fn clone(&self) -> Self { + let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) }; + TypeTree { inner } + } +} + +impl fmt::Display for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let ptr = unsafe { EnzymeTypeTreeToString(self.inner) }; + let cstr = unsafe { CStr::from_ptr(ptr) }; + match cstr.to_str() { + Ok(x) => write!(f, "{}", x)?, + Err(err) => write!(f, "could not parse: {}", err)?, + } + + // delete C string pointer + unsafe { EnzymeTypeTreeToStringFree(ptr) } + + Ok(()) + } +} + +impl fmt::Debug for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} + +impl Drop for TypeTree { + fn drop(&mut self) { + unsafe { EnzymeFreeTypeTree(self.inner) } + } +} diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs new file mode 100644 index 0000000000000..091ddaa3cf213 --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -0,0 +1,33 @@ +use crate::llvm; +use rustc_middle::middle::typetree::{Kind, TypeTree}; + +pub fn to_enzyme_typetree( + tree: TypeTree, + llvm_data_layout: &str, + llcx: &llvm::Context, +) -> llvm::TypeTree { + tree.0.iter().fold(llvm::TypeTree::new(), |obj, x| { + let scalar = match x.kind { + Kind::Integer => llvm::CConcreteType::DT_Integer, + Kind::Float => llvm::CConcreteType::DT_Float, + Kind::Double => llvm::CConcreteType::DT_Double, + Kind::Pointer => llvm::CConcreteType::DT_Pointer, + _ => panic!("Unknown kind {:?}", x.kind), + }; + + let tt = llvm::TypeTree::from_type(scalar, llcx).only(-1); + + let tt = if !x.child.0.is_empty() { + let inner_tt = to_enzyme_typetree(x.child.clone(), llvm_data_layout, llcx); + tt.merge(inner_tt.only(-1)) + } else { + tt + }; + + if x.offset != -1 { + obj.merge(tt.shift(llvm_data_layout, 0, x.size as isize, x.offset as usize)) + } else { + obj.merge(tt) + } + }) +} diff --git a/compiler/rustc_codegen_ssa/src/assert_module_sources.rs b/compiler/rustc_codegen_ssa/src/assert_module_sources.rs index 16bb7b12bd3c1..a4ba7bfe7d20b 100644 --- a/compiler/rustc_codegen_ssa/src/assert_module_sources.rs +++ b/compiler/rustc_codegen_ssa/src/assert_module_sources.rs @@ -46,7 +46,7 @@ pub fn assert_module_sources(tcx: TyCtxt<'_>, set_reuse: &dyn Fn(&mut CguReuseTr } let available_cgus = - tcx.collect_and_partition_mono_items(()).1.iter().map(|cgu| cgu.name()).collect(); + tcx.collect_and_partition_mono_items(()).2.iter().map(|cgu| cgu.name()).collect(); let mut ams = AssertModuleSource { tcx, diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index cb6244050df24..f27b09c8146f3 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -1,9 +1,11 @@ use super::write::CodegenContext; +use crate::back::write::ModuleConfig; use crate::traits::*; use crate::ModuleCodegen; -use rustc_data_structures::memmap::Mmap; +use rustc_data_structures::{fx::FxHashMap, memmap::Mmap}; use rustc_errors::FatalError; +use rustc_middle::middle::autodiff_attrs::AutoDiffItem; use std::ffi::CString; use std::sync::Arc; @@ -76,6 +78,27 @@ impl LtoModuleCodegen { } } + /// Run autodiff on Fat LTO module + pub unsafe fn autodiff( + self, + cgcx: &CodegenContext, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result, FatalError> { + match &self { + LtoModuleCodegen::Fat { ref module, .. } => { + //let module = module.take().unwrap(); + { + B::autodiff(cgcx, &module, diff_fncs, typetrees, config)?; + } + }, + _ => {}, + } + + Ok(self) + } + /// A "gauge" of how costly it is to optimize this module, used to sort /// biggest modules first. pub fn cost(&self) -> u64 { diff --git a/compiler/rustc_codegen_ssa/src/back/symbol_export.rs b/compiler/rustc_codegen_ssa/src/back/symbol_export.rs index 9cd4394108a4a..5fd525dd56e03 100644 --- a/compiler/rustc_codegen_ssa/src/back/symbol_export.rs +++ b/compiler/rustc_codegen_ssa/src/back/symbol_export.rs @@ -317,7 +317,7 @@ fn exported_symbols_provider_local( // external linkage is enough for monomorphization to be linked to. let need_visibility = tcx.sess.target.dynamic_linking && !tcx.sess.target.only_cdylib; - let (_, cgus) = tcx.collect_and_partition_mono_items(()); + let (_, _, cgus) = tcx.collect_and_partition_mono_items(()); for (mono_item, data) in cgus.iter().flat_map(|cgu| cgu.items().iter()) { if data.linkage != Linkage::External { diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 3d6a212433463..0f77938999d9e 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -24,6 +24,7 @@ use rustc_incremental::{ use rustc_metadata::fs::copy_to_stdout; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; +use rustc_middle::middle::autodiff_attrs::AutoDiffItem; use rustc_middle::middle::exported_symbols::SymbolExportInfo; use rustc_middle::ty::TyCtxt; use rustc_session::config::{self, CrateType, Lto, OutFileName, OutputFilenames, OutputType}; @@ -117,6 +118,7 @@ pub struct ModuleConfig { pub inline_threshold: Option, pub emit_lifetime_markers: bool, pub llvm_plugins: Vec, + pub enzyme_print_activity: bool, } impl ModuleConfig { @@ -194,6 +196,7 @@ impl ModuleConfig { false ), + enzyme_print_activity: sess.opts.unstable_opts.enzyme_print_activity, sanitizer: if_regular!(sess.opts.unstable_opts.sanitizer, SanitizerSet::empty()), sanitizer_recover: if_regular!( sess.opts.unstable_opts.sanitizer_recover, @@ -385,6 +388,8 @@ impl CodegenContext { fn generate_lto_work( cgcx: &CodegenContext, + autodiff: Vec, + typetrees: FxHashMap, needs_fat_lto: Vec>, needs_thin_lto: Vec<(String, B::ThinBuffer)>, import_only_modules: Vec<(SerializedModule, WorkProduct)>, @@ -393,10 +398,14 @@ fn generate_lto_work( if !needs_fat_lto.is_empty() { assert!(needs_thin_lto.is_empty()); - let module = + let mut lto_module = B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise()); + if cgcx.lto == Lto::Fat { + let config = cgcx.config(ModuleKind::Regular); + lto_module = unsafe { lto_module.autodiff(cgcx, autodiff, typetrees, config).unwrap() }; + } // We are adding a single work item, so the cost doesn't matter. - vec![(WorkItem::LTO(module), 0)] + vec![(WorkItem::LTO(lto_module), 0)] } else { assert!(needs_fat_lto.is_empty()); let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules) @@ -985,6 +994,8 @@ pub(crate) enum Message { work_product: WorkProduct, }, + AddAutoDiffItems(Vec), + /// The frontend has finished generating everything for all codegen units. /// Sent from the main thread. CodegenComplete, @@ -1287,6 +1298,8 @@ fn start_executing_work( let mut needs_link = Vec::new(); let mut needs_fat_lto = Vec::new(); let mut needs_thin_lto = Vec::new(); + let mut autodiff_items = Vec::new(); + let mut typetrees = FxHashMap::::default(); let mut lto_import_only_modules = Vec::new(); let mut started_lto = false; @@ -1393,9 +1406,14 @@ fn start_executing_work( let needs_thin_lto = mem::take(&mut needs_thin_lto); let import_only_modules = mem::take(&mut lto_import_only_modules); - for (work, cost) in - generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules) - { + for (work, cost) in generate_lto_work( + &cgcx, + autodiff_items.clone(), + typetrees.clone(), + needs_fat_lto, + needs_thin_lto, + import_only_modules, + ) { let insertion_index = work_items .binary_search_by_key(&cost, |&(_, cost)| cost) .unwrap_or_else(|e| e); @@ -1508,7 +1526,16 @@ fn start_executing_work( } } - Message::CodegenDone { llvm_work_item, cost } => { + Message::CodegenDone { mut llvm_work_item, cost } => { + //// extract build typetrees + match &mut llvm_work_item { + WorkItem::Optimize(module) => { + let tt = B::typetrees(&mut module.module_llvm); + typetrees.extend(tt); + } + _ => {}, + } + // We keep the queue sorted by estimated processing cost, // so that more expensive items are processed earlier. This // is good for throughput as it gives the main thread more @@ -1549,6 +1576,10 @@ fn start_executing_work( codegen_state = Aborted; } + Message::AddAutoDiffItems(mut items) => { + autodiff_items.append(&mut items); + } + Message::WorkItem { result, worker_id } => { free_worker(worker_id); @@ -2000,6 +2031,10 @@ impl OngoingCodegen { drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::))); } + pub fn submit_autodiff_items(&self, items: Vec) { + drop(self.coordinator.sender.send(Box::new(Message::::AddAutoDiffItems(items)))); + } + pub fn check_for_errors(&self, sess: &Session) { self.shared_emitter_main.check(sess, false); } diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs index 198e5696357af..a8641ba9fbb30 100644 --- a/compiler/rustc_codegen_ssa/src/base.rs +++ b/compiler/rustc_codegen_ssa/src/base.rs @@ -590,7 +590,8 @@ pub fn codegen_crate( // Run the monomorphization collector and partition the collected items into // codegen units. - let codegen_units = tcx.collect_and_partition_mono_items(()).1; + let (_, autodiff_fncs, codegen_units) = tcx.collect_and_partition_mono_items(()); + let autodiff_fncs = autodiff_fncs.to_vec(); // Force all codegen_unit queries so they are already either red or green // when compile_codegen_unit accesses them. We are not able to re-execute @@ -659,6 +660,10 @@ pub fn codegen_crate( ); } + if !autodiff_fncs.is_empty() { + ongoing_codegen.submit_autodiff_items(autodiff_fncs); + } + // For better throughput during parallel processing by LLVM, we used to sort // CGUs largest to smallest. This would lead to better thread utilization // by, for example, preventing a large CGU from being processed last and @@ -982,7 +987,7 @@ pub fn provide(providers: &mut Providers) { config::OptLevel::SizeMin => config::OptLevel::Default, }; - let (defids, _) = tcx.collect_and_partition_mono_items(cratenum); + let (defids, _, _) = tcx.collect_and_partition_mono_items(cratenum); let any_for_speed = defids.items().any(|id| { let CodegenFnAttrs { optimize, .. } = tcx.codegen_fn_attrs(*id); diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 2e0840f2d1bc3..58019ae43129f 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -1,10 +1,11 @@ -use rustc_ast::{ast, attr, MetaItemKind, NestedMetaItem}; +use rustc_ast::{ast, attr, MetaItem, MetaItemKind, NestedMetaItem}; use rustc_attr::{list_contains_name, InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_errors::struct_span_err; use rustc_hir as hir; use rustc_hir::def::DefKind; use rustc_hir::def_id::{DefId, LocalDefId, LOCAL_CRATE}; use rustc_hir::{lang_items, weak_lang_items::WEAK_LANG_ITEMS, LangItem}; +use rustc_middle::middle::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; use rustc_middle::mir::mono::Linkage; use rustc_middle::query::Providers; @@ -13,6 +14,7 @@ use rustc_session::{lint, parse::feature_err}; use rustc_span::symbol::Ident; use rustc_span::{sym, Span}; use rustc_target::spec::{abi, SanitizerSet}; +use std::str::FromStr; use crate::errors; use crate::target_features::from_target_feature; @@ -697,6 +699,162 @@ fn check_link_name_xor_ordinal( } } +fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { + let attrs = tcx.get_attrs(id, sym::autodiff_into); + + let attrs = attrs + .into_iter() + .filter(|attr| attr.name_or_empty() == sym::autodiff_into) + .collect::>(); + + // check for exactly one autodiff attribute on extern block + let attr = match &attrs[..] { + &[] => return AutoDiffAttrs::inactive(), + &[elm] => elm, + x => { + tcx.sess + .struct_span_err(x[1].span, "autodiff attribute can only be applied once") + .span_label(x[1].span, "more than one") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let list = attr.meta_item_list().unwrap_or_default(); + + // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions + if list.len() == 0 { + return AutoDiffAttrs { + mode: DiffMode::Source, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + }; + } + + let mode = match &list[0] { + NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => { + p2.segments.first().unwrap().ident + } + _ => { + tcx.sess + .struct_span_err(attr.span, "attribute must contain autodiff mode") + .span_label(attr.span, "empty argument list") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + // parse mode + let mode = match mode.as_str() { + //map(|x| x.as_str()) { + "Forward" => DiffMode::Forward, + "Reverse" => DiffMode::Reverse, + _ => { + tcx.sess + .struct_span_err(attr.span, "mode should be either forward or reverse") + .span_label(attr.span, "invalid mode") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let ret_symbol = match &list[1] { + NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => { + p2.segments.first().unwrap().ident + } + _ => { + tcx.sess + .struct_span_err(attr.span, "autodiff attribute must contain the return activity") + .span_label(attr.span, "missing return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let ret_activity = match DiffActivity::from_str(ret_symbol.as_str()) { + Ok(x) => x, + Err(_) => { + tcx.sess + .struct_span_err(attr.span, "unknown return activity") + .span_label(attr.span, "invalid return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let mut arg_activities: Vec = vec![]; + for arg in &list[2..] { + let arg_symbol = match arg { + NestedMetaItem::MetaItem(MetaItem { + path: ref p2, kind: MetaItemKind::Word, .. + }) => p2.segments.first().unwrap().ident, + _ => { + tcx.sess + .struct_span_err( + attr.span, + "autodiff attribute must contain the return activity", + ) + .span_label(attr.span, "missing return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + match DiffActivity::from_str(arg_symbol.as_str()) { + Ok(arg_activity) => arg_activities.push(arg_activity), + Err(_) => { + tcx.sess + .struct_span_err(attr.span, "unknown return activity") + .span_label(attr.span, "invalid input activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + } + } + + if mode == DiffMode::Forward { + if ret_activity == DiffActivity::Active { + tcx.sess + .struct_span_err(attr.span, "Forward Mode is incompatible with Active ret") + .span_label(attr.span, "invalid return activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + if arg_activities.iter().filter(|&x| *x == DiffActivity::Active).count() > 0 { + tcx.sess + .struct_span_err(attr.span, "Forward Mode is incompatible with Active args") + .span_label(attr.span, "invalid input activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + } + + if mode == DiffMode::Reverse { + if ret_activity == DiffActivity::Duplicated + || ret_activity == DiffActivity::DuplicatedNoNeed + { + tcx.sess + .struct_span_err( + attr.span, + "Reverse Mode is only compatible with Active, None, or Const ret", + ) + .span_label(attr.span, "invalid return activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + } + + AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities } +} + pub fn provide(providers: &mut Providers) { - *providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers }; + *providers = + Providers { codegen_fn_attrs, should_inherit_track_caller, autodiff_attrs, ..*providers }; } diff --git a/compiler/rustc_codegen_ssa/src/traits/misc.rs b/compiler/rustc_codegen_ssa/src/traits/misc.rs index 04e2b8796c46a..5f64dd3367661 100644 --- a/compiler/rustc_codegen_ssa/src/traits/misc.rs +++ b/compiler/rustc_codegen_ssa/src/traits/misc.rs @@ -19,4 +19,5 @@ pub trait MiscMethods<'tcx>: BackendTypes { fn apply_target_cpu_attr(&self, llfn: Self::Function); /// Declares the extern "C" main function for the entry point. Returns None if the symbol already exists. fn declare_c_main(&self, fn_type: Self::Type) -> Option; + fn create_autodiff(&self) -> Vec; } diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index ecf5095d8a335..9c1be89580dc4 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -2,8 +2,10 @@ use crate::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use crate::back::write::{CodegenContext, FatLtoInput, ModuleConfig}; use crate::{CompiledModule, ModuleCodegen}; +use rustc_data_structures::fx::FxHashMap; use rustc_errors::{FatalError, Handler}; use rustc_middle::dep_graph::WorkProduct; +use rustc_middle::middle::autodiff_attrs::AutoDiffItem; pub trait WriteBackendMethods: 'static + Sized + Clone { type Module: Send + Sync; @@ -12,6 +14,7 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { type ModuleBuffer: ModuleBufferMethods; type ThinData: Send + Sync; type ThinBuffer: ThinBufferMethods; + type TypeTree: Clone; /// Merge all modules into main_module and returning it fn run_link( @@ -58,6 +61,15 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { ) -> Result; fn prepare_thin(module: ModuleCodegen) -> (String, Self::ThinBuffer); fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer); + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result<(), FatalError>; + fn typetrees(module: &mut Self::Module) -> FxHashMap; } pub trait ThinBufferMethods: Send + Sync { diff --git a/compiler/rustc_feature/src/builtin_attrs.rs b/compiler/rustc_feature/src/builtin_attrs.rs index e808e4815fe0b..2ed334569995b 100644 --- a/compiler/rustc_feature/src/builtin_attrs.rs +++ b/compiler/rustc_feature/src/builtin_attrs.rs @@ -353,6 +353,13 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[ ungated!(used, Normal, template!(Word, List: "compiler|linker"), WarnFollowing, @only_local: true), ungated!(link_ordinal, Normal, template!(List: "ordinal"), ErrorPreceding), + // Autodiff + ungated!( + autodiff_into, Normal, + template!(Word, List: r#""...""#), + DuplicatesOk, + ), + // Limits: ungated!(recursion_limit, CrateLevel, template!(NameValueStr: "N"), FutureWarnFollowing), ungated!(type_length_limit, CrateLevel, template!(NameValueStr: "N"), FutureWarnFollowing), diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs index 57ca709267a7e..4439550d8d037 100644 --- a/compiler/rustc_interface/src/tests.rs +++ b/compiler/rustc_interface/src/tests.rs @@ -767,6 +767,7 @@ fn test_unstable_options_tracking_hash() { tracked!(debug_macros, true); tracked!(dep_info_omit_d_target, true); tracked!(dual_proc_macros, true); + tracked!(enzyme_print_activity, false); tracked!(dwarf_version, Some(5)); tracked!(emit_thin_lto, false); tracked!(export_executable_symbols, true); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 4390486b0deb1..e7db075aefa2f 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -96,6 +96,11 @@ extern "C" char *LLVMRustGetLastError(void) { return Ret; } +extern "C" LLVMTypeRef LLVMRustGetFunctionType(LLVMValueRef Fn) { + auto Ftype = unwrap(Fn)->getFunctionType(); + return wrap(Ftype); +} + extern "C" void LLVMRustSetLastError(const char *Err) { free((void *)LastError); LastError = strdup(Err); diff --git a/compiler/rustc_middle/src/arena.rs b/compiler/rustc_middle/src/arena.rs index 1d573a746b918..acb0a25f087eb 100644 --- a/compiler/rustc_middle/src/arena.rs +++ b/compiler/rustc_middle/src/arena.rs @@ -97,6 +97,7 @@ macro_rules! arena_types { [] upvars_mentioned: rustc_data_structures::fx::FxIndexMap, [] object_safety_violations: rustc_middle::traits::ObjectSafetyViolation, [] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>, + [] autodiff_item: rustc_middle::middle::autodiff_attrs::AutoDiffItem, [decode] attribute: rustc_ast::Attribute, [] name_set: rustc_data_structures::unord::UnordSet, [] ordered_name_set: rustc_data_structures::fx::FxIndexSet, diff --git a/compiler/rustc_middle/src/middle/autodiff_attrs.rs b/compiler/rustc_middle/src/middle/autodiff_attrs.rs new file mode 100644 index 0000000000000..2412df725fe2b --- /dev/null +++ b/compiler/rustc_middle/src/middle/autodiff_attrs.rs @@ -0,0 +1,94 @@ +use crate::middle::typetree::TypeTree; +use std::str::FromStr; + +#[allow(dead_code)] +#[derive(Clone, Copy, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub enum DiffMode { + Inactive, + Source, + Forward, + Reverse, +} + +#[allow(dead_code)] +#[derive(Clone, Copy, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub enum DiffActivity { + None, + Active, + Const, + Duplicated, + DuplicatedNoNeed, +} + +impl FromStr for DiffActivity { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "None" => Ok(DiffActivity::None), + "Active" => Ok(DiffActivity::Active), + "Const" => Ok(DiffActivity::Const), + "Duplicated" => Ok(DiffActivity::Duplicated), + "DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed), + _ => Err(()), + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub struct AutoDiffAttrs { + pub mode: DiffMode, + pub ret_activity: DiffActivity, + pub input_activity: Vec, +} + +impl AutoDiffAttrs { + pub fn inactive() -> Self { + AutoDiffAttrs { + mode: DiffMode::Inactive, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + } + } + + pub fn is_active(&self) -> bool { + match self.mode { + DiffMode::Inactive => false, + _ => true, + } + } + + pub fn is_source(&self) -> bool { + match self.mode { + DiffMode::Source => true, + _ => false, + } + } + pub fn apply_autodiff(&self) -> bool { + match self.mode { + DiffMode::Inactive => false, + DiffMode::Source => false, + _ => true, + } + } + + pub fn into_item( + self, + source: String, + target: String, + inputs: Vec, + output: TypeTree, + ) -> AutoDiffItem { + AutoDiffItem { source, target, inputs, output, attrs: self } + } +} + +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub struct AutoDiffItem { + pub source: String, + pub target: String, + pub attrs: AutoDiffAttrs, + pub inputs: Vec, + pub output: TypeTree, +} diff --git a/compiler/rustc_middle/src/middle/mod.rs b/compiler/rustc_middle/src/middle/mod.rs index 85c5af9ca13cb..43e60c2571cc0 100644 --- a/compiler/rustc_middle/src/middle/mod.rs +++ b/compiler/rustc_middle/src/middle/mod.rs @@ -1,3 +1,4 @@ +pub mod autodiff_attrs; pub mod codegen_fn_attrs; pub mod debugger_visualizer; pub mod dependency_format; @@ -32,6 +33,7 @@ pub mod privacy; pub mod region; pub mod resolve_bound_vars; pub mod stability; +pub mod typetree; pub fn provide(providers: &mut crate::query::Providers) { limits::provide(providers); diff --git a/compiler/rustc_middle/src/middle/typetree.rs b/compiler/rustc_middle/src/middle/typetree.rs new file mode 100644 index 0000000000000..4049d32540bd2 --- /dev/null +++ b/compiler/rustc_middle/src/middle/typetree.rs @@ -0,0 +1,39 @@ +use std::fmt; +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub enum Kind { + Anything, + Integer, + Pointer, + Half, + Float, + Double, + Unknown, +} + +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub struct TypeTree(pub Vec); + +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub struct Type { + pub offset: isize, + pub size: usize, + pub kind: Kind, + pub child: TypeTree, +} + +impl Type { + pub fn add_offset(self, add: isize) -> Self { + let offset = match self.offset { + -1 => add, + x => add + x, + }; + + Self { size: self.size, kind: self.kind, child: self.child, offset } + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} diff --git a/compiler/rustc_middle/src/query/erase.rs b/compiler/rustc_middle/src/query/erase.rs index e20e9d9312c1b..7f33139d67b8d 100644 --- a/compiler/rustc_middle/src/query/erase.rs +++ b/compiler/rustc_middle/src/query/erase.rs @@ -190,6 +190,10 @@ impl EraseType for (&'_ T0, &'_ [T1]) { type Result = [u8; size_of::<(&'static (), &'static [()])>()]; } +impl EraseType for (&'_ T0, &'_ [T1], &'_ [T2]) { + type Result = [u8; size_of::<(&'static (), &'static [()], &'static [()])>()]; +} + macro_rules! trivial { ($($ty:ty),+ $(,)?) => { $( diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index 062b03e71fdc1..8d85928374b5c 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -10,6 +10,7 @@ use crate::dep_graph; use crate::infer::canonical::{self, Canonical}; use crate::lint::LintExpectation; use crate::metadata::ModChild; +use crate::middle::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem}; use crate::middle::codegen_fn_attrs::CodegenFnAttrs; use crate::middle::debugger_visualizer::DebuggerVisualizerFile; use crate::middle::exported_symbols::{ExportedSymbol, SymbolExportInfo}; @@ -1229,6 +1230,13 @@ rustc_queries! { separate_provide_extern } + /// The list autodiff extern functions in current crate + query autodiff_attrs(def_id: DefId) -> &'tcx AutoDiffAttrs { + desc { |tcx| "computing autodiff attributes of `{}`", tcx.def_path_str(def_id) } + arena_cache + cache_on_disk_if { def_id.is_local() } + } + query asm_target_features(def_id: DefId) -> &'tcx FxIndexSet { desc { |tcx| "computing target features for inline asm of `{}`", tcx.def_path_str(def_id) } } @@ -1878,7 +1886,7 @@ rustc_queries! { separate_provide_extern } - query collect_and_partition_mono_items(_: ()) -> (&'tcx DefIdSet, &'tcx [CodegenUnit<'tcx>]) { + query collect_and_partition_mono_items(_: ()) -> (&'tcx DefIdSet, &'tcx [AutoDiffItem], &'tcx [CodegenUnit<'tcx>]) { eval_always desc { "collect_and_partition_mono_items" } } diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 739d4fa886ec3..31d60e97cded7 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -176,6 +176,8 @@ pub struct ResolverGlobalCtxt { /// Mapping from ident span to path span for paths that don't exist as written, but that /// exist under `std`. For example, wrote `str::from_utf8` instead of `std::str::from_utf8`. pub confused_type_with_std_module: FxHashMap, + /// Mapping of autodiff function IDs + pub autodiff_map: FxHashMap, pub doc_link_resolutions: FxHashMap, pub doc_link_traits_in_scope: FxHashMap>, pub all_macro_rules: FxHashMap>, diff --git a/compiler/rustc_monomorphize/Cargo.toml b/compiler/rustc_monomorphize/Cargo.toml index fe097424e8ad4..b75941e71989a 100644 --- a/compiler/rustc_monomorphize/Cargo.toml +++ b/compiler/rustc_monomorphize/Cargo.toml @@ -18,3 +18,4 @@ rustc_middle = { path = "../rustc_middle" } rustc_session = { path = "../rustc_session" } rustc_span = { path = "../rustc_span" } rustc_target = { path = "../rustc_target" } +rustc_symbol_mangling = { path = "../rustc_symbol_mangling" } diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index 82fee7c8dfe58..baac8d98e8b1b 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -1244,6 +1244,7 @@ impl<'v> RootCollector<'_, 'v> { /// monomorphized copy of the start lang item based on /// the return type of `main`. This is not needed when /// the user writes their own `start` manually. + /// TODO: remove annotations after automatic differentation pass fn push_extra_entry_roots(&mut self) { let Some((main_def_id, EntryFnType::Main { .. })) = self.entry_fn else { return; diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 4009e28924068..fa39f35dc334e 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -1282,12 +1282,12 @@ pub fn provide(providers: &mut Providers) { providers.collect_and_partition_mono_items = collect_and_partition_mono_items; providers.is_codegened_item = |tcx, def_id| { - let (all_mono_items, _) = tcx.collect_and_partition_mono_items(()); + let (all_mono_items, _, _) = tcx.collect_and_partition_mono_items(()); all_mono_items.contains(&def_id) }; providers.codegen_unit = |tcx, name| { - let (_, all) = tcx.collect_and_partition_mono_items(()); + let (_, _, all) = tcx.collect_and_partition_mono_items(()); all.iter() .find(|cgu| cgu.name() == name) .unwrap_or_else(|| panic!("failed to find cgu with name {name:?}")) diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs index a8a27e761cb3f..19aa8a308b3a2 100644 --- a/compiler/rustc_passes/src/check_attr.rs +++ b/compiler/rustc_passes/src/check_attr.rs @@ -233,6 +233,7 @@ impl CheckAttrVisitor<'_> { self.check_generic_attr(hir_id, attr, target, Target::Fn); self.check_proc_macro(hir_id, target, ProcMacroKind::Derive) } + sym::autodiff_into => self.check_autodiff(hir_id, attr, span, target), _ => {} } @@ -2394,6 +2395,20 @@ impl CheckAttrVisitor<'_> { self.abort.set(true); } } + + /// Checks if `#[autodiff]` is applied to an item other than a foreign module. + fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, _span: Span, _target: Target) { + //match target { + // Target::ForeignMod => {} + // _ => { + // self.tcx + // .sess + // .struct_span_err(attr.span, "attribute should be applied to an `extern` block") + // .span_label(span, "not an `extern` block") + // .emit(); + // } + //} + } } impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> { diff --git a/compiler/rustc_resolve/src/lib.rs b/compiler/rustc_resolve/src/lib.rs index 501747df5c908..58e6d82595e7a 100644 --- a/compiler/rustc_resolve/src/lib.rs +++ b/compiler/rustc_resolve/src/lib.rs @@ -1522,6 +1522,7 @@ impl<'a, 'tcx> Resolver<'a, 'tcx> { trait_impls: self.trait_impls, proc_macros, confused_type_with_std_module, + autodiff_map: Default::default(), doc_link_resolutions: self.doc_link_resolutions, doc_link_traits_in_scope: self.doc_link_traits_in_scope, all_macro_rules: self.all_macro_rules, diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 30c8b9d67002c..e6401d2fbfbab 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -1537,6 +1537,8 @@ options! { "enables LTO for dylib crate type"), emit_stack_sizes: bool = (false, parse_bool, [UNTRACKED], "emit a section containing stack size metadata (default: no)"), + enzyme_print_activity: bool = (false, parse_bool, [TRACKED], + "print type trees for functions passed to enzyme"), emit_thin_lto: bool = (true, parse_bool, [TRACKED], "emit the bc module with thin LTO info (default: yes)"), export_executable_symbols: bool = (false, parse_bool, [TRACKED], diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 3f99d2a4b1ffb..a461745d9162c 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -437,6 +437,7 @@ symbols! { attributes, augmented_assignments, auto_traits, + autodiff_into, automatically_derived, avx, avx512_target_feature, @@ -1022,6 +1023,7 @@ symbols! { miri, misc, mmx_reg, + mode, modifiers, module, module_path, diff --git a/config.example.toml b/config.example.toml index 66fa91d4bad15..6050848cb3a05 100644 --- a/config.example.toml +++ b/config.example.toml @@ -142,6 +142,9 @@ change-id = 116998 # Whether or not to specify `-DLLVM_TEMPORARILY_ALLOW_OLD_TOOLCHAIN=YES` #allow-old-toolchain = false +# Whether to build enzyme +#enzyme = false + # Whether to include the Polly optimizer. #polly = false diff --git a/library/autodiff/Cargo.lock b/library/autodiff/Cargo.lock new file mode 100644 index 0000000000000..b11b872e7dbd9 --- /dev/null +++ b/library/autodiff/Cargo.lock @@ -0,0 +1,314 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "autodiff" +version = "0.1.0" +dependencies = [ + "macrotest", + "ndarray", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", + "trybuild", +] + +[[package]] +name = "basic-toml" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24c12265665aebaa236af9bbe266681bcc9c5666192119e3d8335cf083aca26f" +dependencies = [ + "serde", +] + +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + +[[package]] +name = "itoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" + +[[package]] +name = "macrotest" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7489ae0986ce45414b7b3122c2e316661343ecf396b206e3e15f07c846616f10" +dependencies = [ + "diff", + "glob", + "prettyplease", + "serde", + "serde_json", + "syn 1.0.109", + "toml", +] + +[[package]] +name = "matrixmultiply" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + +[[package]] +name = "prettyplease" +version = "0.1.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8646e95016a7a6c4adea95bafa8a16baab64b583356217f2c85db4a39d9a86" +dependencies = [ + "proc-macro2", + "syn 1.0.109", +] + +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + +[[package]] +name = "proc-macro2" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "ryu" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + +[[package]] +name = "serde" +version = "1.0.190" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91d3c334ca1ee894a2c6f6ad698fe8c435b76d504b13d436f0685d648d6d96f7" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.190" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67c5609f394e5c2bd7fc51efda478004ea80ef42fee983d5c67a65e34f32c0e3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + +[[package]] +name = "serde_json" +version = "1.0.107" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "termcolor" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "toml" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" +dependencies = [ + "serde", +] + +[[package]] +name = "trybuild" +version = "1.0.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "196a58260a906cedb9bf6d8034b6379d0c11f552416960452f267402ceeddff1" +dependencies = [ + "basic-toml", + "glob", + "once_cell", + "serde", + "serde_derive", + "serde_json", + "termcolor", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" diff --git a/library/autodiff/Cargo.toml b/library/autodiff/Cargo.toml new file mode 100644 index 0000000000000..cbbff8d375e3d --- /dev/null +++ b/library/autodiff/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "autodiff" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + + +[profile.release] +lto = "fat" + +[profile.dev] +lto = "fat" + +[lib] +name = "autodiff" +proc-macro = true + +[dependencies] +quote = "1.0" +proc-macro2 = "1" +proc-macro-error = "1" +syn = { version = "1", features = ["extra-traits", "full", "visit", "visit-mut"]} + +[dev-dependencies] +macrotest = "1" +trybuild = "1" +ndarray = "0.15" diff --git a/library/autodiff/examples/array.rs b/library/autodiff/examples/array.rs new file mode 100644 index 0000000000000..60c6b63fd84cb --- /dev/null +++ b/library/autodiff/examples/array.rs @@ -0,0 +1,23 @@ +use autodiff::autodiff; + +#[autodiff(d_array, Reverse, Active, Duplicated)] +fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 { + arr[0][0][0] * arr[1][1][1] +} + +fn main() { + let arr = [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]; + let mut d_arr = [[[0.0; 2]; 2]; 2]; + + d_array(&arr, &mut d_arr, 1.0); + + dbg!(&d_arr); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/box.rs b/library/autodiff/examples/box.rs new file mode 100644 index 0000000000000..5d4f114830bf4 --- /dev/null +++ b/library/autodiff/examples/box.rs @@ -0,0 +1,24 @@ +use autodiff::autodiff; + +#[autodiff(cos_box, Reverse, Active, Duplicated)] +fn sin(x: &Box) -> f32 { + f32::sin(**x) +} + +fn main() { + let x = Box::::new(3.14); + let mut df_dx = Box::::new(0.0); + cos_box(&x, &mut df_dx, 1.0); + + dbg!(&df_dx); + + assert!(*df_dx == f32::cos(*x)); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/broken_matvec.rs b/library/autodiff/examples/broken_matvec.rs new file mode 100644 index 0000000000000..0c4b2cfe6e927 --- /dev/null +++ b/library/autodiff/examples/broken_matvec.rs @@ -0,0 +1,34 @@ +use autodiff::autodiff; + +type Matrix = Vec>; +type Vector = Vec; + +#[autodiff(d_matvec, Forward, Const)] +fn matvec(#[dup] mat: &Matrix, vec: &Vector, #[dup] out: &mut Vector) { + for i in 0..mat.len() - 1 { + for j in 0..mat[0].len() - 1 { + out[i] += mat[i][j] * vec[j]; + } + } +} + +fn main() { + let mat = vec![vec![1.0, 1.0], vec![1.0, 1.0]]; + let mut d_mat = vec![vec![0.0, 0.0], vec![0.0, 0.0]]; + let inp = vec![1.0, 1.0]; + let mut out = vec![0.0, 0.0]; + let mut out_tang = vec![0.0, 1.0]; + + //matvec(&mat, &inp, &mut out); + d_matvec(&mat, &mut d_mat, &inp, &mut out, &mut out_tang); + + dbg!(&out); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/hessian_sin.rs b/library/autodiff/examples/hessian_sin.rs new file mode 100644 index 0000000000000..6b1e776476fd2 --- /dev/null +++ b/library/autodiff/examples/hessian_sin.rs @@ -0,0 +1,28 @@ +use autodiff::autodiff; + +fn sin(x: &Vec, y: &mut f32) { + *y = x.into_iter().map(|x| f32::sin(*x)).sum() +} + +#[autodiff(sin, Reverse, Const, Duplicated, Duplicated)] +fn jac(x: &Vec, d_x: &mut Vec, y: &mut f32, y_t: &f32); + +#[autodiff(jac, Forward, Const, Duplicated, Const, Const, Const)] +fn hessian(x: &Vec, y_x: &Vec, d_x: &mut Vec, y: &mut f32, y_t: &f32); + +fn main() { + let inp = vec![3.1415 / 2., 1.0, 0.5]; + let mut d_inp = vec![0.0, 0.0, 0.0]; + let mut y = 0.0; + let tang = vec![1.0, 0.0, 0.0]; + hessian(&inp, &tang, &mut d_inp, &mut y, &1.0); + dbg!(&d_inp); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/ndarray.rs b/library/autodiff/examples/ndarray.rs new file mode 100644 index 0000000000000..34402c43cb3e6 --- /dev/null +++ b/library/autodiff/examples/ndarray.rs @@ -0,0 +1,25 @@ +use autodiff::autodiff; + +use ndarray::Array1; + +#[autodiff(d_collect, Reverse, Active)] +fn collect(#[dup] x: &Array1) -> f32 { + x[0] +} + +fn main() { + let a = Array1::zeros(19); + let mut d_a = Array1::zeros(19); + + d_collect(&a, &mut d_a, 1.0); + + dbg!(&d_a); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/rosenbrock_fwd.rs b/library/autodiff/examples/rosenbrock_fwd.rs new file mode 100644 index 0000000000000..a3ab7a47578d0 --- /dev/null +++ b/library/autodiff/examples/rosenbrock_fwd.rs @@ -0,0 +1,34 @@ +use autodiff::autodiff; + +#[autodiff(d_rosenbrock, Forward, DuplicatedNoNeed)] +fn rosenbrock(#[dup] x: &[f64; 2]) -> f64 { + let mut res = 0.0; + for i in 0..(x.len() - 1) { + let a = x[i + 1] - x[i] * x[i]; + let b = x[i] - 1.0; + res += 100.0 * a * a + b * b; + } + res +} + +fn main() { + let x = [3.14, 2.4]; + let output = rosenbrock(&x); + println!("{output}"); + let df_dx = d_rosenbrock(&x, &[1.0, 0.0]); + let df_dy = d_rosenbrock(&x, &[0.0, 1.0]); + + dbg!(&df_dx, &df_dy); + + // https://www.wolframalpha.com/input?i2d=true&i=x%3D3.14%3B+y%3D2.4%3B+D%5Brosenbrock+function%5C%2840%29x%5C%2844%29+y%5C%2841%29+%2Cy%5D + assert!((df_dx - 9373.54).abs() < 0.1); + assert!((df_dy - (-1491.92)).abs() < 0.1); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/rosenbrock_fwd_iter.rs b/library/autodiff/examples/rosenbrock_fwd_iter.rs new file mode 100644 index 0000000000000..1648014392f19 --- /dev/null +++ b/library/autodiff/examples/rosenbrock_fwd_iter.rs @@ -0,0 +1,34 @@ +use autodiff::autodiff; + +#[autodiff(d_rosenbrock, Forward, DuplicatedNoNeed)] +fn rosenbrock(#[dup] x: &[f64; 2]) -> f64 { + (0..x.len() - 1) + .map(|i| { + let (a, b) = (x[i + 1] - x[i] * x[i], x[i] - 1.0); + 100.0 * a * a + b * b + }) + .sum() +} + +fn main() { + let x = [3.14f64, 2.4]; + let output = rosenbrock(&x); + println!("{output}"); + + let df_dx = d_rosenbrock(&x, &[1.0, 0.0]); + let df_dy = d_rosenbrock(&x, &[0.0, 1.0]); + + dbg!(&df_dx, &df_dy); + + // https://www.wolframalpha.com/input?i2d=true&i=x%3D3.14%3B+y%3D2.4%3B+D%5Brosenbrock+function%5C%2840%29x%5C%2844%29+y%5C%2841%29+%2Cy%5D + assert!((df_dx - 9373.54).abs() < 0.1); + assert!((df_dy - (-1491.92)).abs() < 0.1); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/rosenbrock_rev.rs b/library/autodiff/examples/rosenbrock_rev.rs new file mode 100644 index 0000000000000..b4ce00b5afe9d --- /dev/null +++ b/library/autodiff/examples/rosenbrock_rev.rs @@ -0,0 +1,33 @@ +use autodiff::autodiff; + +#[autodiff(d_rosenbrock, Reverse, Active)] +fn rosenbrock(#[dup] x: &[f64; 2]) -> f64 { + let mut res = 0.0; + for i in 0..(x.len() - 1) { + let a = x[i + 1] - x[i] * x[i]; + let b = x[i] - 1.0; + res += 100.0 * a * a + b * b; + } + res +} + +fn main() { + let x = [3.14, 2.4]; + let output = rosenbrock(&x); + println!("{output}"); + + let mut df_dx = [0.0f64; 2]; + d_rosenbrock(&x, &mut df_dx, 1.0); + + // https://www.wolframalpha.com/input?i2d=true&i=x%3D3.14%3B+y%3D2.4%3B+D%5Brosenbrock+function%5C%2840%29x%5C%2844%29+y%5C%2841%29+%2Cy%5D + assert!((df_dx[0] - 9373.54).abs() < 0.01); + assert!((df_dx[1] - (-1491.92)).abs() < 0.01); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/sin.rs b/library/autodiff/examples/sin.rs new file mode 100644 index 0000000000000..1655b1e7ecd09 --- /dev/null +++ b/library/autodiff/examples/sin.rs @@ -0,0 +1,36 @@ +use autodiff::autodiff; + +#[autodiff(cos_inplace, Reverse, Const)] +fn sin_inplace(#[dup] x: &f32, #[dup] y: &mut f32) { + *y = x.sin(); +} + + +fn main() { + // Here we can use ==, even though we work on f32. + // Enzyme will recognize the sin function and replace it with llvm's cos function (see below). + // Calling f32::cos directly will also result in calling llvm's cos function. + let a = 3.1415; + let mut da = 0.0; + let mut y = 0.0; + cos_inplace(&a, &mut da, &mut y, &mut 1.0); + + dbg!(&a, &da, &y); + assert!(da - f32::cos(a) == 0.0); +} + +// Just for curious readers, this is the (inner) function that Enzyme does generate: +// define internal { float } @diffe_ZN3sin3sin17h18f17f71fe94e58fE(float %0, float %1) unnamed_addr #35 { +// %3 = call fast float @llvm.cos.f32(float %0) +// %4 = fmul fast float %1, %3 +// %5 = insertvalue { float } undef, float %4, 0 +// ret { float } %5 +// } + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/sqrt.rs b/library/autodiff/examples/sqrt.rs new file mode 100644 index 0000000000000..d15c6f5ec2051 --- /dev/null +++ b/library/autodiff/examples/sqrt.rs @@ -0,0 +1,21 @@ +use autodiff::autodiff; + +#[autodiff(d_sqrt, Reverse, Active)] +fn sqrt(#[active] a: f32, #[dup] b: &f32, c: &f32, #[active] d: f32) -> f32 { + a * (b * b + c*c*d*d).sqrt() +} + +fn main() { + let mut d_b = 0.0; + + let (d_a, d_d) = d_sqrt(1.0, &1.0, &mut d_b, &1.0, 1.0, 1.0); + dbg!(d_a, d_b, d_d); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/struct.rs b/library/autodiff/examples/struct.rs new file mode 100644 index 0000000000000..1235307fdbcbf --- /dev/null +++ b/library/autodiff/examples/struct.rs @@ -0,0 +1,33 @@ +use autodiff::autodiff; + +use std::io; + +// Will be represented as {f32, i16, i16} when passed by reference +// will be represented as i64 if passed by value +struct Foo { + c1: i16, + a: f32, + c2: i16, +} + +#[autodiff(cos, Reverse, Active, Duplicated)] +fn sin(x: &Foo) -> f32 { + assert!(x.c1 < x.c2); + f32::sin(x.a) +} + +fn main() { + let mut s = String::new(); + println!("Please enter a value for c1"); + io::stdin().read_line(&mut s).unwrap(); + let c2 = s.trim_end().parse::().unwrap(); + dbg!(c2); + + let foo = Foo { c1: 4, a: 3.14, c2 }; + let mut df_dfoo = Foo { c1: 4, a: 0.0, c2 }; + + dbg!(df_dfoo.a); + dbg!(cos(&foo, &mut df_dfoo, 1.0)); + dbg!(df_dfoo.a); + dbg!(f32::cos(foo.a)); +} diff --git a/library/autodiff/examples/vec.rs b/library/autodiff/examples/vec.rs new file mode 100644 index 0000000000000..e82618fac4dac --- /dev/null +++ b/library/autodiff/examples/vec.rs @@ -0,0 +1,24 @@ +use autodiff::autodiff; + +#[autodiff(d_sum, Forward, Duplicated)] +fn sum(#[dup] x: &Vec>) -> f32 { + x.into_iter().map(|x| x.into_iter().map(|x| x.sqrt())).flatten().sum() +} + +fn main() { + let a = vec![vec![1.0, 2.0, 4.0, 8.0]]; + //let mut b = vec![vec![0.0, 0.0, 0.0, 0.0]]; + let b = vec![vec![1.0, 0.0, 0.0, 0.0]]; + + dbg!(&d_sum(&a, &b)); + + dbg!(&b); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples_broken/biquad.rs b/library/autodiff/examples_broken/biquad.rs new file mode 100644 index 0000000000000..7689b1cd1fc51 --- /dev/null +++ b/library/autodiff/examples_broken/biquad.rs @@ -0,0 +1,54 @@ +use autodiff::autodiff; + +#[derive(Debug)] +struct Biquad { + coeffs: [[f32; 5]; N], +} + +impl Biquad { + pub fn new() -> Self { + Biquad { coeffs: [[0.0; 5]; N] } + } + + pub fn process(&self, samples: &[f32], target: &[f32]) -> f32 { + // do some horrible inefficient biquad filtering + let mut samples = samples.to_vec(); + let mut samples_out = vec![0.0; samples.len()]; + + for coeff_set in self.coeffs { + for idx in 0..samples.len() { + samples_out[idx] = coeff_set[0] * samples[idx]; + + if idx > 0 { + samples_out[idx] += coeff_set[1] * samples[idx - 1] - + coeff_set[3] * samples_out[idx - 1]; + } + if idx > 1 { + samples_out[idx] += coeff_set[2] * samples[idx - 2] - + coeff_set[4] * samples_out[idx - 2]; + } + } + + (samples, samples_out) = (samples_out, samples); + } + + samples_out.into_iter().zip(target.into_iter()).map(|(a, b)| a - b).sum() + } + + #[autodiff(Self::process, Reverse, Active)] + pub fn deriv(#[dup] &self, params: &mut Self, samples: &[f32], target: &[f32], ret_adj: f32); +} + +fn main() { + let biquad = Biquad::<10>::new(); + let mut dbiquad = Biquad::<10>::new(); + + // create ramp and pulse train + let signal = (0..1024).map(|x| (x as f32) / 1024.0).collect::>(); + let target = (0..1024).map(|x| if x % 2 == 0 { 0.0 } else { 1.0 }).collect::>(); + + dbg!(&biquad.process(&signal, &target)); + biquad.deriv(&mut dbiquad, &signal, &target, 1.0); + + dbg!(&dbiquad); +} diff --git a/library/autodiff/examples_broken/broken_iter.rs b/library/autodiff/examples_broken/broken_iter.rs new file mode 100644 index 0000000000000..16d205f7373c8 --- /dev/null +++ b/library/autodiff/examples_broken/broken_iter.rs @@ -0,0 +1,20 @@ +#![feature(bench_black_box)] +use autodiff::autodiff; +use std::ptr; + +#[autodiff(sin_vec, Reverse, Active)] +fn cos_vec(#[dup] x: &Vec) -> f32 { + // uses enum internally and breaks + let res = x.into_iter().collect::>(); + + *res[0] +} + +fn main() { + let x = vec![1.0, 1.0, 1.0]; + let mut d_x = vec![0.0; 3]; + + sin_vec(&x, &mut d_x, 1.0); + + dbg!(&d_x, &x); +} diff --git a/library/autodiff/examples_broken/broken_recursive.rs b/library/autodiff/examples_broken/broken_recursive.rs new file mode 100644 index 0000000000000..a1f3ff25eb511 --- /dev/null +++ b/library/autodiff/examples_broken/broken_recursive.rs @@ -0,0 +1,66 @@ +#![feature(bench_black_box)] +use autodiff::autodiff; + +// TODO: As seen by the bloated code generated for the iterative version, +// we definetly have to disable unroll, slpvec, loop-vec before AD. +// We also should check if we have other opts that Julia, C++, Fortran etc. don't have +// and which could make our input code more "complex". +// We then however have to start doing whole-module opt after AD to re-include them, +// instead of just using enzyme to optimize the generated function. + +#[autodiff(d_power_recursive, Forward, DuplicatedNoNeed)] +fn power_recursive(#[dup] a: f64, n: i32) -> f64 { + if n == 0 { + return 1.0; + } + return a * power_recursive(a, n - 1); +} + +#[autodiff(d_power_iterative, Reverse, DuplicatedNoNeed)] +fn power_iterative(#[active] a: f64, n: i32) -> f64 { + let mut res = 1.0; + for _ in 0..n { + res *= a; + } + res +} + +fn main() { + // d/dx x^n = n * x^(n-1) + let n = 4; + let nf = n as f64; + let a = 1.337; + assert!(power_recursive(a, n) == power_iterative(a, n)); + let dpr = d_power_recursive(a, 1.0, n); + let dpi = d_power_iterative(a, n, 1.0); + let control = nf * a.powi(n - 1); + dbg!(dpr); + dbg!(dpi); + dbg!(control); + assert!(dpr == control); + assert!(dpi == control); +} + +// Again, for the curious. We can find n * x^(n-1) nicely in the LLVM-IR +// +// define internal double @fwddiffe_ZN9recursive15power_recursive17h789de751cfc6154dE(double %0, double %1, i32 %2) unnamed_addr #8 { +// => if (n == 0) goto 5: and return 0. Correct, since for n==0 we have 0 * x ^ (0-1) = 0 +// => if (n != 0) goto 7: +// %4 = icmp eq i32 %2, 0 +// br i1 %4, label %5, label %7 +// +// 5: ; preds = %7, %3 +// %6 = phi fast double [ %14, %7 ], [ 0.000000e+00, %3 ] +// ret double %6 +// +// 7: ; preds = %3 +// => reduce n by 1, +// %8 = add i32 %2, -1 +// %9 = call { double, double } @fwddiffe_ZN9recursive15power_recursive17h789de751cfc6154dE.1229(double %0, double %1, i32 %8) +// %10 = extractvalue { double, double } %9, 0 +// %11 = extractvalue { double, double } %9, 1 +// %12 = fmul fast double %11, %0 +// %13 = fmul fast double %1, %10 +// %14 = fadd fast double %12, %13 +// br label %5 +// } diff --git a/library/autodiff/examples_broken/broken_second_order.rs b/library/autodiff/examples_broken/broken_second_order.rs new file mode 100644 index 0000000000000..8b427d7dae36a --- /dev/null +++ b/library/autodiff/examples_broken/broken_second_order.rs @@ -0,0 +1,17 @@ +#![feature(bench_black_box)] +use autodiff::autodiff; + +fn sin(x: &f32) -> f32 { + f32::sin(*x) +} + +#[autodiff(sin, Reverse, Active, Active)] +fn cos(x: &f32, adj: f32) -> f32; + +//#[autodiff(cos, Reverse, Active, Active, Const)] +//fn neg_sin(x: &f32, adj: f32, adj_sec: f32) -> f32; + +fn main() { + dbg!(&cos(&1.0, 1.0)); + //dbg!(&neg_sin(&1.0, 1.0, 1.0)); +} diff --git a/library/autodiff/src/gen.rs b/library/autodiff/src/gen.rs new file mode 100644 index 0000000000000..68aae56ea3311 --- /dev/null +++ b/library/autodiff/src/gen.rs @@ -0,0 +1,217 @@ +use crate::parser::{is_ref_mut, PrimalSig}; +use crate::parser::{Activity, DiffItem, Mode}; +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::{format_ident, quote}; +use syn::{parse_quote, FnArg, Ident, Pat, ReturnType, Type}; + +pub(crate) fn generate_header(item: &DiffItem) -> TokenStream { + let mode = match item.header.mode { + Mode::Forward => format_ident!("Forward"), + Mode::Reverse => format_ident!("Reverse"), + }; + let ret_act = item.header.ret_act.to_ident(); + let param_act = item.params.iter().map(|x| x.to_ident()); + + quote!(#[autodiff_into(#mode, #ret_act, #( #param_act, )*)]) +} + +pub(crate) fn primal_fnc(item: &mut DiffItem) -> TokenStream { + // construct body of primal if not given + let body = item.block.clone().map(|x| quote!(#x)).unwrap_or_else(|| { + let header_fnc = &item.header.name; + //let primal_wrapper = format_ident!("primal_{}", item.primal.ident); + //item.primal.ident = primal_wrapper.clone(); + let inputs = item.primal.inputs.iter().map(|x| only_ident(x)).collect::>(); + + quote!({ + #header_fnc(#(#inputs,)*) + }) + }); + + let sig = &item.primal; + let PrimalSig { ident, inputs, output } = sig; + + let ident = + if item.block.is_some() { ident.clone() } else { format_ident!("primal_{}", ident) }; + + let sig = quote!(fn #ident(#(#inputs,)*) #output); + + quote!( + #[autodiff_into] + #sig + #body + ) +} + +fn only_ident(arg: &FnArg) -> Ident { + match arg { + FnArg::Receiver(_) => format_ident!("self"), + FnArg::Typed(t) => match &*t.pat { + Pat::Ident(ident) => ident.ident.clone(), + _ => panic!(""), + }, + } +} + +fn only_type(arg: &FnArg) -> Type { + match arg { + FnArg::Receiver(_) => parse_quote!(Self), + FnArg::Typed(t) => match &*t.ty { + Type::Reference(t) => *t.elem.clone(), + x => x.clone(), + }, + } +} + +fn as_ref_mut(arg: &FnArg, name: &str, mutable: bool) -> FnArg { + match arg { + FnArg::Receiver(_) => { + let name = format_ident!("{}_self", name); + if mutable { parse_quote!(#name: &mut Self) } else { parse_quote!(#name: &Self) } + } + FnArg::Typed(t) => { + let inner = match &*t.ty { + Type::Reference(t) => &t.elem, + _ => panic!(""), // should not be reachable, as we checked mutability before + }; + + let pat_name = match &*t.pat { + Pat::Ident(x) => &x.ident, + _ => panic!(""), + }; + + let name = format_ident!("{}_{}", name, pat_name); + if mutable { parse_quote!(#name: &mut #inner) } else { parse_quote!(#name: &#inner) } + } + } +} + +pub(crate) fn adjoint_fnc(item: &DiffItem) -> TokenStream { + let mut res_inputs: Vec = Vec::new(); + let mut add_inputs: Vec = Vec::new(); + let out_type = match &item.primal.output { + ReturnType::Type(_, x) => Some(*x.clone()), + _ => None, + }; + + let mut outputs = if item.header.ret_act == Activity::Duplicated { + vec![out_type.clone().unwrap()] + } else { + vec![] + }; + + let PrimalSig { ident, inputs, .. } = &item.primal; + + for (input, activity) in inputs.iter().zip(item.params.iter()) { + res_inputs.push(input.clone()); + + match (item.header.mode, activity, is_ref_mut(&input)) { + (Mode::Forward, Activity::Duplicated|Activity::DuplicatedNoNeed, Some(true)) => { + res_inputs.push(as_ref_mut(&input, "grad", true)); + add_inputs.push(as_ref_mut(&input, "grad", true)); + } + (Mode::Forward, Activity::Duplicated|Activity::DuplicatedNoNeed, Some(false)) => { + res_inputs.push(as_ref_mut(&input, "dual", false)); + add_inputs.push(as_ref_mut(&input, "dual", false)); + out_type.clone().map(|x| outputs.push(x)); + } + (Mode::Forward, Activity::Duplicated, None) => outputs.push(only_type(&input)), + (Mode::Reverse, Activity::Duplicated, Some(false)) => { + res_inputs.push(as_ref_mut(&input, "grad", true)); + add_inputs.push(as_ref_mut(&input, "grad", true)); + } + (Mode::Reverse, Activity::Duplicated | Activity::DuplicatedNoNeed, Some(true)) => { + res_inputs.push(as_ref_mut(&input, "grad", false)); + add_inputs.push(as_ref_mut(&input, "grad", false)); + } + (Mode::Reverse, Activity::Active, None) => outputs.push(only_type(&input)), + _ => {} + } + } + + match (item.header.mode, item.header.ret_act) { + (Mode::Reverse, Activity::Active) => { + let t: FnArg = match &item.primal.output { + ReturnType::Type(_, ty) => parse_quote!(tang_y: #ty), + _ => panic!(""), + }; + res_inputs.push(t.clone()); + add_inputs.push(t); + } + _ => {} + } + + // for adjoint function -> take header if primal + // -> take ident of primal function + let adjoint_ident = if item.block.is_some() { + if let Some(ident) = item.header.name.get_ident() { + ident.clone() + } else { + abort!( + item.header.name, + "not a function name"; + help = "`#[autodiff]` function name should be a single word instead of path" + ); + } + } else { + item.primal.ident.clone() + }; + + let output = match outputs.len() { + 0 => quote!(), + 1 => { + let output = outputs.first().unwrap(); + + quote!(-> #output) + } + _ => quote!(-> (#(#outputs,)*)), + }; + + let sig = quote!(fn #adjoint_ident(#(#res_inputs,)*) #output); + let inputs = inputs + .iter() + .map(|x| match x { + FnArg::Typed(ty) => { + let pat = &ty.pat; + quote!(#pat) + } + FnArg::Receiver(_) => quote!(self), + }) + .collect::>(); + let add_inputs = add_inputs + .iter() + .map(|x| match x { + FnArg::Typed(ty) => { + let pat = &ty.pat; + quote!(#pat) + } + FnArg::Receiver(_) => quote!(self), + }) + .collect::>(); + + let call_ident = match item.block.is_some() { + false => { + let ident = format_ident!("primal_{}", ident); + if item.header.name.segments.first().unwrap().ident == "Self" { + quote!(Self::#ident) + } else { + quote!(#ident) + } + } + true => quote!(#ident), + }; + + let body = quote!({ + std::hint::black_box((#call_ident(#(#inputs,)*), #(#add_inputs,)*)); + + std::hint::black_box(unsafe { std::mem::zeroed() }) + }); + let header = generate_header(&item); + + quote!( + #header + #sig + #body + ) +} diff --git a/library/autodiff/src/lib.rs b/library/autodiff/src/lib.rs new file mode 100644 index 0000000000000..b1d265fa9c59b --- /dev/null +++ b/library/autodiff/src/lib.rs @@ -0,0 +1,31 @@ +use proc_macro::TokenStream; +use proc_macro_error::proc_macro_error; +use quote::quote; + +mod gen; +mod parser; + +#[proc_macro_attribute] +#[proc_macro_error] +pub fn autodiff(args: TokenStream, input: TokenStream) -> TokenStream { + let mut params = parser::parse(args.into(), input.clone().into()); + let (primal, adjoint) = (gen::primal_fnc(&mut params), gen::adjoint_fnc(¶ms)); + + let res = quote!( + #primal + #adjoint + ); + + res.into() +} + +#[test] +pub fn expanding() { + macrotest::expand("tests/expand/*.rs"); +} + +#[test] +fn ui() { + let t = trybuild::TestCases::new(); + t.compile_fail("tests/ui/*.rs"); +} diff --git a/library/autodiff/src/parser.rs b/library/autodiff/src/parser.rs new file mode 100644 index 0000000000000..d11eea24d5015 --- /dev/null +++ b/library/autodiff/src/parser.rs @@ -0,0 +1,464 @@ +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::{format_ident, quote}; +use syn::{ + parse::Parser, parse_quote, punctuated::Punctuated, Attribute, Block, FnArg, ForeignItemFn, + Ident, Item, Path, ReturnType, Signature, Token, Type, +}; + +#[derive(Debug)] +pub struct PrimalSig { + pub(crate) ident: Ident, + pub(crate) inputs: Vec, + pub(crate) output: ReturnType, +} + +#[derive(Debug)] +pub struct DiffItem { + pub(crate) header: Header, + pub(crate) params: Vec, + pub(crate) primal: PrimalSig, + pub(crate) block: Option>, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub(crate) enum Mode { + Forward, + Reverse, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub(crate) enum Activity { + Const, + Active, + Duplicated, + DuplicatedNoNeed, +} + +impl Activity { + fn from_header(name: Option<&Ident>) -> Activity { + if name.is_none() { + return Activity::Const; + } + + match name.unwrap().to_string().as_str() { + "Const" => Activity::Const, + "Active" => Activity::Active, + "Duplicated" => Activity::Duplicated, + "DuplicatedNoNeed" => Activity::DuplicatedNoNeed, + _ => { + abort!( + name, + "unknown activity"; + help = "`#[autodiff]` should use activities (Const|Active|Duplicated|DuplicatedNoNeed)" + ); + } + } + } + + fn from_inline(name: Attribute) -> Activity { + let name = name.path.segments.first().unwrap(); + match name.ident.to_string().as_str() { + "const" => Activity::Const, + "active" => Activity::Active, + "dup" => Activity::Duplicated, + "dup_noneed" => Activity::DuplicatedNoNeed, + _ => { + abort!( + name, + "unknown activity"; + help = "`#[autodiff]` should use activities (const|active|dup|dup_noneed)" + ); + } + } + } + + pub(crate) fn to_ident(&self) -> Ident { + format_ident!( + "{}", + match self { + Activity::Const => "Const", + Activity::Active => "Active", + Activity::Duplicated => "Duplicated", + Activity::DuplicatedNoNeed => "DuplicatedNoNeed", + } + ) + } +} + +#[derive(Debug)] +pub(crate) struct Header { + pub name: Path, + pub mode: Mode, + pub ret_act: Activity, +} + +impl Header { + fn from_params(name: &Path, mode: Option<&Ident>, ret_activity: Option<&Ident>) -> Self { + // parse mode and return activity + let mode = mode + .map(|x| match x.to_string().as_str() { + "forward" | "Forward" => Mode::Forward, + "reverse" | "Reverse" => Mode::Reverse, + _ => { + abort!( + mode, + "should be forward or reverse"; + help = "`#[autodiff]` modes should be either forward or reverse" + ); + } + }) + .unwrap_or(Mode::Forward); + let ret_act = Activity::from_header(ret_activity); + + // check for invalid mode and return activity combinations + match (mode, ret_act) { + (Mode::Forward, Activity::Active) => abort!( + ret_activity, + "active return for forward mode"; + help = "`#[autodiff]` return should be Const, Duplicated or DuplicatedNoNeed in forward mode" + ), + (Mode::Reverse, Activity::Duplicated | Activity::DuplicatedNoNeed) => abort!( + ret_activity, + "duplicated return for reverse mode"; + help = "`#[autodiff]` return should be Const or Active in reverse mode" + ), + + _ => {} + } + + Header { name: name.clone(), mode, ret_act } + } + + fn parse(args: TokenStream) -> (Header, Vec) { + let args_parsed: Vec<_> = + match Punctuated::::parse_terminated.parse(args.clone().into()) { + Ok(x) => x.into_iter().collect(), + Err(_) => abort!( + args, + "duplicated return for reverse mode"; + help = "`#[autodiff]` return should be Const or Active in reverse mode" + ), + }; + + match &args_parsed[..] { + [name] => (Self::from_params(&name, None, None), vec![]), + [name, mode] => { + (Self::from_params(&name, Some(&mode.get_ident().unwrap()), None), vec![]) + } + [name, mode, ret_act, rem @ ..] => { + let params = Self::from_params( + &name, + Some(&mode.get_ident().unwrap()), + Some(&ret_act.get_ident().unwrap()), + ); + let rem = rem.into_iter() + .map(|x| x.get_ident().unwrap()) + .map(|x| Activity::from_header(Some(x))) + .map(|x| match (params.mode, x) { + (Mode::Forward, Activity::Active) => { + abort!( + args, + "active argument in forward mode"; + help = "`#[autodiff]` forward mode should be either Const, Duplicated" + ); + }, + (_, x) => x, + }) + .collect(); + + (params, rem) + } + _ => { + abort!( + args, + "please specify the autodiff function"; + help = "`#[autodiff]` needs a function name for primal or adjoint" + ); + } + } + } +} + +pub(crate) fn is_ref_mut(t: &FnArg) -> Option { + match t { + FnArg::Receiver(pat) => Some(pat.mutability.is_some()), + FnArg::Typed(pat) => match &*pat.ty { + Type::Reference(t) => Some(t.mutability.is_some()), + _ => None, + }, + } +} + +fn is_scalar(t: &Type) -> bool { + let t_f32: Type = parse_quote!(f32); + let t_f64: Type = parse_quote!(f64); + t == &t_f32 || t == &t_f64 +} + +fn ret_arg(arg: &FnArg) -> Type { + match arg { + FnArg::Receiver(_) => parse_quote!(Self), + FnArg::Typed(t) => match &*t.ty { + Type::Reference(t) => *t.elem.clone(), + x => x.clone(), + }, + } +} + +pub(crate) fn reduce_params( + mut sig: Signature, + header_acts: Vec, + is_adjoint: bool, + header: &Header, +) -> (PrimalSig, Vec) { + let mut args = Vec::new(); + let mut ret = Vec::new(); + let mut acts = Vec::new(); + let mut last_arg: Option = None; + + let mut arg_it = sig.inputs.iter_mut(); + let mut header_acts_it = header_acts.iter(); + + while let Some(arg) = arg_it.next() { + // Compare current with last argument when parsing duplicated rules. This only + // happens when we parse the signature of adjoint/augmented primal function + if let Some(prev_arg) = last_arg.take() { + match (header.mode, is_ref_mut(&prev_arg), is_ref_mut(&arg)) { + (Mode::Forward, Some(false), Some(true) | None) => abort!( + arg, + "should be an immutable reference"; + help = "`#[autodiff]` input parameter should duplicate tangent into second parameter for forward mode" + ), + (Mode::Forward, Some(true), Some(false) | None) => abort!( + arg, + "should be a mutable reference"; + help = "`#[autodiff]` output parameter should duplicate derivative into second parameter for forward mode" + ), + (Mode::Reverse, Some(false), Some(false) | None) => abort!( + arg, + "should be a mutable reference"; + help = "`#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode" + ), + (Mode::Reverse, Some(true), Some(true) | None) => abort!( + arg, + "should be an immutable reference"; + help = "`#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode" + ), + _ => {} + } + + continue; + } + + // parse current attribute macro + let attrs: Vec<_> = match arg { + FnArg::Typed(pat) => pat.attrs.drain(..).collect(), + FnArg::Receiver(pat) => pat.attrs.drain(..).collect(), + }; + let attr = attrs.first(); + let act: Activity = match (header_acts.is_empty(), attr) { + (false, None) => header_acts_it.next().map(|x| *x).unwrap_or(Activity::Const), + (true, Some(x)) => Activity::from_inline(x.clone()), + (true, None) => Activity::Const, + _ => { + abort!( + arg, + "inline activity"; + help = "`#[autodiff]` should have activities either specified in header or as inline attributes" + ); + } + }; + + // compare indirection with activity + match (header.mode, is_ref_mut(&arg), act) { + (Mode::Forward, None, Activity::Duplicated) => abort!( + arg, + "type not behind reference"; + help = "`#[autodiff]` duplicated types should be behind a reference" + ), + (Mode::Forward, Some(false), Activity::DuplicatedNoNeed) => abort!( + arg, + "should be mutable reference"; + help = "`#[autodiff]` parameter should be output for DuplicatedNoNeed activity" + ), + (Mode::Reverse, Some(_), Activity::Active) => abort!( + arg, + "type behind reference"; + help = "`#[autodiff]` active parameter should be concrete in reverse mode" + ), + (Mode::Reverse, None, Activity::Duplicated | Activity::DuplicatedNoNeed) => abort!( + arg, + "type not behind reference"; + help = "`#[autodiff]` duplicated parameters should be behind reference in reverse mode" + ), + (Mode::Reverse, Some(false), Activity::DuplicatedNoNeed) => abort!( + arg, + "use duplicated instead"; + help = "`#[autodiff]` input parameter cannot be declared as duplicatednoneed" + ), + (Mode::Forward, Some(false), Activity::Duplicated) + if header.ret_act != Activity::Const => + { + ret.push(ret_arg(&arg)) + } + (Mode::Reverse, None, Activity::Active) => ret.push(ret_arg(&arg)), + (Mode::Forward, Some(_), Activity::Duplicated | Activity::DuplicatedNoNeed) + | (Mode::Reverse, _, Activity::Duplicated | Activity::DuplicatedNoNeed) + if is_adjoint => + { + last_arg = Some(arg.clone()) + } + _ => {} + } + + args.push(arg.clone()); + acts.push(act); + } + + // if we have adjoint signature and are in forward mode + // if duplicated -> return type * (n + 1) times + // if duplicated_no_need -> return type * n times + // if const -> no return + + // if we have adjoint signature and are in reverse mode + // if active -> input type * n times + // construct return type based on mode + let ret = if is_adjoint { + let ret_typs = match &sig.output { + ReturnType::Type(_, ref x) => match &**x { + Type::Tuple(x) => x.elems.iter().cloned().collect(), + x => vec![x.clone()], + }, + ReturnType::Default => vec![], + }; + + match (header.mode, header.ret_act) { + (Mode::Forward, Activity::Duplicated) => { + let expected = ret_typs[0].clone(); + let list = vec![expected.clone(); ret.len() + 1]; + + if list != ret_typs { + let ret = quote!((#(#list,)*)); + abort!( + sig.output, + "invalid output"; + help = format!("`#[autodiff]` expected {}", ret) + ); + } + + parse_quote!(-> #expected) + } + (Mode::Forward, Activity::DuplicatedNoNeed) => { + let expected = ret_typs[0].clone(); + let list = vec![expected.clone(); ret.len()]; + + if list != ret_typs { + let ret = quote!((#(#list,)*)); + abort!( + sig.output, + "invalid output"; + help = format!("`#[autodiff]` expected {}", ret) + ); + } + + parse_quote!(-> #expected) + } + (Mode::Reverse, Activity::Active) => { + // tangent of output is latest in parameter list + let ret_typ = match (args.pop(), acts.pop()) { + (Some(x), Some(y)) => { + let x = ret_arg(&x); + if !is_scalar(&x) { + abort!( + x, + "output tangent not a floating point"; + help = "`#[autodiff]` the output tangent should be a floating point" + ); + } else if y != Activity::Const { + abort!( + x, + "output tangent not const"; + help = "`#[autodiff]` the last parameter of an adjoint with active return should be a constant tangent" + ); + } else { + parse_quote!(-> #x) + } + } + (None, None) => abort!( + sig, + "missing output tangent parameter"; + help = "`#[autodiff]` the last parameter of an adjoint with active return should exist" + ), + _ => unreachable!(), + }; + + // check that the return tuple confirms with return types + if ret_typs != ret { + let ret = quote!((#(#ret,)*)); + abort!( + sig.output, + "invalid output"; + help = format!("`#[autodiff]` expected {}", ret) + ) + } + + ret_typ + } + (_, Activity::Const) if ret.len() > 0 => { + abort!( + ret[0], + "constant return but more than one return"; + help = "`#[autodiff]` adjoint should have a return type when active" + ) + } + _ => ReturnType::Default, + } + } else { + if header.ret_act != Activity::Const && sig.output == ReturnType::Default { + abort!( + sig, + "no return type"; + help = "`#[autodiff]` non-const return activity but no return type" + ) + } + + sig.output.clone() + }; + + let sig = if is_adjoint { + // header is used for calling if we are adjoint + format_ident!("{}", sig.ident) + } else { + sig.ident.clone() + }; + + (PrimalSig { ident: sig, inputs: args, output: ret }, acts) +} + +pub(crate) fn parse(args: TokenStream, input: TokenStream) -> DiffItem { + // first parse function + let (_attrs, _, sig, block) = match syn::parse2::(input) { + Ok(Item::Fn(item)) => (item.attrs, item.vis, item.sig, Some(item.block)), + Ok(Item::Verbatim(x)) => match syn::parse2::(x) { + Ok(item) => (item.attrs, item.vis, item.sig, None), + Err(err) => panic!("Could not parse item {}", err), + }, + Ok(item) => { + abort!( + item, + "item is not a function"; + help = "`#[autodiff]` can only be used on primal or adjoint functions" + ) + } + Err(err) => panic!("Could not parse item: {}", err), + }; + + // then parse attributes + let (header, param_attrs) = Header::parse(args); + + // reduce parameters to primal parameter set + let (primal, params) = reduce_params(sig, param_attrs, !block.is_some(), &header); + + DiffItem { header, primal, params, block } +} diff --git a/library/autodiff/tests/expand/forward_duplicated.expanded.rs b/library/autodiff/tests/expand/forward_duplicated.expanded.rs new file mode 100644 index 0000000000000..bf3890154ab8e --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated.expanded.rs @@ -0,0 +1,10 @@ +use autodiff::autodiff; +#[autodiff_into] +fn square(a: &Vec, b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} +#[autodiff_into(Forward, Const, Duplicated, Duplicated)] +fn d_square(a: &Vec, dual_a: &Vec, b: &mut f32, grad_b: &mut f32) { + std::hint::black_box((square(a, b), dual_a, grad_b)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/forward_duplicated.rs b/library/autodiff/tests/expand/forward_duplicated.rs new file mode 100644 index 0000000000000..9a0bfc6c13a47 --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_square, Forward, Const)] +fn square(#[dup] a: &Vec, #[dup] b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} diff --git a/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs b/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs new file mode 100644 index 0000000000000..a3754de7ab70b --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs @@ -0,0 +1,15 @@ +use autodiff::autodiff; +#[autodiff_into] +fn square2(a: &Vec, b: &Vec) -> f32 { + a.into_iter().map(f32::square).sum() +} +#[autodiff_into(Forward, Duplicated, Duplicated, Duplicated)] +fn d_square2( + a: &Vec, + dual_a: &Vec, + b: &Vec, + dual_b: &Vec, +) -> (f32, f32, f32) { + std::hint::black_box((square2(a, b), dual_a, dual_b)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/forward_duplicated_return.rs b/library/autodiff/tests/expand/forward_duplicated_return.rs new file mode 100644 index 0000000000000..3397e5309ea96 --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated_return.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_square2, Forward, Duplicated)] +fn square2(#[dup] a: &Vec, #[dup] b: &Vec) -> f32 { + a.into_iter().map(f32::square).sum() +} diff --git a/library/autodiff/tests/expand/reverse_duplicated.expanded.rs b/library/autodiff/tests/expand/reverse_duplicated.expanded.rs new file mode 100644 index 0000000000000..60c0d7f2f696b --- /dev/null +++ b/library/autodiff/tests/expand/reverse_duplicated.expanded.rs @@ -0,0 +1,10 @@ +use autodiff::autodiff; +#[autodiff_into] +fn square(a: &Vec, b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} +#[autodiff_into(Reverse, Const, Duplicated, Duplicated)] +fn d_square(a: &Vec, grad_a: &mut Vec, b: &mut f32, grad_b: &f32) { + std::hint::black_box((square(a, b), grad_a, grad_b)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/reverse_duplicated.rs b/library/autodiff/tests/expand/reverse_duplicated.rs new file mode 100644 index 0000000000000..107a708bec848 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_duplicated.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_square, Reverse, Const)] +fn square(#[dup] a: &Vec, #[dup] b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} diff --git a/library/autodiff/tests/expand/reverse_return_array.expanded.rs b/library/autodiff/tests/expand/reverse_return_array.expanded.rs new file mode 100644 index 0000000000000..5b784157fea7b --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_array.expanded.rs @@ -0,0 +1,10 @@ +use autodiff::autodiff; +#[autodiff_into] +fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 { + arr[0][0][0] * arr[1][1][1] +} +#[autodiff_into(Reverse, Active, Duplicated)] +fn d_array(arr: &[[[f32; 2]; 2]; 2], grad_arr: &mut [[[f32; 2]; 2]; 2], tang_y: f32) { + std::hint::black_box((array(arr), grad_arr, tang_y)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/reverse_return_array.rs b/library/autodiff/tests/expand/reverse_return_array.rs new file mode 100644 index 0000000000000..da080a6b3a860 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_array.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_array, Reverse, Active)] +fn array(#[dup] arr: &[[[f32; 2]; 2]; 2]) -> f32 { + arr[0][0][0] * arr[1][1][1] +} diff --git a/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs b/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs new file mode 100644 index 0000000000000..f49864fb7e9b9 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs @@ -0,0 +1,17 @@ +use autodiff::autodiff; +#[autodiff_into] +fn sqrt(a: f32, b: &f32, c: &f32, d: f32) -> f32 { + a * (b * b + c * c * d * d).sqrt() +} +#[autodiff_into(Reverse, Active, Active, Duplicated, Const, Active)] +fn d_sqrt( + a: f32, + b: &f32, + grad_b: &mut f32, + c: &f32, + d: f32, + tang_y: f32, +) -> (f32, f32) { + std::hint::black_box((sqrt(a, b, c, d), grad_b, tang_y)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/reverse_return_mixed.rs b/library/autodiff/tests/expand/reverse_return_mixed.rs new file mode 100644 index 0000000000000..3260c3560d523 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_mixed.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sqrt, Reverse, Active)] +fn sqrt(#[active] a: f32, #[dup] b: &f32, c: &f32, #[active] d: f32) -> f32 { + a * (b * b + c*c*d*d).sqrt() +} diff --git a/library/autodiff/tests/ui/active_in_forward_mode.rs b/library/autodiff/tests/ui/active_in_forward_mode.rs new file mode 100644 index 0000000000000..10366b1b422b8 --- /dev/null +++ b/library/autodiff/tests/ui/active_in_forward_mode.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Forward, DuplicatedNoNeed, Active)] +fn sin(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/active_in_forward_mode.stderr b/library/autodiff/tests/ui/active_in_forward_mode.stderr new file mode 100644 index 0000000000000..cd413564068ae --- /dev/null +++ b/library/autodiff/tests/ui/active_in_forward_mode.stderr @@ -0,0 +1,7 @@ +error: active argument in forward mode + --> tests/ui/active_in_forward_mode.rs:3:12 + | +3 | #[autodiff(d_sin, Forward, DuplicatedNoNeed, Active)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` forward mode should be either Const, Duplicated diff --git a/library/autodiff/tests/ui/activities_inline_and_header.rs b/library/autodiff/tests/ui/activities_inline_and_header.rs new file mode 100644 index 0000000000000..1ecf37ec60a8f --- /dev/null +++ b/library/autodiff/tests/ui/activities_inline_and_header.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Active, Active)] +fn sin(#[active] x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/activities_inline_and_header.stderr b/library/autodiff/tests/ui/activities_inline_and_header.stderr new file mode 100644 index 0000000000000..b4d50d02a26a4 --- /dev/null +++ b/library/autodiff/tests/ui/activities_inline_and_header.stderr @@ -0,0 +1,7 @@ +error: inline activity + --> tests/ui/activities_inline_and_header.rs:4:18 + | +4 | fn sin(#[active] x: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` should have activities either specified in header or as inline attributes diff --git a/library/autodiff/tests/ui/invalid_indirection.rs b/library/autodiff/tests/ui/invalid_indirection.rs new file mode 100644 index 0000000000000..627a7cb0fc6f9 --- /dev/null +++ b/library/autodiff/tests/ui/invalid_indirection.rs @@ -0,0 +1,19 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Const)] +fn duplicated_without_reference(#[dup] x: f32) { +} + +#[autodiff(d_sin, Reverse, Const)] +fn active_with_reference(#[active] x: &f32) { +} + +#[autodiff(d_sin, Forward, Const)] +fn duplicated_forward(#[dup] x: f32) { +} + +#[autodiff(d_sin, Forward, Const)] +fn duplicated_no_need_forward(#[dup_noneed] x: &f32) { +} + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_indirection.stderr b/library/autodiff/tests/ui/invalid_indirection.stderr new file mode 100644 index 0000000000000..cb27c542018e5 --- /dev/null +++ b/library/autodiff/tests/ui/invalid_indirection.stderr @@ -0,0 +1,31 @@ +error: type not behind reference + --> tests/ui/invalid_indirection.rs:4:40 + | +4 | fn duplicated_without_reference(#[dup] x: f32) { + | ^^^^^^ + | + = help: `#[autodiff]` duplicated parameters should be behind reference in reverse mode + +error: type behind reference + --> tests/ui/invalid_indirection.rs:8:36 + | +8 | fn active_with_reference(#[active] x: &f32) { + | ^^^^^^^ + | + = help: `#[autodiff]` active parameter should be concrete in reverse mode + +error: type not behind reference + --> tests/ui/invalid_indirection.rs:12:30 + | +12 | fn duplicated_forward(#[dup] x: f32) { + | ^^^^^^ + | + = help: `#[autodiff]` duplicated types should be behind a reference + +error: should be mutable reference + --> tests/ui/invalid_indirection.rs:16:45 + | +16 | fn duplicated_no_need_forward(#[dup_noneed] x: &f32) { + | ^^^^^^^ + | + = help: `#[autodiff]` parameter should be output for DuplicatedNoNeed activity diff --git a/library/autodiff/tests/ui/invalid_mutability_pairs.rs b/library/autodiff/tests/ui/invalid_mutability_pairs.rs new file mode 100644 index 0000000000000..708ecc597a5be --- /dev/null +++ b/library/autodiff/tests/ui/invalid_mutability_pairs.rs @@ -0,0 +1,24 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Forward, Duplicated)] +fn fwd_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + +#[autodiff(d_sin, Forward, Duplicated)] +fn output_immutable(#[dup] x: &mut f32, y: &f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn rev_input_no_reference(#[dup] x: &f32, y: f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn rev_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn input_immutable(#[dup] x: &f32, y: &f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn output_mutable(#[dup] x: &mut f32, y: &mut f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn dupnoneed_input(#[dup_noneed] x: &f32, y: &f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_mutability_pairs.stderr b/library/autodiff/tests/ui/invalid_mutability_pairs.stderr new file mode 100644 index 0000000000000..37af0c2ad52ee --- /dev/null +++ b/library/autodiff/tests/ui/invalid_mutability_pairs.stderr @@ -0,0 +1,55 @@ +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:4:48 + | +4 | fn fwd_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` output parameter should duplicate derivative into second parameter for forward mode + +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:7:41 + | +7 | fn output_immutable(#[dup] x: &mut f32, y: &f32) -> f32; + | ^^^^^^^ + | + = help: `#[autodiff]` output parameter should duplicate derivative into second parameter for forward mode + +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:10:43 + | +10 | fn rev_input_no_reference(#[dup] x: &f32, y: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: should be an immutable reference + --> tests/ui/invalid_mutability_pairs.rs:13:48 + | +13 | fn rev_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:16:36 + | +16 | fn input_immutable(#[dup] x: &f32, y: &f32) -> f32; + | ^^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: should be an immutable reference + --> tests/ui/invalid_mutability_pairs.rs:19:39 + | +19 | fn output_mutable(#[dup] x: &mut f32, y: &mut f32) -> f32; + | ^^^^^^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: use duplicated instead + --> tests/ui/invalid_mutability_pairs.rs:22:34 + | +22 | fn dupnoneed_input(#[dup_noneed] x: &f32, y: &f32) -> f32; + | ^^^^^^^ + | + = help: `#[autodiff]` input parameter cannot be declared as duplicatednoneed diff --git a/library/autodiff/tests/ui/invalid_return.rs b/library/autodiff/tests/ui/invalid_return.rs new file mode 100644 index 0000000000000..b3c8bce1166bf --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return.rs @@ -0,0 +1,12 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Forward, Active)] +fn sin1(x: f32) -> f32; + +#[autodiff(d_sin, Reverse, Duplicated)] +fn sin2(x: f32) -> f32; + +#[autodiff(d_sin, Reverse, DuplicatedNoNeed)] +fn sin3(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_return.stderr b/library/autodiff/tests/ui/invalid_return.stderr new file mode 100644 index 0000000000000..4ddaccdba0f72 --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return.stderr @@ -0,0 +1,23 @@ +error: active return for forward mode + --> tests/ui/invalid_return.rs:3:28 + | +3 | #[autodiff(d_sin, Forward, Active)] + | ^^^^^^ + | + = help: `#[autodiff]` return should be Const, Duplicated or DuplicatedNoNeed in forward mode + +error: duplicated return for reverse mode + --> tests/ui/invalid_return.rs:6:28 + | +6 | #[autodiff(d_sin, Reverse, Duplicated)] + | ^^^^^^^^^^ + | + = help: `#[autodiff]` return should be Const or Active in reverse mode + +error: duplicated return for reverse mode + --> tests/ui/invalid_return.rs:9:28 + | +9 | #[autodiff(d_sin, Reverse, DuplicatedNoNeed)] + | ^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` return should be Const or Active in reverse mode diff --git a/library/autodiff/tests/ui/invalid_return_type.rs b/library/autodiff/tests/ui/invalid_return_type.rs new file mode 100644 index 0000000000000..7b91ccd2d650a --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return_type.rs @@ -0,0 +1,16 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Active)] +fn active_but_no_return(#[active] x: f32) { +} + +#[autodiff(d_sin, Reverse, Active)] +fn invalid_primal_value(#[active] x: f32, #[active] y: Vec, #[active] z: Tensor, y_tang: f32) -> (i32, f32); + +#[autodiff(d_sin, Forward, Duplicated)] +fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32, f32); + +#[autodiff(d_sin, Forward, DuplicatedNoNeed)] +fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32); + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_return_type.stderr b/library/autodiff/tests/ui/invalid_return_type.stderr new file mode 100644 index 0000000000000..90e5e47a2a33d --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return_type.stderr @@ -0,0 +1,31 @@ +error: no return type + --> tests/ui/invalid_return_type.rs:4:1 + | +4 | fn active_but_no_return(#[active] x: f32) { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` non-const return activity but no return type + +error: invalid output + --> tests/ui/invalid_return_type.rs:8:100 + | +8 | fn invalid_primal_value(#[active] x: f32, #[active] y: Vec, #[active] z: Tensor, y_tang: f32) -> (i32, f32); + | ^^^^^^^^^^^^^ + | + = help: `#[autodiff]` expected (f32, Vec < f32 >, Tensor,) + +error: invalid output + --> tests/ui/invalid_return_type.rs:11:121 + | +11 | fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32, f32); + | ^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` expected (f32, f32, f32, f32,) + +error: invalid output + --> tests/ui/invalid_return_type.rs:14:121 + | +14 | fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32); + | ^^^^^^^^^^^^^ + | + = help: `#[autodiff]` expected (f32, f32, f32,) diff --git a/library/autodiff/tests/ui/no_function_name.rs b/library/autodiff/tests/ui/no_function_name.rs new file mode 100644 index 0000000000000..8222ca4aaf37d --- /dev/null +++ b/library/autodiff/tests/ui/no_function_name.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff] +fn sin(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/no_function_name.stderr b/library/autodiff/tests/ui/no_function_name.stderr new file mode 100644 index 0000000000000..e98add3164c9f --- /dev/null +++ b/library/autodiff/tests/ui/no_function_name.stderr @@ -0,0 +1,8 @@ +error: please specify the autodiff function + --> tests/ui/no_function_name.rs:3:1 + | +3 | #[autodiff] + | ^^^^^^^^^^^ + | + = help: `#[autodiff]` needs a function name for primal or adjoint + = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/library/autodiff/tests/ui/not_a_function.rs b/library/autodiff/tests/ui/not_a_function.rs new file mode 100644 index 0000000000000..0a3c11725a086 --- /dev/null +++ b/library/autodiff/tests/ui/not_a_function.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff] +struct NotAFunction; + +fn main() {} diff --git a/library/autodiff/tests/ui/not_a_function.stderr b/library/autodiff/tests/ui/not_a_function.stderr new file mode 100644 index 0000000000000..c681841532a5e --- /dev/null +++ b/library/autodiff/tests/ui/not_a_function.stderr @@ -0,0 +1,7 @@ +error: item is not a function + --> tests/ui/not_a_function.rs:4:1 + | +4 | struct NotAFunction; + | ^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` can only be used on primal or adjoint functions diff --git a/library/autodiff/tests/ui/reverse_tangent.rs b/library/autodiff/tests/ui/reverse_tangent.rs new file mode 100644 index 0000000000000..603f7fd1789ce --- /dev/null +++ b/library/autodiff/tests/ui/reverse_tangent.rs @@ -0,0 +1,12 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Active)] +fn invalid_output_tangent_type(#[active] x: f32, y_tang: i32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn active_output_tangent(#[active] x: f32, #[active] y_tang: f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn tangent_missing() -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/reverse_tangent.stderr b/library/autodiff/tests/ui/reverse_tangent.stderr new file mode 100644 index 0000000000000..a7b4b6e3d97d6 --- /dev/null +++ b/library/autodiff/tests/ui/reverse_tangent.stderr @@ -0,0 +1,23 @@ +error: output tangent not a floating point + --> tests/ui/reverse_tangent.rs:4:58 + | +4 | fn invalid_output_tangent_type(#[active] x: f32, y_tang: i32) -> f32; + | ^^^ + | + = help: `#[autodiff]` the output tangent should be a floating point + +error: output tangent not const + --> tests/ui/reverse_tangent.rs:7:62 + | +7 | fn active_output_tangent(#[active] x: f32, #[active] y_tang: f32) -> f32; + | ^^^ + | + = help: `#[autodiff]` the last parameter of an adjoint with active return should be a constant tangent + +error: missing output tangent parameter + --> tests/ui/reverse_tangent.rs:10:1 + | +10 | fn tangent_missing() -> f32; + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` the last parameter of an adjoint with active return should exist diff --git a/library/autodiff/tests/ui/wrong_mode.rs b/library/autodiff/tests/ui/wrong_mode.rs new file mode 100644 index 0000000000000..1b500711de109 --- /dev/null +++ b/library/autodiff/tests/ui/wrong_mode.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, WrongMode)] +fn sin(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/wrong_mode.stderr b/library/autodiff/tests/ui/wrong_mode.stderr new file mode 100644 index 0000000000000..ca18d81abb306 --- /dev/null +++ b/library/autodiff/tests/ui/wrong_mode.stderr @@ -0,0 +1,7 @@ +error: should be forward or reverse + --> tests/ui/wrong_mode.rs:3:19 + | +3 | #[autodiff(d_sin, WrongMode)] + | ^^^^^^^^^ + | + = help: `#[autodiff]` modes should be either forward or reverse diff --git a/library/core/src/macros/mod.rs b/library/core/src/macros/mod.rs index 125a6f57bfbaa..658b640ae4ce3 100644 --- a/library/core/src/macros/mod.rs +++ b/library/core/src/macros/mod.rs @@ -1416,6 +1416,18 @@ pub(crate) mod builtin { }; } + /// Differentiate function + ///#[unstable( + /// feature = "autodiff", + /// issue = "29598", + /// reason = "autodiff is not stable enough" + ///)] + ///#[rustc_builtin_macro] + ///#[macro_export] + ///pub macro autodiff($item:item) { + /// /* compiler built-in */ + ///} + /// Parses a file as an expression or an item according to the context. /// /// **Warning**: For multi-file Rust projects, the `include!` macro is probably not what you diff --git a/src/bootstrap/configure.py b/src/bootstrap/configure.py index bfef3e672407d..6369a1a557a8f 100755 --- a/src/bootstrap/configure.py +++ b/src/bootstrap/configure.py @@ -70,6 +70,7 @@ def v(*args): # channel, etc. o("optimize-llvm", "llvm.optimize", "build optimized LLVM") o("llvm-assertions", "llvm.assertions", "build LLVM with assertions") +o("llvm-enzyme", "llvm.enzyme", "build LLVM with Enzyme") o("llvm-plugins", "llvm.plugins", "build LLVM with plugin interface") o("debug-assertions", "rust.debug-assertions", "build with debugging assertions") o("debug-assertions-std", "rust.debug-assertions-std", "build the standard library with debugging assertions") diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index 441931e415cc6..7a53c4caffe6d 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1539,6 +1539,7 @@ pub struct Assemble { pub target_compiler: Compiler, } +#[allow(unreachable_code)] impl Step for Assemble { type Output = Compiler; const ONLY_HOSTS: bool = true; @@ -1599,6 +1600,24 @@ impl Step for Assemble { return target_compiler; } + // Build enzyme + let enzyme_install = if builder.config.llvm_enzyme { + Some(builder.ensure(llvm::Enzyme { target: build_compiler.host })) + } else { + None + }; + + if let Some(enzyme_install) = enzyme_install { + let src_lib = enzyme_install.join("build/Enzyme/LLVMEnzyme-16.so"); + + let libdir = builder.sysroot_libdir(build_compiler, build_compiler.host); + let target_libdir = builder.sysroot_libdir(target_compiler, target_compiler.host); + let dst_lib = libdir.join("libLLVMEnzyme-16.so"); + let target_dst_lib = target_libdir.join("libLLVMEnzyme-16.so"); + builder.copy(&src_lib, &dst_lib); + builder.copy(&src_lib, &target_dst_lib); + } + // Build the libraries for this compiler to link to (i.e., the libraries // it uses at runtime). NOTE: Crates the target compiler compiles don't // link to these. (FIXME: Is that correct? It seems to be correct most diff --git a/src/bootstrap/src/core/build_steps/llvm.rs b/src/bootstrap/src/core/build_steps/llvm.rs index 24351118a5aa1..11e377be92e24 100644 --- a/src/bootstrap/src/core/build_steps/llvm.rs +++ b/src/bootstrap/src/core/build_steps/llvm.rs @@ -802,6 +802,72 @@ fn get_var(var_base: &str, host: &str, target: &str) -> Option { .or_else(|| env::var_os(var_base)) } +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct Enzyme { + pub target: TargetSelection, +} + +impl Step for Enzyme { + type Output = PathBuf; + const ONLY_HOSTS: bool = true; + + fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { + run.path("src/tools/enzyme/enzyme") + } + + fn make_run(run: RunConfig<'_>) { + run.builder.ensure(Enzyme { target: run.target }); + } + + /// Compile Enzyme for `target`. + fn run(self, builder: &Builder<'_>) -> PathBuf { + if builder.config.dry_run() { + let out_dir = builder.enzyme_out(self.target); + return out_dir; + } + let target = self.target; + + let LlvmResult { llvm_config, .. } = builder.ensure(Llvm { target: self.target }); + + let out_dir = builder.enzyme_out(target); + let done_stamp = out_dir.join("enzyme-finished-building"); + if done_stamp.exists() { + return out_dir; + } + + builder.info(&format!("Building Enzyme for {}", target)); + let _time = helpers::timeit(&builder); + t!(fs::create_dir_all(&out_dir)); + + builder.update_submodule(&Path::new("src").join("tools").join("enzyme")); + let mut cfg = cmake::Config::new(builder.src.join("src/tools/enzyme/enzyme/")); + // TODO: Find a nicer way to use Enzyme Debug builds + //cfg.profile("Debug"); + //cfg.define("CMAKE_BUILD_TYPE", "Debug"); + configure_cmake(builder, target, &mut cfg, true, LdFlags::default(), &[]); + + // Re-use the same flags as llvm to control the level of debug information + // generated for lld. + let profile = match (builder.config.llvm_optimize, builder.config.llvm_release_debuginfo) { + (false, _) => "Debug", + (true, false) => "Release", + (true, true) => "RelWithDebInfo", + }; + + cfg.out_dir(&out_dir) + .profile(profile) + .env("LLVM_CONFIG_REAL", &llvm_config) + .define("LLVM_ENABLE_ASSERTIONS", "ON") + .define("ENZYME_EXTERNAL_SHARED_LIB", "OFF") + .define("LLVM_DIR", builder.llvm_out(target)); + + cfg.build(); + + t!(File::create(&done_stamp)); + out_dir + } +} + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] pub struct Lld { pub target: TargetSelection, diff --git a/src/bootstrap/src/core/builder.rs b/src/bootstrap/src/core/builder.rs index 90e09d12a9d50..65c94281e00d4 100644 --- a/src/bootstrap/src/core/builder.rs +++ b/src/bootstrap/src/core/builder.rs @@ -1371,6 +1371,11 @@ impl<'a> Builder<'a> { } } + // TODO: adjust -14 ending for Enzyme + // https://rust-lang.zulipchat.com/#narrow/stream/182449-t-compiler.2Fhelp/topic/.E2.9C.94.20link.20new.20library.20into.20stage1.2Frustc + rustflags.arg("-l"); + rustflags.arg("LLVMEnzyme-16"); + let use_new_symbol_mangling = match self.config.rust_new_symbol_mangling { Some(setting) => { // If an explicit setting is given, use that diff --git a/src/bootstrap/src/core/config/config.rs b/src/bootstrap/src/core/config/config.rs index 5b5334b0a5572..a49835b93b662 100644 --- a/src/bootstrap/src/core/config/config.rs +++ b/src/bootstrap/src/core/config/config.rs @@ -173,6 +173,8 @@ pub struct Config { // llvm codegen options pub llvm_assertions: bool, pub llvm_tests: bool, + pub llvm_enzyme: bool, + pub llvm_enzyme_build: Option, pub llvm_plugins: bool, pub llvm_optimize: bool, pub llvm_thin_lto: bool, @@ -676,24 +678,24 @@ macro_rules! define_config { A: serde::de::MapAccess<'de>, { $(let mut $field: Option<$field_ty> = None;)* - while let Some(key) = - match serde::de::MapAccess::next_key::(&mut map) { - Ok(val) => val, - Err(err) => { - return Err(err); + while let Some(key) = + match serde::de::MapAccess::next_key::(&mut map) { + Ok(val) => val, + Err(err) => { + return Err(err); + } } - } { match &*key { $($field_key => { if $field.is_some() { return Err(::duplicate_field( - $field_key, - )); + $field_key, + )); } $field = match serde::de::MapAccess::next_value::<$field_ty>( &mut map, - ) { + ) { Ok(val) => Some(val), Err(err) => { return Err(err); @@ -823,6 +825,7 @@ define_config! { release_debuginfo: Option = "release-debuginfo", assertions: Option = "assertions", tests: Option = "tests", + enzyme: Option = "enzyme", plugins: Option = "plugins", ccache: Option = "ccache", static_libstdcpp: Option = "static-libstdcpp", @@ -1356,6 +1359,7 @@ impl Config { // we'll infer default values for them later let mut llvm_assertions = None; let mut llvm_tests = None; + let mut llvm_enzyme = None; let mut llvm_plugins = None; let mut debug = None; let mut debug_assertions = None; @@ -1500,6 +1504,7 @@ impl Config { set(&mut config.ninja_in_file, llvm.ninja); llvm_assertions = llvm.assertions; llvm_tests = llvm.tests; + llvm_enzyme = llvm.enzyme; llvm_plugins = llvm.plugins; set(&mut config.llvm_optimize, llvm.optimize); set(&mut config.llvm_thin_lto, llvm.thin_lto); @@ -1565,6 +1570,7 @@ impl Config { check_ci_llvm!(llvm.polly); check_ci_llvm!(llvm.clang); check_ci_llvm!(llvm.build_config); + check_ci_llvm!(llvm.enzyme); check_ci_llvm!(llvm.plugins); } @@ -1658,6 +1664,7 @@ impl Config { config.llvm_assertions = llvm_assertions.unwrap_or(false); config.llvm_tests = llvm_tests.unwrap_or(false); + config.llvm_enzyme = llvm_enzyme.unwrap_or(false); config.llvm_plugins = llvm_plugins.unwrap_or(false); config.rust_optimize = optimize.unwrap_or(RustOptimize::Bool(true)); diff --git a/src/bootstrap/src/lib.rs b/src/bootstrap/src/lib.rs index d7f49a6d11b9c..ac74c632a1b1f 100644 --- a/src/bootstrap/src/lib.rs +++ b/src/bootstrap/src/lib.rs @@ -806,6 +806,10 @@ impl Build { self.out.join(&*target.triple).join("lld") } + fn enzyme_out(&self, target: TargetSelection) -> PathBuf { + self.out.join(&*target.triple).join("enzyme") + } + /// Output directory for all documentation for a target fn doc_out(&self, target: TargetSelection) -> PathBuf { self.out.join(&*target.triple).join("doc") diff --git a/src/test/ui/terminal-width/flag-human.rs b/src/test/ui/terminal-width/flag-human.rs new file mode 100644 index 0000000000000..4b94ebb01fc8e --- /dev/null +++ b/src/test/ui/terminal-width/flag-human.rs @@ -0,0 +1,9 @@ +// compile-flags: --diagnostic-width=20 + +// This test checks that `-Z diagnostic-width` effects the human error output by restricting it to an +// arbitrarily low value so that the effect is visible. + +fn main() { + let _: () = 42; + //~^ ERROR mismatched types +} diff --git a/src/test/ui/terminal-width/flag-json.rs b/src/test/ui/terminal-width/flag-json.rs new file mode 100644 index 0000000000000..3add1d7d9301e --- /dev/null +++ b/src/test/ui/terminal-width/flag-json.rs @@ -0,0 +1,9 @@ +// compile-flags: --diagnostic-width=20 --error-format=json + +// This test checks that `-Z diagnostic-width` effects the JSON error output by restricting it to an +// arbitrarily low value so that the effect is visible. + +fn main() { + let _: () = 42; + //~^ ERROR mismatched types +} diff --git a/src/test/ui/terminal-width/flag-json.stderr b/src/test/ui/terminal-width/flag-json.stderr new file mode 100644 index 0000000000000..b21391d1640ef --- /dev/null +++ b/src/test/ui/terminal-width/flag-json.stderr @@ -0,0 +1,40 @@ +{"message":"mismatched types","code":{"code":"E0308","explanation":"Expected type did not match the received type. + +Erroneous code examples: + +```compile_fail,E0308 +fn plus_one(x: i32) -> i32 { + x + 1 +} + +plus_one(\"Not a number\"); +// ^^^^^^^^^^^^^^ expected `i32`, found `&str` + +if \"Not a bool\" { +// ^^^^^^^^^^^^ expected `bool`, found `&str` +} + +let x: f32 = \"Not a float\"; +// --- ^^^^^^^^^^^^^ expected `f32`, found `&str` +// | +// expected due to this +``` + +This error occurs when an expression was used in a place where the compiler +expected an expression of a different type. It can occur in several cases, the +most common being when calling a function and passing an argument which has a +different type than the matching type in the function declaration. +"},"level":"error","spans":[{"file_name":"$DIR/flag-json.rs","byte_start":243,"byte_end":245,"line_start":7,"line_end":7,"column_start":17,"column_end":19,"is_primary":true,"text":[{"text":" let _: () = 42;","highlight_start":17,"highlight_end":19}],"label":"expected `()`, found integer","suggested_replacement":null,"suggestion_applicability":null,"expansion":null},{"file_name":"$DIR/flag-json.rs","byte_start":238,"byte_end":240,"line_start":7,"line_end":7,"column_start":12,"column_end":14,"is_primary":false,"text":[{"text":" let _: () = 42;","highlight_start":12,"highlight_end":14}],"label":"expected due to this","suggested_replacement":null,"suggestion_applicability":null,"expansion":null}],"children":[],"rendered":"error[E0308]: mismatched types + --> $DIR/flag-json.rs:7:17 + | +LL | ..._: () = 42; + | -- ^^ expected `()`, found integer + | | + | expected due to this + +"} +{"message":"aborting due to previous error","code":null,"level":"error","spans":[],"children":[],"rendered":"error: aborting due to previous error + +"} +{"message":"For more information about this error, try `rustc --explain E0308`.","code":null,"level":"failure-note","spans":[],"children":[],"rendered":"For more information about this error, try `rustc --explain E0308`. +"} diff --git a/src/tools/enzyme b/src/tools/enzyme new file mode 160000 index 0000000000000..86fc287c5a396 --- /dev/null +++ b/src/tools/enzyme @@ -0,0 +1 @@ +Subproject commit 86fc287c5a39632364af2c48bc3efb5ef1f6652d diff --git a/tests/rustdoc-ui/doctest/terminal-width.rs b/tests/rustdoc-ui/doctest/terminal-width.rs new file mode 100644 index 0000000000000..61961d5ec710e --- /dev/null +++ b/tests/rustdoc-ui/doctest/terminal-width.rs @@ -0,0 +1,5 @@ +// compile-flags: -Zunstable-options --diagnostic-width=10 +#![deny(rustdoc::bare_urls)] + +/// This is a long line that contains a http://link.com +pub struct Foo; //~^ ERROR diff --git a/tests/rustdoc-ui/doctest/terminal-width.stderr b/tests/rustdoc-ui/doctest/terminal-width.stderr new file mode 100644 index 0000000000000..fed049d2b37bc --- /dev/null +++ b/tests/rustdoc-ui/doctest/terminal-width.stderr @@ -0,0 +1,15 @@ +error: this URL is not a hyperlink + --> $DIR/diagnostic-width.rs:4:41 + | +LL | ... a http://link.com + | ^^^^^^^^^^^^^^^ help: use an automatic link instead: `` + | +note: the lint level is defined here + --> $DIR/diagnostic-width.rs:2:9 + | +LL | ...ny(rustdoc::bare_url... + | ^^^^^^^^^^^^^^^^^^ + = note: bare URLs are not automatically turned into clickable links + +error: aborting due to previous error + diff --git a/tests/ui/json/autodiff.rs b/tests/ui/json/autodiff.rs new file mode 100644 index 0000000000000..54f94c3765bf6 --- /dev/null +++ b/tests/ui/json/autodiff.rs @@ -0,0 +1,16 @@ +// Check autodiff attribute +// edition:2018 + +extern "C" fn rosenbrock(a: f32, b: f32, x: f32, y: f32) -> f32 { + let (z, w) = (a-x, y-x*x); + + z*z + b*w*w +} + +#[autodiff(rosenbrock, mode = "forward")] +extern "C" { + fn dx_rosenbrock(a: f32, b: f32, x: f32, y: f32, d_x: &mut f32); + fn dy_rosenbrock(a: f32, b: f32, x: f32, y: f32, d_y: &mut f32); +} + +fn main() {} From 88a7f3e839748da704286c202aae757933daa918 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 30 Oct 2023 02:03:15 -0400 Subject: [PATCH 02/17] remove Cargo.lock and further rebase fixes --- .../rustc_codegen_ssa/src/codegen_attrs.rs | 31 +- .../rustc_monomorphize/src/partitioning.rs | 3 +- library/autodiff/Cargo.lock | 314 ------------------ 3 files changed, 21 insertions(+), 327 deletions(-) delete mode 100644 library/autodiff/Cargo.lock diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 58019ae43129f..01aad2790b407 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -708,12 +708,13 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { .collect::>(); // check for exactly one autodiff attribute on extern block + let msg_once = "autodiff attribute can only be applied once"; let attr = match &attrs[..] { &[] => return AutoDiffAttrs::inactive(), &[elm] => elm, x => { tcx.sess - .struct_span_err(x[1].span, "autodiff attribute can only be applied once") + .struct_span_err(x[1].span, msg_once) .span_label(x[1].span, "more than one") .emit(); @@ -732,13 +733,14 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { }; } + let msg_ad_mode = "autodiff attribute must contain autodiff mode"; let mode = match &list[0] { NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => { p2.segments.first().unwrap().ident } _ => { tcx.sess - .struct_span_err(attr.span, "attribute must contain autodiff mode") + .struct_span_err(attr.span, msg_ad_mode) .span_label(attr.span, "empty argument list") .emit(); @@ -747,13 +749,14 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { }; // parse mode + let msg_mode = "mode should be either forward or reverse"; let mode = match mode.as_str() { //map(|x| x.as_str()) { "Forward" => DiffMode::Forward, "Reverse" => DiffMode::Reverse, _ => { tcx.sess - .struct_span_err(attr.span, "mode should be either forward or reverse") + .struct_span_err(attr.span, msg_mode) .span_label(attr.span, "invalid mode") .emit(); @@ -761,13 +764,14 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { } }; + let msg_ret_activity = "autodiff attribute must contain the return activity"; let ret_symbol = match &list[1] { NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => { p2.segments.first().unwrap().ident } _ => { tcx.sess - .struct_span_err(attr.span, "autodiff attribute must contain the return activity") + .struct_span_err(attr.span, msg_ret_activity) .span_label(attr.span, "missing return activity") .emit(); @@ -775,11 +779,12 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { } }; + let msg_unknown_ret_activity = "unknown return activity"; let ret_activity = match DiffActivity::from_str(ret_symbol.as_str()) { Ok(x) => x, Err(_) => { tcx.sess - .struct_span_err(attr.span, "unknown return activity") + .struct_span_err(attr.span, msg_unknown_ret_activity) .span_label(attr.span, "invalid return activity") .emit(); @@ -787,6 +792,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { } }; + let msg_arg_activity = "autodiff attribute must contain the return activity"; let mut arg_activities: Vec = vec![]; for arg in &list[2..] { let arg_symbol = match arg { @@ -796,8 +802,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { _ => { tcx.sess .struct_span_err( - attr.span, - "autodiff attribute must contain the return activity", + attr.span, msg_arg_activity, ) .span_label(attr.span, "missing return activity") .emit(); @@ -810,7 +815,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { Ok(arg_activity) => arg_activities.push(arg_activity), Err(_) => { tcx.sess - .struct_span_err(attr.span, "unknown return activity") + .struct_span_err(attr.span, msg_unknown_ret_activity) .span_label(attr.span, "invalid input activity") .emit(); @@ -819,17 +824,20 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { } } + let msg_fwd_incompatible_ret = "Forward Mode is incompatible with Active ret"; + let msg_fwd_incompatible_arg = "Forward Mode is incompatible with Active ret"; + let msg_rev_incompatible_arg = "Reverse Mode is only compatible with Active, None, or Const ret"; if mode == DiffMode::Forward { if ret_activity == DiffActivity::Active { tcx.sess - .struct_span_err(attr.span, "Forward Mode is incompatible with Active ret") + .struct_span_err(attr.span, msg_fwd_incompatible_ret) .span_label(attr.span, "invalid return activity") .emit(); return AutoDiffAttrs::inactive(); } if arg_activities.iter().filter(|&x| *x == DiffActivity::Active).count() > 0 { tcx.sess - .struct_span_err(attr.span, "Forward Mode is incompatible with Active args") + .struct_span_err(attr.span, msg_fwd_incompatible_arg) .span_label(attr.span, "invalid input activity") .emit(); return AutoDiffAttrs::inactive(); @@ -842,8 +850,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { { tcx.sess .struct_span_err( - attr.span, - "Reverse Mode is only compatible with Active, None, or Const ret", + attr.span, msg_rev_incompatible_arg, ) .span_label(attr.span, "invalid return activity") .emit(); diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index fa39f35dc334e..55e63f5cf9c9a 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -103,6 +103,7 @@ use rustc_data_structures::sync; use rustc_hir::def::DefKind; use rustc_hir::def_id::{DefId, DefIdSet, LOCAL_CRATE}; use rustc_hir::definitions::DefPathDataName; +use rustc_middle::middle::autodiff_attrs::AutoDiffItem; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags; use rustc_middle::middle::exported_symbols::{SymbolExportInfo, SymbolExportLevel}; use rustc_middle::mir::mono::{ @@ -1078,7 +1079,7 @@ where } } -fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[CodegenUnit<'_>]) { +fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[AutoDiffItem], &[CodegenUnit<'_>]) { let collection_mode = match tcx.sess.opts.unstable_opts.print_mono_items { Some(ref s) => { let mode = s.to_lowercase(); diff --git a/library/autodiff/Cargo.lock b/library/autodiff/Cargo.lock deleted file mode 100644 index b11b872e7dbd9..0000000000000 --- a/library/autodiff/Cargo.lock +++ /dev/null @@ -1,314 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "autocfg" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" - -[[package]] -name = "autodiff" -version = "0.1.0" -dependencies = [ - "macrotest", - "ndarray", - "proc-macro-error", - "proc-macro2", - "quote", - "syn 1.0.109", - "trybuild", -] - -[[package]] -name = "basic-toml" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24c12265665aebaa236af9bbe266681bcc9c5666192119e3d8335cf083aca26f" -dependencies = [ - "serde", -] - -[[package]] -name = "diff" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" - -[[package]] -name = "glob" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" - -[[package]] -name = "itoa" -version = "1.0.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" - -[[package]] -name = "macrotest" -version = "1.0.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7489ae0986ce45414b7b3122c2e316661343ecf396b206e3e15f07c846616f10" -dependencies = [ - "diff", - "glob", - "prettyplease", - "serde", - "serde_json", - "syn 1.0.109", - "toml", -] - -[[package]] -name = "matrixmultiply" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" -dependencies = [ - "autocfg", - "rawpointer", -] - -[[package]] -name = "ndarray" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" -dependencies = [ - "matrixmultiply", - "num-complex", - "num-integer", - "num-traits", - "rawpointer", -] - -[[package]] -name = "num-complex" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-integer" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" -dependencies = [ - "autocfg", - "num-traits", -] - -[[package]] -name = "num-traits" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" -dependencies = [ - "autocfg", -] - -[[package]] -name = "once_cell" -version = "1.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" - -[[package]] -name = "prettyplease" -version = "0.1.25" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8646e95016a7a6c4adea95bafa8a16baab64b583356217f2c85db4a39d9a86" -dependencies = [ - "proc-macro2", - "syn 1.0.109", -] - -[[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" -dependencies = [ - "proc-macro2", - "quote", - "version_check", -] - -[[package]] -name = "proc-macro2" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "quote" -version = "1.0.33" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rawpointer" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" - -[[package]] -name = "ryu" -version = "1.0.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" - -[[package]] -name = "serde" -version = "1.0.190" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91d3c334ca1ee894a2c6f6ad698fe8c435b76d504b13d436f0685d648d6d96f7" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.190" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67c5609f394e5c2bd7fc51efda478004ea80ef42fee983d5c67a65e34f32c0e3" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.38", -] - -[[package]] -name = "serde_json" -version = "1.0.107" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" -dependencies = [ - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "syn" -version = "2.0.38" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "termcolor" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" -dependencies = [ - "winapi-util", -] - -[[package]] -name = "toml" -version = "0.5.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" -dependencies = [ - "serde", -] - -[[package]] -name = "trybuild" -version = "1.0.85" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "196a58260a906cedb9bf6d8034b6379d0c11f552416960452f267402ceeddff1" -dependencies = [ - "basic-toml", - "glob", - "once_cell", - "serde", - "serde_derive", - "serde_json", - "termcolor", -] - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "version_check" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" - -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-util" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" -dependencies = [ - "winapi", -] - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" From 6eec58db95c621fae06b18c7dbcea603af1e4819 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 30 Oct 2023 02:58:38 -0400 Subject: [PATCH 03/17] continue rebasing --- compiler/rustc_codegen_llvm/src/back/write.rs | 2 +- .../src/coverageinfo/mapgen.rs | 2 +- compiler/rustc_codegen_llvm/src/lib.rs | 3 +- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 1 - .../rustc_monomorphize/src/partitioning.rs | 207 +++++++++++++++++- 5 files changed, 209 insertions(+), 6 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 153b09d867a29..9e9cfede9da12 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -16,7 +16,7 @@ use crate::type_::Type; use crate::typetree::to_enzyme_typetree; use crate::LlvmCodegenBackend; use crate::ModuleLlvm; -use crate::{base, DiffTypeTree}; +use crate::DiffTypeTree; use llvm::{ enzyme_rust_forward_diff, enzyme_rust_reverse_diff, BasicBlock, CreateEnzymeLogic, CreateTypeAnalysis, EnzymeLogicRef, EnzymeTypeAnalysisRef, LLVMAddFunction, diff --git a/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs b/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs index 274e0aeaaba4f..13b5f02ab6e7c 100644 --- a/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs +++ b/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs @@ -409,7 +409,7 @@ fn add_unused_functions(cx: &CodegenCx<'_, '_>) { /// All items participating in code generation together with (instrumented) /// items inlined into them. fn codegenned_and_inlined_items(tcx: TyCtxt<'_>) -> DefIdSet { - let (items, cgus) = tcx.collect_and_partition_mono_items(()); + let (items, _, cgus) = tcx.collect_and_partition_mono_items(()); let mut visited = DefIdSet::default(); let mut result = items.clone(); diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 011a208eb6389..200cc4528883b 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -40,13 +40,12 @@ use rustc_codegen_ssa::back::write::{ use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::{CodegenResults, CompiledModule}; -use rustc_data_structures::fx::FxIndexMap; +use rustc_data_structures::fx::{FxHashMap, FxIndexMap}; use rustc_errors::{DiagnosticMessage, ErrorGuaranteed, FatalError, Handler, SubdiagnosticMessage}; use rustc_fluent_macro::fluent_messages; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::middle::autodiff_attrs::AutoDiffItem; -use rustc_middle::ty::query::Providers; use rustc_middle::ty::TyCtxt; use rustc_middle::util::Providers; use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index c5514d5bff823..651e89bd52d6a 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1,7 +1,6 @@ #![allow(non_camel_case_types)] #![allow(non_upper_case_globals)] -use rustc_codegen_ssa::coverageinfo::map as coverage_map; use rustc_middle::middle::autodiff_attrs::DiffActivity; use super::debuginfo::{ diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 55e63f5cf9c9a..1b3c99f8af5e3 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -92,6 +92,11 @@ //! source-level module, functions from the same module will be available for //! inlining, even when they are not marked `#[inline]`. +// Manuel, fixing rebase +use rustc_symbol_mangling::symbol_name_for_instance_in_crate; +//use crate::ty::ParamEnv; +use rustc_middle::ty::ParamEnv; + use std::cmp; use std::collections::hash_map::Entry; use std::fs::{self, File}; @@ -1138,6 +1143,50 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au }) .collect(); + + let autodiff_items = items + .iter() + .filter_map(|item| match *item { + MonoItem::Fn(ref instance) => Some((item, instance)), + _ => None, + }) + .filter_map(|(item, instance)| { + let target_id = instance.def_id(); + let target_attrs = tcx.autodiff_attrs(target_id); + if !target_attrs.apply_autodiff() { + return None; + } + + let target_symbol = + symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); + let range = usage_map.index.get(&item).unwrap(); + + let source = usage_map.targets[range.clone()] + .into_iter() + .filter_map(|item| match *item { + MonoItem::Fn(ref instance_s) => { + let source_id = instance_s.def_id(); + + if tcx.autodiff_attrs(source_id).is_active() { + return Some(instance_s); + } + + None + } + _ => None, + }) + .next(); + + source.map(|inst| { + let (inputs, output) = fnc_typetrees(inst.ty(tcx, ParamEnv::empty()), tcx); + let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); + + target_attrs.clone().into_item(symb, target_symbol, inputs, output) + }) + }); + + let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items); + // Output monomorphization stats per def_id if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats { if let Err(err) = @@ -1198,7 +1247,163 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au } } - (tcx.arena.alloc(mono_items), codegen_units) + (tcx.arena.alloc(mono_items), autodiff_items, codegen_units) +} +use rustc_middle::ty::{self, Adt, ParamEnvAnd, Ty}; +use rustc_target::abi::FieldsShape; +use std::iter; + +pub fn typetree_empty() -> TypeTree { + TypeTree(vec![]) +} + +pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize) -> TypeTree { + if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do whith fn ptr?"); + } + + let inner_ty = ty.builtin_deref(true).unwrap().ty; + let child = typetree_from_ty(inner_ty, tcx, depth + 1); + + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; + //println!("{:depth$} add indirection {:?}", "", tt); + + return TypeTree(vec![tt]); + } + + if ty.is_scalar() { + assert!(!ty.is_any_ptr()); + + let (kind, size) = if ty.is_integral() { + (Kind::Integer, 8) + } else { + assert!(ty.is_floating_point()); + match ty { + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + _ => panic!("floatTy scalar that is neither f32 nor f64"), + } + }; + + return TypeTree(vec![Type { offset: -1, child: typetree_empty(), kind, size }]); + } + + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: ty }; + + let layout = tcx.layout_of(param_env_and); + assert!(layout.is_ok()); + + let layout = layout.unwrap().layout; + let fields = layout.fields(); + let max_size = layout.size(); + + if ty.is_adt() { + let adt_def = ty.ty_adt_def().unwrap(); + let substs = match ty.kind() { + Adt(_, subst_ref) => subst_ref, + _ => panic!(""), + }; + + if adt_def.is_struct() { + let (offsets, _memory_index) = match fields { + FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m), + _ => panic!(""), + }; + //println!("{:depth$} combine fields", ""); + + let fields = adt_def.all_fields(); + let fields = fields + .into_iter() + .zip(offsets.into_iter()) + .filter_map(|(field, offset)| { + let field_ty: Ty<'_> = field.ty(tcx, substs); + let field_ty: Ty<'_> = + tcx.normalize_erasing_regions(ParamEnv::empty(), field_ty); + + if field_ty.is_phantom_data() { + return None; + } + + let mut child = typetree_from_ty(field_ty, tcx, depth + 1).0; + + for c in &mut child { + if c.offset == -1 { + c.offset = offset.bytes() as isize + } else { + c.offset += offset.bytes() as isize; + } + } + + //inner_tt.offset = offset; + + //println!("{:depth$} -> {:?}", "", child); + + Some(child) + }) + .flatten() + .collect::>(); + + let ret_tt = TypeTree(fields); + //println!("{:depth$} into {:?}", "", ret_tt); + return ret_tt; + } else { + unimplemented!("adt that isn't a struct"); + } + } + + if ty.is_array() { + let (stride, count) = match fields { + FieldsShape::Array { stride: s, count: c } => (s, c), + _ => panic!(""), + }; + let byte_stride = stride.bytes_usize(); + let byte_max_size = max_size.bytes_usize(); + + assert!(byte_stride * *count as usize == byte_max_size); + assert!(*count > 0); // return empty TT for empty? + let sub_ty = ty.builtin_index().unwrap(); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1); + + // calculate size of subtree + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty }; + let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize; + let tt = TypeTree( + iter::repeat(subtt) + .take(*count as usize) + .enumerate() + .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) + .flatten() + .collect(), + ); + + //println!("{:depth$} repeated array into {:?}", "", tt); + + return tt; + } + + if ty.is_slice() { + let sub_ty = ty.builtin_index().unwrap(); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1); + + return subtt; + } + + //println!("Warning: create empty typetree for {}", ty); + typetree_empty() +} + +pub fn fnc_typetrees<'tcx>(fn_ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> (Vec, TypeTree) { + let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); + + // TODO: verify. + let x: ty::FnSig<'_> = fnc_binder.skip_binder(); + + let inputs = x.inputs().into_iter().map(|x| typetree_from_ty(*x, tcx, 0)).collect(); + + let output = typetree_from_ty(x.output(), tcx, 0); + + (inputs, output) } /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s From a22d21de767cea8ec6a811944fed5c489508fd19 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 2 Nov 2023 16:16:10 -0400 Subject: [PATCH 04/17] cleanup use statements --- compiler/rustc_monomorphize/src/partitioning.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 1b3c99f8af5e3..6b9251ce63e4a 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -94,8 +94,8 @@ // Manuel, fixing rebase use rustc_symbol_mangling::symbol_name_for_instance_in_crate; -//use crate::ty::ParamEnv; -use rustc_middle::ty::ParamEnv; +use rustc_middle::middle::typetree::{Kind, Type, TypeTree}; +use rustc_target::abi::FieldsShape; use std::cmp; use std::collections::hash_map::Entry; @@ -117,7 +117,7 @@ use rustc_middle::mir::mono::{ }; use rustc_middle::query::Providers; use rustc_middle::ty::print::{characteristic_def_id_of_type, with_no_trimmed_paths}; -use rustc_middle::ty::{self, visit::TypeVisitableExt, InstanceDef, TyCtxt}; +use rustc_middle::ty::{self, visit::TypeVisitableExt, InstanceDef, TyCtxt, ParamEnv, ParamEnvAnd, Adt, Ty}; use rustc_session::config::{DumpMonoStatsFormat, SwitchWithOptPath}; use rustc_session::CodegenUnits; use rustc_span::symbol::Symbol; @@ -1249,9 +1249,9 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au (tcx.arena.alloc(mono_items), autodiff_items, codegen_units) } -use rustc_middle::ty::{self, Adt, ParamEnvAnd, Ty}; -use rustc_target::abi::FieldsShape; -use std::iter; +//use rustc_middle::ty::{self, Adt, ParamEnvAnd, Ty}; +//use rustc_target::abi::FieldsShape; +//use std::iter; pub fn typetree_empty() -> TypeTree { TypeTree(vec![]) @@ -1369,7 +1369,7 @@ pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize) -> TypeTr let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty }; let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize; let tt = TypeTree( - iter::repeat(subtt) + std::iter::repeat(subtt) .take(*count as usize) .enumerate() .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) From b66413ae134b901db28ca80dd3e1ed4e221cf403 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 2 Nov 2023 20:18:19 -0400 Subject: [PATCH 05/17] fixing rebase issues, updating enzyme submodule --- .../rustc_monomorphize/src/partitioning.rs | 5 +++-- src/bootstrap/src/core/build_steps/compile.rs | 6 +++--- src/bootstrap/src/core/config/config.rs | 19 +++++++++---------- src/tools/enzyme | 2 +- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 6b9251ce63e4a..606fbd9d6b145 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -1159,9 +1159,10 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); - let range = usage_map.index.get(&item).unwrap(); + //let range = usage_map.used_map.get(&item).unwrap(); + //TODO: check if last and next line are correct after rebasing - let source = usage_map.targets[range.clone()] + let source = usage_map.get_user_items(*item) .into_iter() .filter_map(|item| match *item { MonoItem::Fn(ref instance_s) => { diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index 7a53c4caffe6d..775b86628be79 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1608,12 +1608,12 @@ impl Step for Assemble { }; if let Some(enzyme_install) = enzyme_install { - let src_lib = enzyme_install.join("build/Enzyme/LLVMEnzyme-16.so"); + let src_lib = enzyme_install.join("build/Enzyme/LLVMEnzyme-17.so"); let libdir = builder.sysroot_libdir(build_compiler, build_compiler.host); let target_libdir = builder.sysroot_libdir(target_compiler, target_compiler.host); - let dst_lib = libdir.join("libLLVMEnzyme-16.so"); - let target_dst_lib = target_libdir.join("libLLVMEnzyme-16.so"); + let dst_lib = libdir.join("libLLVMEnzyme-17.so"); + let target_dst_lib = target_libdir.join("libLLVMEnzyme-17.so"); builder.copy(&src_lib, &dst_lib); builder.copy(&src_lib, &target_dst_lib); } diff --git a/src/bootstrap/src/core/config/config.rs b/src/bootstrap/src/core/config/config.rs index a49835b93b662..7034be88aed48 100644 --- a/src/bootstrap/src/core/config/config.rs +++ b/src/bootstrap/src/core/config/config.rs @@ -174,7 +174,6 @@ pub struct Config { pub llvm_assertions: bool, pub llvm_tests: bool, pub llvm_enzyme: bool, - pub llvm_enzyme_build: Option, pub llvm_plugins: bool, pub llvm_optimize: bool, pub llvm_thin_lto: bool, @@ -678,24 +677,24 @@ macro_rules! define_config { A: serde::de::MapAccess<'de>, { $(let mut $field: Option<$field_ty> = None;)* - while let Some(key) = - match serde::de::MapAccess::next_key::(&mut map) { - Ok(val) => val, - Err(err) => { - return Err(err); - } + while let Some(key) = + match serde::de::MapAccess::next_key::(&mut map) { + Ok(val) => val, + Err(err) => { + return Err(err); } + } { match &*key { $($field_key => { if $field.is_some() { return Err(::duplicate_field( - $field_key, - )); + $field_key, + )); } $field = match serde::de::MapAccess::next_value::<$field_ty>( &mut map, - ) { + ) { Ok(val) => Some(val), Err(err) => { return Err(err); diff --git a/src/tools/enzyme b/src/tools/enzyme index 86fc287c5a396..01c279c6b4674 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit 86fc287c5a39632364af2c48bc3efb5ef1f6652d +Subproject commit 01c279c6b46746172182dc2bf18466b95e67e199 From fd83e749819fb9350d77ee17ee28280457841960 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 2 Nov 2023 20:41:01 -0400 Subject: [PATCH 06/17] TODO: make buildname generic --- src/bootstrap/src/core/builder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bootstrap/src/core/builder.rs b/src/bootstrap/src/core/builder.rs index 65c94281e00d4..b2fa044cd0da3 100644 --- a/src/bootstrap/src/core/builder.rs +++ b/src/bootstrap/src/core/builder.rs @@ -1374,7 +1374,7 @@ impl<'a> Builder<'a> { // TODO: adjust -14 ending for Enzyme // https://rust-lang.zulipchat.com/#narrow/stream/182449-t-compiler.2Fhelp/topic/.E2.9C.94.20link.20new.20library.20into.20stage1.2Frustc rustflags.arg("-l"); - rustflags.arg("LLVMEnzyme-16"); + rustflags.arg("LLVMEnzyme-17"); let use_new_symbol_mangling = match self.config.rust_new_symbol_mangling { Some(setting) => { From 7c01c3d9cbf2e83a3d4282b7323ec9f4d3fe1e50 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 2 Nov 2023 22:58:54 -0400 Subject: [PATCH 07/17] cleanups --- compiler/rustc_ast/src/mut_visit.rs | 2 +- compiler/rustc_codegen_llvm/src/back/lto.rs | 1 - compiler/rustc_codegen_llvm/src/back/write.rs | 25 +------------------ compiler/rustc_codegen_ssa/src/back/lto.rs | 1 - compiler/rustc_codegen_ssa/src/back/write.rs | 12 ++++----- compiler/rustc_interface/src/tests.rs | 1 - .../rustc_monomorphize/src/partitioning.rs | 21 +--------------- compiler/rustc_session/src/options.rs | 2 -- 8 files changed, 8 insertions(+), 57 deletions(-) diff --git a/compiler/rustc_ast/src/mut_visit.rs b/compiler/rustc_ast/src/mut_visit.rs index 23e7975edd65b..0634ee970ec5e 100644 --- a/compiler/rustc_ast/src/mut_visit.rs +++ b/compiler/rustc_ast/src/mut_visit.rs @@ -381,7 +381,7 @@ pub fn visit_bounds(bounds: &mut GenericBounds, vis: &mut T) { } // No `noop_` prefix because there isn't a corresponding method in `MutVisitor`. -pub fn visit_fn_sig(FnSig { header, decl, span, .. }: &mut FnSig, vis: &mut T) { +pub fn visit_fn_sig(FnSig { header, decl, span }: &mut FnSig, vis: &mut T) { vis.visit_fn_header(header); vis.visit_fn_decl(decl); vis.visit_span(span); diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index c63870dfe4327..a1b546b1922f2 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -273,7 +273,6 @@ fn fat_lto( info!("pushing cached module {:?}", wp.cgu_name); (buffer, CString::new(wp.cgu_name).unwrap()) })); - for module in modules { match module { FatLtoInput::InMemory(m) => in_memory.push(m), diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 9e9cfede9da12..f54c8eeb156af 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -628,15 +628,13 @@ fn get_params(fnc: &Value) -> Vec<&Value> { } } -// TODO: cleanup +// TODO: Here we could start adding length checks for the shaddow args. unsafe fn create_wrapper<'a>( llmod: &'a llvm::Module, - //module: &'a ModuleCodegen, fnc: &'a Value, u_type: &Type, fnc_name: String, ) -> (&'a Value, &'a BasicBlock, Vec<&'a Value>, Vec<&'a Value>, CString) { - //let llmod = module.module_llvm.llmod(); let context = LLVMGetModuleContext(llmod); let inner_fnc_name = "inner_".to_string() + &fnc_name; let c_inner_fnc_name = CString::new(inner_fnc_name.clone()).unwrap(); @@ -656,22 +654,13 @@ unsafe fn create_wrapper<'a>( (outer_fnc, basic_block, outer_params, inner_params, c_inner_fnc_name) } -//pub(crate) fn get_type(t: LLVMTypeRef) -> CString { -// unsafe { CString::from_raw(LLVMPrintTypeToString(t)) } -//} - -// TODO: Don't write a wrapper function, just unwrap the struct inside of the same fnc. -// Might help during debugging, if you have one function less to jump trough pub(crate) unsafe fn extract_return_type<'a>( llmod: &'a llvm::Module, fnc: &'a Value, u_type: &Type, fnc_name: String, ) -> &'a Value { - //let llmod = module.module_llvm.llmod(); let context = llvm::LLVMGetModuleContext(llmod); - //dbg!("Unpacking", fnc_name.clone()); - //dbg!("From: ", f_type, " into ", u_type); let inner_param_num = LLVMCountParams(fnc); let (outer_fnc, outer_bb, mut outer_args, _inner_args, c_inner_fnc_name) = @@ -697,17 +686,9 @@ pub(crate) unsafe fn extract_return_type<'a>( let struct_ret = LLVMBuildExtractValue(builder, struct_ret, 0, c_inner_grad_name.as_ptr()); let _ret = LLVMBuildRet(builder, struct_ret); let _terminator = LLVMGetBasicBlockTerminator(outer_bb); - //assert!(LLVMIsNull(terminator)!=0, "no terminator"); LLVMDisposeBuilder(builder); - let _fnc_ok = LLVMVerifyFunction(outer_fnc, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); - //dbg!(outer_fnc); - //assert!(fnc_ok); - //if let Err(e) = verify_function(outer_fnc) { - // panic!("Creating a wrapper function failed! {}", e); - //} - outer_fnc } @@ -792,16 +773,12 @@ pub(crate) unsafe fn enzyme_ad( let void_type = LLVMVoidTypeInContext(llcx); if item.attrs.mode == DiffMode::Reverse && f_return_type != void_type { - //dbg!("Reverse Mode sanitizer"); - //dbg!(f_type); - //dbg!(f_return_type); let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); if num_elem_in_ret_struct == 1 { let u_type = LLVMTypeOf(target_fnc); res = extract_return_type(llmod, res, u_type, rust_name2.clone()); // TODO: check if name or name2 } } - //dbg!(&target_fnc); LLVMSetValueName2(res, name2.as_ptr(), rust_name2.len()); LLVMReplaceAllUsesWith(target_fnc, res); LLVMDeleteFunction(target_fnc); diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index f27b09c8146f3..d7bcc92abc667 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -88,7 +88,6 @@ impl LtoModuleCodegen { ) -> Result, FatalError> { match &self { LtoModuleCodegen::Fat { ref module, .. } => { - //let module = module.take().unwrap(); { B::autodiff(cgcx, &module, diff_fncs, typetrees, config)?; } diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 0f77938999d9e..ad4bad943e524 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -118,7 +118,6 @@ pub struct ModuleConfig { pub inline_threshold: Option, pub emit_lifetime_markers: bool, pub llvm_plugins: Vec, - pub enzyme_print_activity: bool, } impl ModuleConfig { @@ -196,7 +195,6 @@ impl ModuleConfig { false ), - enzyme_print_activity: sess.opts.unstable_opts.enzyme_print_activity, sanitizer: if_regular!(sess.opts.unstable_opts.sanitizer, SanitizerSet::empty()), sanitizer_recover: if_regular!( sess.opts.unstable_opts.sanitizer_recover, @@ -398,19 +396,19 @@ fn generate_lto_work( if !needs_fat_lto.is_empty() { assert!(needs_thin_lto.is_empty()); - let mut lto_module = + let mut module = B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise()); if cgcx.lto == Lto::Fat { let config = cgcx.config(ModuleKind::Regular); - lto_module = unsafe { lto_module.autodiff(cgcx, autodiff, typetrees, config).unwrap() }; + module = unsafe { module.autodiff(cgcx, autodiff, typetrees, config).unwrap() }; } // We are adding a single work item, so the cost doesn't matter. - vec![(WorkItem::LTO(lto_module), 0)] + vec![(WorkItem::LTO(module), 0)] } else { assert!(needs_fat_lto.is_empty()); - let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules) + let (modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules) .unwrap_or_else(|e| e.raise()); - lto_modules + modules .into_iter() .map(|module| { let cost = module.cost(); diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs index 4439550d8d037..57ca709267a7e 100644 --- a/compiler/rustc_interface/src/tests.rs +++ b/compiler/rustc_interface/src/tests.rs @@ -767,7 +767,6 @@ fn test_unstable_options_tracking_hash() { tracked!(debug_macros, true); tracked!(dep_info_omit_d_target, true); tracked!(dual_proc_macros, true); - tracked!(enzyme_print_activity, false); tracked!(dwarf_version, Some(5)); tracked!(emit_thin_lto, false); tracked!(export_executable_symbols, true); diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 606fbd9d6b145..811160c3a72e2 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -92,7 +92,6 @@ //! source-level module, functions from the same module will be available for //! inlining, even when they are not marked `#[inline]`. -// Manuel, fixing rebase use rustc_symbol_mangling::symbol_name_for_instance_in_crate; use rustc_middle::middle::typetree::{Kind, Type, TypeTree}; use rustc_target::abi::FieldsShape; @@ -1143,7 +1142,6 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au }) .collect(); - let autodiff_items = items .iter() .filter_map(|item| match *item { @@ -1156,6 +1154,7 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au if !target_attrs.apply_autodiff() { return None; } + println!("target_id: {:?}", target_id); let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); @@ -1250,9 +1249,6 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au (tcx.arena.alloc(mono_items), autodiff_items, codegen_units) } -//use rustc_middle::ty::{self, Adt, ParamEnvAnd, Ty}; -//use rustc_target::abi::FieldsShape; -//use std::iter; pub fn typetree_empty() -> TypeTree { TypeTree(vec![]) @@ -1263,19 +1259,14 @@ pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize) -> TypeTr if ty.is_fn_ptr() { unimplemented!("what to do whith fn ptr?"); } - let inner_ty = ty.builtin_deref(true).unwrap().ty; let child = typetree_from_ty(inner_ty, tcx, depth + 1); - let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; - //println!("{:depth$} add indirection {:?}", "", tt); - return TypeTree(vec![tt]); } if ty.is_scalar() { assert!(!ty.is_any_ptr()); - let (kind, size) = if ty.is_integral() { (Kind::Integer, 8) } else { @@ -1286,7 +1277,6 @@ pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize) -> TypeTr _ => panic!("floatTy scalar that is neither f32 nor f64"), } }; - return TypeTree(vec![Type { offset: -1, child: typetree_empty(), kind, size }]); } @@ -1311,7 +1301,6 @@ pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize) -> TypeTr FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m), _ => panic!(""), }; - //println!("{:depth$} combine fields", ""); let fields = adt_def.all_fields(); let fields = fields @@ -1336,17 +1325,12 @@ pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize) -> TypeTr } } - //inner_tt.offset = offset; - - //println!("{:depth$} -> {:?}", "", child); - Some(child) }) .flatten() .collect::>(); let ret_tt = TypeTree(fields); - //println!("{:depth$} into {:?}", "", ret_tt); return ret_tt; } else { unimplemented!("adt that isn't a struct"); @@ -1378,8 +1362,6 @@ pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize) -> TypeTr .collect(), ); - //println!("{:depth$} repeated array into {:?}", "", tt); - return tt; } @@ -1390,7 +1372,6 @@ pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize) -> TypeTr return subtt; } - //println!("Warning: create empty typetree for {}", ty); typetree_empty() } diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index e6401d2fbfbab..30c8b9d67002c 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -1537,8 +1537,6 @@ options! { "enables LTO for dylib crate type"), emit_stack_sizes: bool = (false, parse_bool, [UNTRACKED], "emit a section containing stack size metadata (default: no)"), - enzyme_print_activity: bool = (false, parse_bool, [TRACKED], - "print type trees for functions passed to enzyme"), emit_thin_lto: bool = (true, parse_bool, [TRACKED], "emit the bc module with thin LTO info (default: yes)"), export_executable_symbols: bool = (false, parse_bool, [TRACKED], From 7e6d48c164e1c0ce8c9dccd8a2921d0ed1733dc9 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 3 Nov 2023 00:33:11 -0400 Subject: [PATCH 08/17] fix monomorphization issue --- compiler/rustc_monomorphize/src/collector.rs | 2 +- .../rustc_monomorphize/src/partitioning.rs | 21 +++++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index baac8d98e8b1b..db3d3fc88ac05 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -204,7 +204,7 @@ pub enum MonoItemCollectionMode { pub struct UsageMap<'tcx> { // Maps every mono item to the mono items used by it. - used_map: FxHashMap, Vec>>, + pub used_map: FxHashMap, Vec>>, // Maps every mono item to the mono items that use it. user_map: FxHashMap, Vec>>, diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 811160c3a72e2..b8e631f54586b 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -1154,30 +1154,47 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Au if !target_attrs.apply_autodiff() { return None; } - println!("target_id: {:?}", target_id); + //println!("target_id: {:?}", target_id); let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); //let range = usage_map.used_map.get(&item).unwrap(); //TODO: check if last and next line are correct after rebasing - let source = usage_map.get_user_items(*item) + println!("target_symbol: {:?}", target_symbol); + println!("target_attrs: {:?}", target_attrs); + println!("target_id: {:?}", target_id); + //print item + println!("item: {:?}", item); + let source = usage_map.used_map.get(&item).unwrap() .into_iter() .filter_map(|item| match *item { MonoItem::Fn(ref instance_s) => { let source_id = instance_s.def_id(); + println!("source_id_inner: {:?}", source_id); + println!("instance_s: {:?}", instance_s); if tcx.autodiff_attrs(source_id).is_active() { + println!("source_id is active"); return Some(instance_s); } + //target_symbol: "_ZN14rosenbrock_rev12d_rosenbrock17h3352c4f00c3082daE" + //target_attrs: AutoDiffAttrs { mode: Reverse, ret_activity: Active, input_activity: [Duplicated] } + //target_id: DefId(0:8 ~ rosenbrock_rev[2708]::d_rosenbrock) + //item: Fn(Instance { def: Item(DefId(0:8 ~ rosenbrock_rev[2708]::d_rosenbrock)), args: [] }) + //source_id_inner: DefId(0:4 ~ rosenbrock_rev[2708]::main) + //instance_s: Instance { def: Item(DefId(0:4 ~ rosenbrock_rev[2708]::main)), args: [] } + None } _ => None, }) .next(); + println!("source: {:?}", source); source.map(|inst| { + println!("source_id: {:?}", inst.def_id()); let (inputs, output) = fnc_typetrees(inst.ty(tcx, ParamEnv::empty()), tcx); let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); From de6199598ca13f2335a1544209fdeaefa31d070d Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 3 Nov 2023 01:02:24 -0400 Subject: [PATCH 09/17] use default tt constructor --- compiler/rustc_codegen_llvm/src/base.rs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/base.rs b/compiler/rustc_codegen_llvm/src/base.rs index 1d9157e6355f4..b659fd02eecf6 100644 --- a/compiler/rustc_codegen_llvm/src/base.rs +++ b/compiler/rustc_codegen_llvm/src/base.rs @@ -25,7 +25,6 @@ use rustc_codegen_ssa::base::maybe_create_entry_wrapper; use rustc_codegen_ssa::mono_item::MonoItemExt; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{ModuleCodegen, ModuleKind}; -use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::small_c_str::SmallCStr; use rustc_middle::dep_graph; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; @@ -83,10 +82,9 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen recorder.record_arg(cgu.size_estimate().to_string()); }); // Instantiate monomorphizations without filling out definitions yet... - let mut llvm_module = ModuleLlvm::new(tcx, cgu_name.as_str()); - let typetrees = { + let llvm_module = ModuleLlvm::new(tcx, cgu_name.as_str()); + { let cx = CodegenCx::new(tcx, cgu, &llvm_module); - let mono_items = cx.codegen_unit.items_in_deterministic_order(cx.tcx); for &(mono_item, data) in &mono_items { mono_item.predefine::>(&cx, data.linkage, data.visibility); @@ -134,11 +132,7 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen if cx.sess().opts.debuginfo != DebugInfo::None { cx.debuginfo_finalize(); } - - FxHashMap::default() - }; - - llvm_module.typetrees = typetrees; + } ModuleCodegen { name: cgu_name.to_string(), From 84968bf6f7d5b0bdc8609847e198a1df1d1ce425 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 3 Nov 2023 21:36:34 -0400 Subject: [PATCH 10/17] fix monomorphization, extra dbg output --- compiler/rustc_codegen_llvm/src/back/write.rs | 3 +++ compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 13 +++++++++++-- compiler/rustc_monomorphize/src/partitioning.rs | 9 ++++++++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index f54c8eeb156af..7d32c5f339d36 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -712,6 +712,9 @@ pub(crate) unsafe fn enzyme_ad( let name2 = CString::new(rust_name2.clone()).unwrap(); let src_fnc = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr()).unwrap(); let target_fnc = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()).unwrap(); + let src_num_args = llvm::LLVMCountParams(src_fnc); + let target_num_args = llvm::LLVMCountParams(target_fnc); + assert!(src_num_args <= target_num_args); // create enzyme typetrees let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 651e89bd52d6a..0d9cf67cf9ac6 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -914,6 +914,8 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF); let input_activity: Vec = input_activity.iter().map(|&x| cdiffe_from(x)).collect(); + dbg!(&fnc); + if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG { if ret_primary_ret != true { dbg!("overwriting ret_primary_ret!"); @@ -931,6 +933,11 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( // We don't support volatile / extern / (global?) values. // Just because I didn't had time to test them, and it seems less urgent. let args_uncacheable = vec![0; input_tts.len()]; + assert!(args_uncacheable.len() == input_activity.len()); + let num_fnc_args = LLVMCountParams(fnc); + println!("num_fnc_args: {}", num_fnc_args); + println!("input_activity.len(): {}", input_activity.len()); + assert!(num_fnc_args == input_activity.len() as u32); let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; @@ -942,7 +949,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( KnownValues: known_values.as_mut_ptr(), }; - EnzymeCreatePrimalAndGradient( + let res = EnzymeCreatePrimalAndGradient( logic_ref, // Logic std::ptr::null(), std::ptr::null(), @@ -963,7 +970,9 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( args_uncacheable.len(), // uncacheable arguments std::ptr::null_mut(), // write augmented function to this 0, - ) + ); + dbg!(&res); + res } pub type GetSymbolsCallback = unsafe extern "C" fn(*mut c_void, *const c_char) -> *mut c_void; pub type GetSymbolsErrorCallback = unsafe extern "C" fn(*const c_char) -> *mut c_void; diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index b8e631f54586b..403e16ff49a60 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -247,7 +247,14 @@ where &mut can_be_internalized, export_generics, ); - if visibility == Visibility::Hidden && can_be_internalized { + //if visibility == Visibility::Hidden && can_be_internalized { + + //dbg!(&characteristic_def_id); + let autodiff_active = characteristic_def_id + .map(|x| cx.tcx.autodiff_attrs(x).is_active()) + .unwrap_or(false); + + if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized { internalization_candidates.insert(mono_item); } let size_estimate = mono_item.size_estimate(cx.tcx); From d1c94affffd58219316d620ba14521ddd1286f59 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 5 Nov 2023 21:24:21 -0500 Subject: [PATCH 11/17] add and use better abstractions for Attribute --- compiler/rustc_codegen_llvm/src/back/write.rs | 15 +++++++-------- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 17 +++++++++++------ .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 13 +++++++++++++ 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 7d32c5f339d36..67addebbdf460 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -17,7 +17,7 @@ use crate::typetree::to_enzyme_typetree; use crate::LlvmCodegenBackend; use crate::ModuleLlvm; use crate::DiffTypeTree; -use llvm::{ +use llvm::{LLVMRustGetEnumAttributeAtIndex, LLVMRustAddEnumAttributeAtIndex, LLVMRustRemoveEnumAttributeAtIndex, enzyme_rust_forward_diff, enzyme_rust_reverse_diff, BasicBlock, CreateEnzymeLogic, CreateTypeAnalysis, EnzymeLogicRef, EnzymeTypeAnalysisRef, LLVMAddFunction, LLVMAppendBasicBlockInContext, LLVMBuildCall2, LLVMBuildExtractValue, LLVMBuildRet, @@ -25,9 +25,9 @@ use llvm::{ LLVMDisposeBuilder, LLVMGetBasicBlockTerminator, LLVMGetElementType, LLVMGetModuleContext, LLVMGetParams, LLVMGetReturnType, LLVMPositionBuilderAtEnd, LLVMSetValueName2, LLVMTypeOf, LLVMVoidTypeInContext, LLVMGlobalGetValueType, LLVMGetStringAttributeAtIndex, - LLVMIsStringAttribute, LLVMRemoveStringAttributeAtIndex, LLVMRemoveEnumAttributeAtIndex, AttributeKind, - LLVMGetFirstFunction, LLVMGetNextFunction, LLVMGetEnumAttributeAtIndex, LLVMIsEnumAttribute, - LLVMCreateStringAttribute, LLVMRustAddFunctionAttributes, LLVMCreateEnumAttribute, LLVMDumpModule, + LLVMIsStringAttribute, LLVMRemoveStringAttributeAtIndex, AttributeKind, + LLVMGetFirstFunction, LLVMGetNextFunction, LLVMIsEnumAttribute, + LLVMCreateStringAttribute, LLVMRustAddFunctionAttributes, LLVMDumpModule, LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, }; use rustc_codegen_ssa::back::link::ensure_removed; @@ -831,7 +831,7 @@ pub(crate) unsafe fn differentiate( if LLVMIsStringAttribute(attr) { LLVMRemoveStringAttributeAtIndex(lf, c_uint::MAX, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint); } else { - LLVMRemoveEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); + LLVMRustRemoveEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); } @@ -876,13 +876,12 @@ pub(crate) unsafe fn optimize( f = LLVMGetNextFunction(lf); let myhwattr = "enzyme_hw"; let myhwv = ""; - let prevattr = LLVMGetEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); + let prevattr = LLVMRustGetEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); if LLVMIsEnumAttribute(prevattr) { let attr = LLVMCreateStringAttribute(llcx, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint, myhwv.as_ptr() as *const c_char, myhwv.as_bytes().len() as c_uint); LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); } else { - let attr = LLVMCreateEnumAttribute(llcx, AttributeKind::SanitizeHWAddress, 0); - LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); + LLVMRustAddEnumAttributeAtIndex(llcx, lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); } } else { diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 0d9cf67cf9ac6..94285ed817ef5 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -191,7 +191,7 @@ pub enum AttributeKind { OptimizeNone = 24, ReturnsTwice = 25, ReadNone = 26, - SanitizeHWAddress = 51, + SanitizeHWAddress = 28, WillReturn = 29, StackProtectReq = 30, StackProtectStrong = 31, @@ -980,19 +980,21 @@ pub type GetSymbolsErrorCallback = unsafe extern "C" fn(*const c_char) -> *mut c extern "C" { // Enzyme - //pub fn LLVMReplaceAllUsesWith(old: &Value, new: &Value); - pub fn GibtsNicht(M: &Module) -> bool; pub fn LLVMIsStructTy(ty: &Type) -> bool; pub fn LLVMGetReturnType(T: &Type) -> &Type; pub fn LLVMDumpModule(M: &Module); pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; pub fn LLVMDeleteFunction(V: &Value); + + pub fn LLVMCreateEnumAttribute(C : &Context, Kind: Attribute, val:u64) -> &Attribute; pub fn LLVMRemoveStringAttributeAtIndex(F : &Value, Idx: c_uint, K: *const c_char, KLen : c_uint); pub fn LLVMGetStringAttributeAtIndex(F : &Value, Idx: c_uint, K: *const c_char, KLen : c_uint) -> &Attribute; - pub fn LLVMRemoveEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: AttributeKind); - pub fn LLVMGetEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: AttributeKind) -> &Attribute; + + pub fn LLVMAddAttributeAtIndex(F : &Value, Idx: c_uint, K: &Attribute); + pub fn LLVMRemoveEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: Attribute); + pub fn LLVMGetEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: Attribute) -> &Attribute; + pub fn LLVMIsEnumAttribute(A : &Attribute) -> bool; - pub fn LLVMCreateEnumAttribute(C : &Context, Kind: AttributeKind, val:u64) -> &Attribute; pub fn LLVMIsStringAttribute(A : &Attribute) -> bool; pub fn LLVMVerifyFunction(V: &Value, action: LLVMVerifierFailureAction) -> bool; pub fn LLVMGetParams(Fnc: &Value, parms: *mut &Value); @@ -1207,6 +1209,9 @@ extern "C" { // Operations on attributes pub fn LLVMRustCreateAttrNoValue(C: &Context, attr: AttributeKind) -> &Attribute; + pub fn LLVMRustAddEnumAttributeAtIndex(C: &Context, V: &Value, index: c_uint, attr: AttributeKind); + pub fn LLVMRustRemoveEnumAttributeAtIndex(V: &Value, index: c_uint, attr: AttributeKind); + pub fn LLVMRustGetEnumAttributeAtIndex(V: &Value, index: c_uint, attr: AttributeKind) ->&Attribute; pub fn LLVMCreateStringAttribute( C: &Context, Name: *const c_char, diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index e7db075aefa2f..f357127794438 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -319,6 +319,19 @@ extern "C" LLVMAttributeRef LLVMRustCreateAttrNoValue(LLVMContextRef C, return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr))); } +extern "C" void LLVMRustRemoveEnumAttributeAtIndex(LLVMValueRef F, size_t index, LLVMRustAttribute RustAttr) { + LLVMRemoveEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + +extern "C" void LLVMRustAddEnumAttributeAtIndex(LLVMContextRef C, LLVMValueRef F, size_t index, LLVMRustAttribute RustAttr) { + LLVMAddAttributeAtIndex(F, index, LLVMRustCreateAttrNoValue(C, RustAttr)); +} + +extern "C" LLVMAttributeRef LLVMRustGetEnumAttributeAtIndex(LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + return LLVMGetEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + extern "C" LLVMAttributeRef LLVMRustCreateAlignmentAttr(LLVMContextRef C, uint64_t Bytes) { return wrap(Attribute::getWithAlignment(*unwrap(C), llvm::Align(Bytes))); From dee82b30e4748b148d67f2ad71908ab6e6ef52b0 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 5 Nov 2023 23:35:52 -0500 Subject: [PATCH 12/17] wire up propper Rust error handler --- compiler/rustc_codegen_llvm/messages.ftl | 3 ++ compiler/rustc_codegen_llvm/src/back/write.rs | 30 ++++++++++++++++--- compiler/rustc_codegen_llvm/src/errors.rs | 7 +++++ 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/compiler/rustc_codegen_llvm/messages.ftl b/compiler/rustc_codegen_llvm/messages.ftl index c0cfe39f1e0a4..efc4abc719ade 100644 --- a/compiler/rustc_codegen_llvm/messages.ftl +++ b/compiler/rustc_codegen_llvm/messages.ftl @@ -60,6 +60,9 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO codegen_llvm_run_passes = failed to run LLVM passes codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err} +codegen_llvm_prepare_autodiff = failed to prepare AutoDiff: src: {$src}, target: {$target}, {$error} +codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare AutoDiff: {$llvm_err}, src: {$src}, target: {$target}, {$error} + codegen_llvm_sanitizer_memtag_requires_mte = `-Zsanitizer=memtag` requires `-Ctarget-feature=+mte` diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 67addebbdf460..91b481074675d 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -698,6 +698,7 @@ pub(crate) unsafe fn extract_return_type<'a>( pub(crate) unsafe fn enzyme_ad( llmod: &llvm::Module, llcx: &llvm::Context, + diag_handler: &rustc_errors::Handler, item: AutoDiffItem, ) -> Result<(), FatalError> { let autodiff_mode = item.attrs.mode; @@ -710,8 +711,28 @@ pub(crate) unsafe fn enzyme_ad( // get target and source function let name = CString::new(rust_name.to_owned()).unwrap(); let name2 = CString::new(rust_name2.clone()).unwrap(); - let src_fnc = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr()).unwrap(); - let target_fnc = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()).unwrap(); + let src_fnc_opt = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr()); + let src_fnc = match src_fnc_opt { + Some(x) => x, + None => { + return Err(llvm_err(diag_handler, LlvmError::PrepareAutoDiff{ + src: rust_name.to_owned(), + target: rust_name2.to_owned(), + error: "could not find src function".to_owned(), + })); + } + }; + let target_fnc_opt = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()); + let target_fnc = match target_fnc_opt { + Some(x) => x, + None => { + return Err(llvm_err(diag_handler, LlvmError::PrepareAutoDiff{ + src: rust_name.to_owned(), + target: rust_name2.to_owned(), + error: "could not find target function".to_owned(), + })); + } + }; let src_num_args = llvm::LLVMCountParams(src_fnc); let target_num_args = llvm::LLVMCountParams(target_fnc); assert!(src_num_args <= target_num_args); @@ -791,13 +812,14 @@ pub(crate) unsafe fn enzyme_ad( pub(crate) unsafe fn differentiate( module: &ModuleCodegen, - _cgcx: &CodegenContext, + cgcx: &CodegenContext, diff_items: Vec, _typetrees: FxHashMap, _config: &ModuleConfig, ) -> Result<(), FatalError> { let llmod = module.module_llvm.llmod(); let llcx = &module.module_llvm.llcx; + let diag_handler = cgcx.create_diag_handler(); llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); @@ -818,7 +840,7 @@ pub(crate) unsafe fn differentiate( } for item in diff_items { - let res = enzyme_ad(llmod, llcx, item); + let res = enzyme_ad(llmod, llcx, &diag_handler, item); assert!(res.is_ok()); } diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index 665d195790c2d..6d317a400d4be 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -172,6 +172,12 @@ pub enum LlvmError<'a> { PrepareThinLtoModule, #[diag(codegen_llvm_parse_bitcode)] ParseBitcode, + #[diag(codegen_llvm_prepare_autodiff)] + PrepareAutoDiff { + src: String, + target: String, + error: String, + } } pub(crate) struct WithLlvmError<'a>(pub LlvmError<'a>, pub String); @@ -193,6 +199,7 @@ impl IntoDiagnostic<'_, EM> for WithLlvmError<'_> { } PrepareThinLtoModule => fluent::codegen_llvm_prepare_thin_lto_module_with_llvm_err, ParseBitcode => fluent::codegen_llvm_parse_bitcode_with_llvm_err, + PrepareAutoDiff { .. } => fluent::codegen_llvm_prepare_autodiff_with_llvm_err, }; let mut diag = self.0.into_diagnostic(sess); diag.set_primary_message(msg_with_llvm_err); From 50a8258a93817fa4365848bc44d7e3fe4df0838c Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 10 Nov 2023 15:54:09 -0500 Subject: [PATCH 13/17] move AD types from rustc_middle to rustc_ast --- compiler/rustc_ast/src/ast.rs | 2 + .../rustc_ast/src/expand/autodiff_attrs.rs | 149 +++++++++++++++ compiler/rustc_ast/src/expand/mod.rs | 2 + compiler/rustc_ast/src/expand/typetree.rs | 67 +++++++ compiler/rustc_builtin_macros/messages.ftl | 2 + compiler/rustc_builtin_macros/src/autodiff.rs | 178 ++++++++++++++++++ compiler/rustc_builtin_macros/src/errors.rs | 7 + compiler/rustc_builtin_macros/src/lib.rs | 2 + compiler/rustc_codegen_llvm/src/back/write.rs | 2 +- compiler/rustc_codegen_llvm/src/lib.rs | 2 +- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 2 +- compiler/rustc_codegen_llvm/src/typetree.rs | 2 +- compiler/rustc_codegen_ssa/src/back/lto.rs | 2 +- compiler/rustc_codegen_ssa/src/back/write.rs | 2 +- .../rustc_codegen_ssa/src/codegen_attrs.rs | 2 +- .../rustc_codegen_ssa/src/traits/write.rs | 2 +- compiler/rustc_expand/src/expand.rs | 1 + compiler/rustc_feature/src/builtin_attrs.rs | 5 + compiler/rustc_middle/src/arena.rs | 2 +- compiler/rustc_middle/src/middle/mod.rs | 1 - compiler/rustc_middle/src/query/mod.rs | 2 +- compiler/rustc_monomorphize/Cargo.toml | 1 + .../rustc_monomorphize/src/partitioning.rs | 4 +- compiler/rustc_span/src/symbol.rs | 1 + 24 files changed, 429 insertions(+), 13 deletions(-) create mode 100644 compiler/rustc_ast/src/expand/autodiff_attrs.rs create mode 100644 compiler/rustc_ast/src/expand/typetree.rs create mode 100644 compiler/rustc_builtin_macros/src/autodiff.rs diff --git a/compiler/rustc_ast/src/ast.rs b/compiler/rustc_ast/src/ast.rs index 146a4db200caa..62d3c0227a9c6 100644 --- a/compiler/rustc_ast/src/ast.rs +++ b/compiler/rustc_ast/src/ast.rs @@ -1592,11 +1592,13 @@ impl MacCall { } } +/// Manuel /// Arguments passed to an attribute macro. #[derive(Clone, Encodable, Decodable, Debug)] pub enum AttrArgs { /// No arguments: `#[attr]`. Empty, + /// Manuel autodiff /// Delimited arguments: `#[attr()/[]/{}]`. Delimited(DelimArgs), /// Arguments of a key-value attribute: `#[attr = "value"]`. diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs new file mode 100644 index 0000000000000..2a8a120d62323 --- /dev/null +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -0,0 +1,149 @@ +use super::typetree::TypeTree; +use std::str::FromStr; +use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableOrd}; +use crate::HashStableContext; + +#[allow(dead_code)] +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)] +pub enum DiffMode { + Inactive, + Source, + Forward, + Reverse, +} + +#[allow(dead_code)] +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)] +pub enum DiffActivity { + None, + Active, + Const, + Duplicated, + DuplicatedNoNeed, +} +fn clause_diffactivity_discriminant(value: &DiffActivity) -> usize { + match value { + DiffActivity::None => 0, + DiffActivity::Active => 1, + DiffActivity::Const => 2, + DiffActivity::Duplicated => 3, + DiffActivity::DuplicatedNoNeed => 4, + } +} +fn clause_diffmode_discriminant(value: &DiffMode) -> usize { + match value { + DiffMode::Inactive => 0, + DiffMode::Source => 1, + DiffMode::Forward => 2, + DiffMode::Reverse => 3, + } +} + + +impl HashStable for DiffMode { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + clause_diffmode_discriminant(self).hash_stable(hcx, hasher); + } +} + +impl HashStable for DiffActivity { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + clause_diffactivity_discriminant(self).hash_stable(hcx, hasher); + } +} + + +impl FromStr for DiffActivity { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "None" => Ok(DiffActivity::None), + "Active" => Ok(DiffActivity::Active), + "Const" => Ok(DiffActivity::Const), + "Duplicated" => Ok(DiffActivity::Duplicated), + "DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed), + _ => Err(()), + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] +pub struct AutoDiffAttrs { + pub mode: DiffMode, + pub ret_activity: DiffActivity, + pub input_activity: Vec, +} + +impl HashStable for AutoDiffAttrs { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + self.mode.hash_stable(hcx, hasher); + self.ret_activity.hash_stable(hcx, hasher); + self.input_activity.hash_stable(hcx, hasher); + } +} + +impl AutoDiffAttrs { + pub fn inactive() -> Self { + AutoDiffAttrs { + mode: DiffMode::Inactive, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + } + } + + pub fn is_active(&self) -> bool { + match self.mode { + DiffMode::Inactive => false, + _ => true, + } + } + + pub fn is_source(&self) -> bool { + match self.mode { + DiffMode::Source => true, + _ => false, + } + } + pub fn apply_autodiff(&self) -> bool { + match self.mode { + DiffMode::Inactive => false, + DiffMode::Source => false, + _ => true, + } + } + + pub fn into_item( + self, + source: String, + target: String, + inputs: Vec, + output: TypeTree, + ) -> AutoDiffItem { + AutoDiffItem { source, target, inputs, output, attrs: self } + } +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] +pub struct AutoDiffItem { + pub source: String, + pub target: String, + pub attrs: AutoDiffAttrs, + pub inputs: Vec, + pub output: TypeTree, +} + +impl HashStable for AutoDiffItem { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + self.source.hash_stable(hcx, hasher); + self.target.hash_stable(hcx, hasher); + self.attrs.hash_stable(hcx, hasher); + for tt in &self.inputs { + tt.0.hash_stable(hcx, hasher); + } + //self.inputs.hash_stable(hcx, hasher); + self.output.0.hash_stable(hcx, hasher); + } +} + diff --git a/compiler/rustc_ast/src/expand/mod.rs b/compiler/rustc_ast/src/expand/mod.rs index 942347383ce31..b8434374a3594 100644 --- a/compiler/rustc_ast/src/expand/mod.rs +++ b/compiler/rustc_ast/src/expand/mod.rs @@ -5,6 +5,8 @@ use rustc_span::{def_id::DefId, symbol::Ident}; use crate::MetaItem; pub mod allocator; +pub mod typetree; +pub mod autodiff_attrs; #[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)] pub struct StrippedCfgItem { diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs new file mode 100644 index 0000000000000..2adcc724831a7 --- /dev/null +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -0,0 +1,67 @@ +use std::fmt; +use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableOrd}; +use crate::HashStableContext; + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] +pub enum Kind { + Anything, + Integer, + Pointer, + Half, + Float, + Double, + Unknown, +} +impl HashStable for Kind { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + clause_kind_discriminant(self).hash_stable(hcx, hasher); + } +} +fn clause_kind_discriminant(value: &Kind) -> usize { + match value { + Kind::Anything => 0, + Kind::Integer => 1, + Kind::Pointer => 2, + Kind::Half => 3, + Kind::Float => 4, + Kind::Double => 5, + Kind::Unknown => 6, + } +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] +pub struct TypeTree(pub Vec); + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] +pub struct Type { + pub offset: isize, + pub size: usize, + pub kind: Kind, + pub child: TypeTree, +} + +impl HashStable for Type { + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { + self.offset.hash_stable(hcx, hasher); + self.size.hash_stable(hcx, hasher); + self.kind.hash_stable(hcx, hasher); + self.child.0.hash_stable(hcx, hasher); + } +} + +impl Type { + pub fn add_offset(self, add: isize) -> Self { + let offset = match self.offset { + -1 => add, + x => add + x, + }; + + Self { size: self.size, kind: self.kind, child: self.child, offset } + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} diff --git a/compiler/rustc_builtin_macros/messages.ftl b/compiler/rustc_builtin_macros/messages.ftl index dda466b026d91..07c9b588e1faf 100644 --- a/compiler/rustc_builtin_macros/messages.ftl +++ b/compiler/rustc_builtin_macros/messages.ftl @@ -1,6 +1,8 @@ builtin_macros_alloc_error_must_be_fn = alloc_error_handler must be a function builtin_macros_alloc_must_statics = allocators must be statics +builtin_macros_autodiff = autodiff must be applied to function + builtin_macros_asm_clobber_abi = clobber_abi builtin_macros_asm_clobber_no_reg = asm with `clobber_abi` must specify explicit registers for outputs builtin_macros_asm_clobber_outputs = generic outputs diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs new file mode 100644 index 0000000000000..2e5c57e14f9f3 --- /dev/null +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -0,0 +1,178 @@ +use crate::errors; +use crate::util::check_builtin_macro_attribute; +//use rustc_ast::expand::allocator::{ +// global_fn_name, AllocatorMethod, AllocatorMethodInput, AllocatorTy, ALLOCATOR_METHODS, +//}; +//use rustc_middle::middle::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode}; +//use rustc_ast::ptr::P; +use rustc_ast::{self as ast};//, AttrVec, Expr, FnHeader, FnSig, Generics, Param, StmtKind}; +//use rustc_ast::{Fn, ItemKind, Mutability, Stmt, Ty, TyKind, Unsafe}; +use rustc_ast::ItemKind; +use rustc_expand::base::{Annotatable, ExtCtxt}; +//use rustc_span::symbol::{kw, sym, Ident, Symbol}; +use rustc_span::symbol::sym; +use rustc_span::Span; +//use thin_vec::{thin_vec, ThinVec}; + + +pub fn expand( + ecx: &mut ExtCtxt<'_>, + _span: Span, + meta_item: &ast::MetaItem, + item: Annotatable, +) -> Vec { + check_builtin_macro_attribute(ecx, meta_item, sym::autodiff); + let orig_item = item.clone(); + // FnSig + // inner_tokens + //ItemKind::Fn(Box) + // Allow using `#[autodiff(...)]` on a Fn + let (item, _ty_span) = if let Annotatable::Item(item) = &item + && let ItemKind::Fn(box ast::Fn { sig, .. }) = &item.kind + { + (item, ecx.with_def_site_ctxt(sig.span)) + } else { + ecx.sess + .parse_sess + .span_diagnostic + .emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![orig_item]; + }; + let _span = ecx.with_def_site_ctxt(item.span); +// let f = AllocFnFactory { span, ty_span, global: item.ident, cx: ecx }; + return vec![orig_item]; +} +// +// // Generate a bunch of new items using the AllocFnFactory +// let f = AllocFnFactory { span, ty_span, global: item.ident, cx: ecx }; +// +// // Generate item statements for the allocator methods. +// let stmts = ALLOCATOR_METHODS.iter().map(|method| f.allocator_fn(method)).collect(); +// +// // Generate anonymous constant serving as container for the allocator methods. +// let const_ty = ecx.ty(ty_span, TyKind::Tup(ThinVec::new())); +// let const_body = ecx.expr_block(ecx.block(span, stmts)); +// let const_item = ecx.item_const(span, Ident::new(kw::Underscore, span), const_ty, const_body); +// let const_item = if is_stmt { +// Annotatable::Stmt(P(ecx.stmt_item(span, const_item))) +// } else { +// Annotatable::Item(const_item) +// }; +// +// // Return the original item and the new methods. +// vec![orig_item, const_item] +//} + +//struct AllocFnFactory<'a, 'b> { +// span: Span, +// ty_span: Span, +// global: Ident, +// cx: &'b ExtCtxt<'a>, +//} + +//impl AllocFnFactory<'_, '_> { +// fn allocator_fn(&self, method: &AllocatorMethod) -> Stmt { +// let mut abi_args = ThinVec::new(); +// let args = method.inputs.iter().map(|input| self.arg_ty(input, &mut abi_args)).collect(); +// let result = self.call_allocator(method.name, args); +// let output_ty = self.ret_ty(&method.output); +// let decl = self.cx.fn_decl(abi_args, ast::FnRetTy::Ty(output_ty)); +// let header = FnHeader { unsafety: Unsafe::Yes(self.span), ..FnHeader::default() }; +// let sig = FnSig { decl, header, span: self.span }; +// let body = Some(self.cx.block_expr(result)); +// let kind = ItemKind::Fn(Box::new(Fn { +// defaultness: ast::Defaultness::Final, +// sig, +// generics: Generics::default(), +// body, +// })); +// let item = self.cx.item( +// self.span, +// Ident::from_str_and_span(&global_fn_name(method.name), self.span), +// self.attrs(), +// kind, +// ); +// self.cx.stmt_item(self.ty_span, item) +// } +// +// fn call_allocator(&self, method: Symbol, mut args: ThinVec>) -> P { +// let method = self.cx.std_path(&[sym::alloc, sym::GlobalAlloc, method]); +// let method = self.cx.expr_path(self.cx.path(self.ty_span, method)); +// let allocator = self.cx.path_ident(self.ty_span, self.global); +// let allocator = self.cx.expr_path(allocator); +// let allocator = self.cx.expr_addr_of(self.ty_span, allocator); +// args.insert(0, allocator); +// +// self.cx.expr_call(self.ty_span, method, args) +// } +// +// fn attrs(&self) -> AttrVec { +// thin_vec![self.cx.attr_word(sym::rustc_std_internal_symbol, self.span)] +// } +// +// fn arg_ty(&self, input: &AllocatorMethodInput, args: &mut ThinVec) -> P { +// match input.ty { +// AllocatorTy::Layout => { +// // If an allocator method is ever introduced having multiple +// // Layout arguments, these argument names need to be +// // disambiguated somehow. Currently the generated code would +// // fail to compile with "identifier is bound more than once in +// // this parameter list". +// let size = Ident::from_str_and_span("size", self.span); +// let align = Ident::from_str_and_span("align", self.span); +// +// let usize = self.cx.path_ident(self.span, Ident::new(sym::usize, self.span)); +// let ty_usize = self.cx.ty_path(usize); +// args.push(self.cx.param(self.span, size, ty_usize.clone())); +// args.push(self.cx.param(self.span, align, ty_usize)); +// +// let layout_new = +// self.cx.std_path(&[sym::alloc, sym::Layout, sym::from_size_align_unchecked]); +// let layout_new = self.cx.expr_path(self.cx.path(self.span, layout_new)); +// let size = self.cx.expr_ident(self.span, size); +// let align = self.cx.expr_ident(self.span, align); +// let layout = self.cx.expr_call(self.span, layout_new, thin_vec![size, align]); +// layout +// } +// +// AllocatorTy::Ptr => { +// let ident = Ident::from_str_and_span(input.name, self.span); +// args.push(self.cx.param(self.span, ident, self.ptr_u8())); +// self.cx.expr_ident(self.span, ident) +// } +// +// AllocatorTy::Usize => { +// let ident = Ident::from_str_and_span(input.name, self.span); +// args.push(self.cx.param(self.span, ident, self.usize())); +// self.cx.expr_ident(self.span, ident) +// } +// +// AllocatorTy::ResultPtr | AllocatorTy::Unit => { +// panic!("can't convert AllocatorTy to an argument") +// } +// } +// } +// +// fn ret_ty(&self, ty: &AllocatorTy) -> P { +// match *ty { +// AllocatorTy::ResultPtr => self.ptr_u8(), +// +// AllocatorTy::Unit => self.cx.ty(self.span, TyKind::Tup(ThinVec::new())), +// +// AllocatorTy::Layout | AllocatorTy::Usize | AllocatorTy::Ptr => { +// panic!("can't convert `AllocatorTy` to an output") +// } +// } +// } +// +// fn usize(&self) -> P { +// let usize = self.cx.path_ident(self.span, Ident::new(sym::usize, self.span)); +// self.cx.ty_path(usize) +// } +// +// fn ptr_u8(&self) -> P { +// let u8 = self.cx.path_ident(self.span, Ident::new(sym::u8, self.span)); +// let ty_u8 = self.cx.ty_path(u8); +// self.cx.ty_ptr(self.span, ty_u8, Mutability::Mut) +// } +//} diff --git a/compiler/rustc_builtin_macros/src/errors.rs b/compiler/rustc_builtin_macros/src/errors.rs index fde4270334b67..75c42fcee350e 100644 --- a/compiler/rustc_builtin_macros/src/errors.rs +++ b/compiler/rustc_builtin_macros/src/errors.rs @@ -157,6 +157,13 @@ pub(crate) struct TestArgs { pub(crate) span: Span, } +#[derive(Diagnostic)] +#[diag(builtin_macros_autodiff)] +pub(crate) struct AutoDiffInvalidApplication { + #[primary_span] + pub(crate) span: Span, +} + #[derive(Diagnostic)] #[diag(builtin_macros_alloc_must_statics)] pub(crate) struct AllocMustStatics { diff --git a/compiler/rustc_builtin_macros/src/lib.rs b/compiler/rustc_builtin_macros/src/lib.rs index d84742c9b8293..e6d22709c83ff 100644 --- a/compiler/rustc_builtin_macros/src/lib.rs +++ b/compiler/rustc_builtin_macros/src/lib.rs @@ -30,6 +30,7 @@ use rustc_fluent_macro::fluent_messages; use rustc_span::symbol::sym; mod alloc_error_handler; +mod autodiff; mod assert; mod cfg; mod cfg_accessible; @@ -107,6 +108,7 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) { register_attr! { alloc_error_handler: alloc_error_handler::expand, + autodiff: autodiff::expand, bench: test::expand_bench, cfg_accessible: cfg_accessible::Expander, cfg_eval: cfg_eval::expand, diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 91b481074675d..a40da5f5c0606 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -42,7 +42,7 @@ use rustc_data_structures::profiling::SelfProfilerRef; use rustc_data_structures::small_c_str::SmallCStr; use rustc_errors::{FatalError, Handler, Level}; use rustc_fs_util::{link_or_copy, path_to_c_string}; -use rustc_middle::middle::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode}; use rustc_middle::ty::TyCtxt; use rustc_session::config::{self, Lto, OutputType, Passes, SplitDwarfKind, SwitchWithOptPath}; use rustc_session::Session; diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 200cc4528883b..5b1b2fcd58b3d 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -45,7 +45,7 @@ use rustc_errors::{DiagnosticMessage, ErrorGuaranteed, FatalError, Handler, Subd use rustc_fluent_macro::fluent_messages; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; -use rustc_middle::middle::autodiff_attrs::AutoDiffItem; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_middle::ty::TyCtxt; use rustc_middle::util::Providers; use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 94285ed817ef5..026d12121e535 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1,7 +1,7 @@ #![allow(non_camel_case_types)] #![allow(non_upper_case_globals)] -use rustc_middle::middle::autodiff_attrs::DiffActivity; +use rustc_ast::expand::autodiff_attrs::DiffActivity; use super::debuginfo::{ DIArray, DIBasicType, DIBuilder, DICompositeType, DIDerivedType, DIDescriptor, DIEnumerator, diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 091ddaa3cf213..c45f5ec5e7005 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -1,5 +1,5 @@ use crate::llvm; -use rustc_middle::middle::typetree::{Kind, TypeTree}; +use rustc_ast::expand::typetree::{Kind, TypeTree}; pub fn to_enzyme_typetree( tree: TypeTree, diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index d7bcc92abc667..b8f75f71a112a 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -5,7 +5,7 @@ use crate::ModuleCodegen; use rustc_data_structures::{fx::FxHashMap, memmap::Mmap}; use rustc_errors::FatalError; -use rustc_middle::middle::autodiff_attrs::AutoDiffItem; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use std::ffi::CString; use std::sync::Arc; diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index ad4bad943e524..649421bd3929c 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -24,7 +24,7 @@ use rustc_incremental::{ use rustc_metadata::fs::copy_to_stdout; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; -use rustc_middle::middle::autodiff_attrs::AutoDiffItem; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_middle::middle::exported_symbols::SymbolExportInfo; use rustc_middle::ty::TyCtxt; use rustc_session::config::{self, CrateType, Lto, OutFileName, OutputFilenames, OutputType}; diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 01aad2790b407..231ea4c8211ea 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -5,7 +5,7 @@ use rustc_hir as hir; use rustc_hir::def::DefKind; use rustc_hir::def_id::{DefId, LocalDefId, LOCAL_CRATE}; use rustc_hir::{lang_items, weak_lang_items::WEAK_LANG_ITEMS, LangItem}; -use rustc_middle::middle::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; use rustc_middle::mir::mono::Linkage; use rustc_middle::query::Providers; diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index 9c1be89580dc4..05995062435f5 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -5,7 +5,7 @@ use crate::{CompiledModule, ModuleCodegen}; use rustc_data_structures::fx::FxHashMap; use rustc_errors::{FatalError, Handler}; use rustc_middle::dep_graph::WorkProduct; -use rustc_middle::middle::autodiff_attrs::AutoDiffItem; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; pub trait WriteBackendMethods: 'static + Sized + Clone { type Module: Send + Sync; diff --git a/compiler/rustc_expand/src/expand.rs b/compiler/rustc_expand/src/expand.rs index f87f4aba2b9ea..93282c8b2b571 100644 --- a/compiler/rustc_expand/src/expand.rs +++ b/compiler/rustc_expand/src/expand.rs @@ -406,6 +406,7 @@ impl<'a, 'b> MacroExpander<'a, 'b> { } /// Recursively expand all macro invocations in this AST fragment. + /// Manuel: Add autodiff pub fn fully_expand_fragment(&mut self, input_fragment: AstFragment) -> AstFragment { let orig_expansion_data = self.cx.current_expansion.clone(); let orig_force_mode = self.cx.force_mode; diff --git a/compiler/rustc_feature/src/builtin_attrs.rs b/compiler/rustc_feature/src/builtin_attrs.rs index 2ed334569995b..c0c883915b3c7 100644 --- a/compiler/rustc_feature/src/builtin_attrs.rs +++ b/compiler/rustc_feature/src/builtin_attrs.rs @@ -359,6 +359,11 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[ template!(Word, List: r#""...""#), DuplicatesOk, ), + ungated!( + autodiff, Normal, + template!(List: r#""...""#), + DuplicatesOk, + ), // Limits: ungated!(recursion_limit, CrateLevel, template!(NameValueStr: "N"), FutureWarnFollowing), diff --git a/compiler/rustc_middle/src/arena.rs b/compiler/rustc_middle/src/arena.rs index acb0a25f087eb..f341a645e7d81 100644 --- a/compiler/rustc_middle/src/arena.rs +++ b/compiler/rustc_middle/src/arena.rs @@ -97,7 +97,7 @@ macro_rules! arena_types { [] upvars_mentioned: rustc_data_structures::fx::FxIndexMap, [] object_safety_violations: rustc_middle::traits::ObjectSafetyViolation, [] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>, - [] autodiff_item: rustc_middle::middle::autodiff_attrs::AutoDiffItem, + [] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem, [decode] attribute: rustc_ast::Attribute, [] name_set: rustc_data_structures::unord::UnordSet, [] ordered_name_set: rustc_data_structures::fx::FxIndexSet, diff --git a/compiler/rustc_middle/src/middle/mod.rs b/compiler/rustc_middle/src/middle/mod.rs index 43e60c2571cc0..786ad3361ff2e 100644 --- a/compiler/rustc_middle/src/middle/mod.rs +++ b/compiler/rustc_middle/src/middle/mod.rs @@ -1,4 +1,3 @@ -pub mod autodiff_attrs; pub mod codegen_fn_attrs; pub mod debugger_visualizer; pub mod dependency_format; diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index 8d85928374b5c..b16cfaf0e01a6 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -10,7 +10,7 @@ use crate::dep_graph; use crate::infer::canonical::{self, Canonical}; use crate::lint::LintExpectation; use crate::metadata::ModChild; -use crate::middle::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem}; use crate::middle::codegen_fn_attrs::CodegenFnAttrs; use crate::middle::debugger_visualizer::DebuggerVisualizerFile; use crate::middle::exported_symbols::{ExportedSymbol, SymbolExportInfo}; diff --git a/compiler/rustc_monomorphize/Cargo.toml b/compiler/rustc_monomorphize/Cargo.toml index b75941e71989a..e2758cc330d8d 100644 --- a/compiler/rustc_monomorphize/Cargo.toml +++ b/compiler/rustc_monomorphize/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" serde = "1" serde_json = "1" tracing = "0.1" +rustc_ast = { path = "../rustc_ast" } rustc_data_structures = { path = "../rustc_data_structures" } rustc_errors = { path = "../rustc_errors" } rustc_hir = { path = "../rustc_hir" } diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 403e16ff49a60..e30eec6787f06 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -93,7 +93,7 @@ //! inlining, even when they are not marked `#[inline]`. use rustc_symbol_mangling::symbol_name_for_instance_in_crate; -use rustc_middle::middle::typetree::{Kind, Type, TypeTree}; +use rustc_ast::expand::typetree::{Kind, Type, TypeTree}; use rustc_target::abi::FieldsShape; use std::cmp; @@ -107,7 +107,7 @@ use rustc_data_structures::sync; use rustc_hir::def::DefKind; use rustc_hir::def_id::{DefId, DefIdSet, LOCAL_CRATE}; use rustc_hir::definitions::DefPathDataName; -use rustc_middle::middle::autodiff_attrs::AutoDiffItem; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags; use rustc_middle::middle::exported_symbols::{SymbolExportInfo, SymbolExportLevel}; use rustc_middle::mir::mono::{ diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index a461745d9162c..306860e30453f 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -437,6 +437,7 @@ symbols! { attributes, augmented_assignments, auto_traits, + autodiff, autodiff_into, automatically_derived, avx, From c4cf7d63ed15c30793802841641b3ae514d78c48 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 10 Nov 2023 16:53:52 -0500 Subject: [PATCH 14/17] move AD types from rustc_middle to rustc_ast --- .github/workflows/enzyme-ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index 4064d4709a5ed..edc48879479e8 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -1,6 +1,6 @@ name: Rust CI -on: +on: push: branches: - master @@ -13,7 +13,7 @@ jobs: build: name: Rust Integration CI LLVM ${{ matrix.llvm }} ${{ matrix.build }} ${{ matrix.os }} runs-on: ${{ matrix.os }} - + strategy: fail-fast: false matrix: @@ -29,7 +29,7 @@ jobs: run: | mkdir build cd build - ../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-clang --enable-lld --enable-option-checking --enable-ninja --disable-docs + ../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-lld --enable-option-checking --enable-ninja --disable-docs ../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc rustup toolchain link enzyme `pwd`/build/`rustup target list --installed`/stage1 rustup toolchain install nightly # enables -Z unstable-options From efad9fdf92d432441e88586b4092033bb3dfcb65 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 10 Nov 2023 17:14:19 -0500 Subject: [PATCH 15/17] remove old types from rustc_middle --- .../rustc_middle/src/middle/autodiff_attrs.rs | 94 ------------------- compiler/rustc_middle/src/middle/mod.rs | 1 - compiler/rustc_middle/src/middle/typetree.rs | 39 -------- 3 files changed, 134 deletions(-) delete mode 100644 compiler/rustc_middle/src/middle/autodiff_attrs.rs delete mode 100644 compiler/rustc_middle/src/middle/typetree.rs diff --git a/compiler/rustc_middle/src/middle/autodiff_attrs.rs b/compiler/rustc_middle/src/middle/autodiff_attrs.rs deleted file mode 100644 index 2412df725fe2b..0000000000000 --- a/compiler/rustc_middle/src/middle/autodiff_attrs.rs +++ /dev/null @@ -1,94 +0,0 @@ -use crate::middle::typetree::TypeTree; -use std::str::FromStr; - -#[allow(dead_code)] -#[derive(Clone, Copy, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] -pub enum DiffMode { - Inactive, - Source, - Forward, - Reverse, -} - -#[allow(dead_code)] -#[derive(Clone, Copy, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] -pub enum DiffActivity { - None, - Active, - Const, - Duplicated, - DuplicatedNoNeed, -} - -impl FromStr for DiffActivity { - type Err = (); - - fn from_str(s: &str) -> Result { - match s { - "None" => Ok(DiffActivity::None), - "Active" => Ok(DiffActivity::Active), - "Const" => Ok(DiffActivity::Const), - "Duplicated" => Ok(DiffActivity::Duplicated), - "DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed), - _ => Err(()), - } - } -} - -#[allow(dead_code)] -#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] -pub struct AutoDiffAttrs { - pub mode: DiffMode, - pub ret_activity: DiffActivity, - pub input_activity: Vec, -} - -impl AutoDiffAttrs { - pub fn inactive() -> Self { - AutoDiffAttrs { - mode: DiffMode::Inactive, - ret_activity: DiffActivity::None, - input_activity: Vec::new(), - } - } - - pub fn is_active(&self) -> bool { - match self.mode { - DiffMode::Inactive => false, - _ => true, - } - } - - pub fn is_source(&self) -> bool { - match self.mode { - DiffMode::Source => true, - _ => false, - } - } - pub fn apply_autodiff(&self) -> bool { - match self.mode { - DiffMode::Inactive => false, - DiffMode::Source => false, - _ => true, - } - } - - pub fn into_item( - self, - source: String, - target: String, - inputs: Vec, - output: TypeTree, - ) -> AutoDiffItem { - AutoDiffItem { source, target, inputs, output, attrs: self } - } -} - -#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] -pub struct AutoDiffItem { - pub source: String, - pub target: String, - pub attrs: AutoDiffAttrs, - pub inputs: Vec, - pub output: TypeTree, -} diff --git a/compiler/rustc_middle/src/middle/mod.rs b/compiler/rustc_middle/src/middle/mod.rs index 786ad3361ff2e..85c5af9ca13cb 100644 --- a/compiler/rustc_middle/src/middle/mod.rs +++ b/compiler/rustc_middle/src/middle/mod.rs @@ -32,7 +32,6 @@ pub mod privacy; pub mod region; pub mod resolve_bound_vars; pub mod stability; -pub mod typetree; pub fn provide(providers: &mut crate::query::Providers) { limits::provide(providers); diff --git a/compiler/rustc_middle/src/middle/typetree.rs b/compiler/rustc_middle/src/middle/typetree.rs deleted file mode 100644 index 4049d32540bd2..0000000000000 --- a/compiler/rustc_middle/src/middle/typetree.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::fmt; -#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] -pub enum Kind { - Anything, - Integer, - Pointer, - Half, - Float, - Double, - Unknown, -} - -#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] -pub struct TypeTree(pub Vec); - -#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] -pub struct Type { - pub offset: isize, - pub size: usize, - pub kind: Kind, - pub child: TypeTree, -} - -impl Type { - pub fn add_offset(self, add: isize) -> Self { - let offset = match self.offset { - -1 => add, - x => add + x, - }; - - Self { size: self.size, kind: self.kind, child: self.child, offset } - } -} - -impl fmt::Display for Type { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - ::fmt(self, f) - } -} From d9e9c9ce1c98ace037c091c90db19d908f3043fd Mon Sep 17 00:00:00 2001 From: Jed Brown Date: Thu, 21 Dec 2023 20:00:30 -0700 Subject: [PATCH 16/17] autodiff: no_std support (switch std:: to core::) I can now do this no a device function and the IR looks okay by eyeball. argo +enzyme rustc --release --target=nvptx64-nvidia-cuda -Zbuild-std -- --emit=llvm-ir --- library/autodiff/src/gen.rs | 8 ++++---- .../autodiff/tests/expand/forward_duplicated.expanded.rs | 4 ++-- .../tests/expand/forward_duplicated_return.expanded.rs | 4 ++-- .../autodiff/tests/expand/reverse_duplicated.expanded.rs | 4 ++-- .../tests/expand/reverse_return_array.expanded.rs | 4 ++-- .../tests/expand/reverse_return_mixed.expanded.rs | 4 ++-- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/library/autodiff/src/gen.rs b/library/autodiff/src/gen.rs index 68aae56ea3311..59b37b997e8cf 100644 --- a/library/autodiff/src/gen.rs +++ b/library/autodiff/src/gen.rs @@ -107,11 +107,11 @@ pub(crate) fn adjoint_fnc(item: &DiffItem) -> TokenStream { res_inputs.push(input.clone()); match (item.header.mode, activity, is_ref_mut(&input)) { - (Mode::Forward, Activity::Duplicated|Activity::DuplicatedNoNeed, Some(true)) => { + (Mode::Forward, Activity::Duplicated | Activity::DuplicatedNoNeed, Some(true)) => { res_inputs.push(as_ref_mut(&input, "grad", true)); add_inputs.push(as_ref_mut(&input, "grad", true)); } - (Mode::Forward, Activity::Duplicated|Activity::DuplicatedNoNeed, Some(false)) => { + (Mode::Forward, Activity::Duplicated | Activity::DuplicatedNoNeed, Some(false)) => { res_inputs.push(as_ref_mut(&input, "dual", false)); add_inputs.push(as_ref_mut(&input, "dual", false)); out_type.clone().map(|x| outputs.push(x)); @@ -203,9 +203,9 @@ pub(crate) fn adjoint_fnc(item: &DiffItem) -> TokenStream { }; let body = quote!({ - std::hint::black_box((#call_ident(#(#inputs,)*), #(#add_inputs,)*)); + core::hint::black_box((#call_ident(#(#inputs,)*), #(#add_inputs,)*)); - std::hint::black_box(unsafe { std::mem::zeroed() }) + core::hint::black_box(unsafe { core::mem::zeroed() }) }); let header = generate_header(&item); diff --git a/library/autodiff/tests/expand/forward_duplicated.expanded.rs b/library/autodiff/tests/expand/forward_duplicated.expanded.rs index bf3890154ab8e..c3e30939f92d4 100644 --- a/library/autodiff/tests/expand/forward_duplicated.expanded.rs +++ b/library/autodiff/tests/expand/forward_duplicated.expanded.rs @@ -5,6 +5,6 @@ fn square(a: &Vec, b: &mut f32) { } #[autodiff_into(Forward, Const, Duplicated, Duplicated)] fn d_square(a: &Vec, dual_a: &Vec, b: &mut f32, grad_b: &mut f32) { - std::hint::black_box((square(a, b), dual_a, grad_b)); - std::hint::black_box(unsafe { std::mem::zeroed() }) + core::hint::black_box((square(a, b), dual_a, grad_b)); + core::hint::black_box(unsafe { core::mem::zeroed() }) } diff --git a/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs b/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs index a3754de7ab70b..12b3cb898797f 100644 --- a/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs +++ b/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs @@ -10,6 +10,6 @@ fn d_square2( b: &Vec, dual_b: &Vec, ) -> (f32, f32, f32) { - std::hint::black_box((square2(a, b), dual_a, dual_b)); - std::hint::black_box(unsafe { std::mem::zeroed() }) + core::hint::black_box((square2(a, b), dual_a, dual_b)); + core::hint::black_box(unsafe { core::mem::zeroed() }) } diff --git a/library/autodiff/tests/expand/reverse_duplicated.expanded.rs b/library/autodiff/tests/expand/reverse_duplicated.expanded.rs index 60c0d7f2f696b..04f462a09b2dd 100644 --- a/library/autodiff/tests/expand/reverse_duplicated.expanded.rs +++ b/library/autodiff/tests/expand/reverse_duplicated.expanded.rs @@ -5,6 +5,6 @@ fn square(a: &Vec, b: &mut f32) { } #[autodiff_into(Reverse, Const, Duplicated, Duplicated)] fn d_square(a: &Vec, grad_a: &mut Vec, b: &mut f32, grad_b: &f32) { - std::hint::black_box((square(a, b), grad_a, grad_b)); - std::hint::black_box(unsafe { std::mem::zeroed() }) + core::hint::black_box((square(a, b), grad_a, grad_b)); + core::hint::black_box(unsafe { core::mem::zeroed() }) } diff --git a/library/autodiff/tests/expand/reverse_return_array.expanded.rs b/library/autodiff/tests/expand/reverse_return_array.expanded.rs index 5b784157fea7b..48e0d99fd2797 100644 --- a/library/autodiff/tests/expand/reverse_return_array.expanded.rs +++ b/library/autodiff/tests/expand/reverse_return_array.expanded.rs @@ -5,6 +5,6 @@ fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 { } #[autodiff_into(Reverse, Active, Duplicated)] fn d_array(arr: &[[[f32; 2]; 2]; 2], grad_arr: &mut [[[f32; 2]; 2]; 2], tang_y: f32) { - std::hint::black_box((array(arr), grad_arr, tang_y)); - std::hint::black_box(unsafe { std::mem::zeroed() }) + core::hint::black_box((array(arr), grad_arr, tang_y)); + core::hint::black_box(unsafe { core::mem::zeroed() }) } diff --git a/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs b/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs index f49864fb7e9b9..3517912222615 100644 --- a/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs +++ b/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs @@ -12,6 +12,6 @@ fn d_sqrt( d: f32, tang_y: f32, ) -> (f32, f32) { - std::hint::black_box((sqrt(a, b, c, d), grad_b, tang_y)); - std::hint::black_box(unsafe { std::mem::zeroed() }) + core::hint::black_box((sqrt(a, b, c, d), grad_b, tang_y)); + core::hint::black_box(unsafe { core::mem::zeroed() }) } From 4d7371963c095a4bd2db0f06c3973a9d706bb1e4 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 18 Feb 2024 17:03:44 -0500 Subject: [PATCH 17/17] just touching a file --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2ac2d3b38d679..7b880aeb22d4a 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# The Rust Programming Language +Enzyme +# The Rust Programming Language +Enzyme (git history) [![Rust Community](https://img.shields.io/badge/Rust_Community%20-Join_us-brightgreen?style=plastic&logo=rust)](https://www.rust-lang.org/community)