diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index f437e52355ef0..ad4bd05a08b46 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -1093,14 +1093,11 @@ pub(crate) unsafe fn differentiate( } // Before dumping the module, we want all the tt to become part of the module. - for item in &diff_items { + for (i, item) in diff_items.iter().enumerate() { let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; let llvm_data_layout = std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes()) .expect("got a non-UTF8 data-layout from LLVM"); - //let input_tts: Vec = - // item.inputs.iter().map(|x| to_enzyme_typetree(x.clone(), llvm_data_layout, llcx)).collect(); - //let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); let tt: FncTree = FncTree { args: item.inputs.clone(), ret: item.output.clone(), @@ -1108,9 +1105,18 @@ pub(crate) unsafe fn differentiate( let name = CString::new(item.source.clone()).unwrap(); let fn_def: &llvm::Value = llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap(); crate::builder::add_tt2(llmod, llcx, fn_def, tt); + + // Before dumping the module, we also might want to add dummy functions, which will + // trigger the LLVMEnzyme pass to run on them, if we invoke the opt binary. + // This is super helpfull if we want to create a MWE bug reproducer, e.g. to run in + // Enzyme's compiler explorer. TODO: Can we run llvm-extract on the module to remove all other functions? + if std::env::var("ENZYME_OPT").is_ok() { + dbg!("Enable extra debug helper to debug Enzyme through the opt plugin"); + crate::builder::add_opt_dbg_helper(llmod, llcx, fn_def, item.attrs.clone(), i); + } } - if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() { + if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() || std::env::var("ENZYME_OPT").is_ok(){ unsafe { LLVMDumpModule(llmod); } diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 733dd76ed105c..c4341250e121e 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -8,6 +8,7 @@ use crate::type_::Type; use crate::type_of::LayoutLlvmExt; use crate::value::Value; use libc::{c_char, c_uint}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity}; use rustc_codegen_ssa::common::{IntPredicate, RealPredicate, SynchronizationScope, TypeKind}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::mir::place::PlaceRef; @@ -30,6 +31,7 @@ use std::iter; use std::ops::Deref; use std::ptr; +use rustc_ast::expand::autodiff_attrs::DiffMode; use crate::typetree::to_enzyme_typetree; use rustc_ast::expand::typetree::{TypeTree, FncTree}; @@ -136,6 +138,7 @@ macro_rules! builder_methods_for_value_instructions { })+ } } + pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll Value, tt: FncTree) { let inputs = tt.args; let ret_tt: TypeTree = tt.ret; @@ -180,6 +183,107 @@ pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: } } +#[allow(unused)] +pub fn add_opt_dbg_helper<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, val: &'ll Value, attrs: AutoDiffAttrs, i: usize) { + //pub mode: DiffMode, + //pub ret_activity: DiffActivity, + //pub input_activity: Vec, + let inputs = attrs.input_activity; + let outputs = attrs.ret_activity; + let ad_name = match attrs.mode { + DiffMode::Forward => "__enzyme_fwddiff", + DiffMode::Reverse => "__enzyme_autodiff", + DiffMode::ForwardFirst => "__enzyme_fwddiff", + DiffMode::ReverseFirst => "__enzyme_autodiff", + _ => panic!("Why are we here?"), + }; + + // Assuming that our val is the fnc square, want to generate the following llvm-ir: + // declare double @__enzyme_autodiff(...) + // + // define double @dsquare(double %x) { + // entry: + // %0 = tail call double (...) @__enzyme_autodiff(double (double)* nonnull @square, double %x) + // ret double %0 + // } + + let mut final_num_args; + unsafe { + let fn_ty = llvm::LLVMRustGetFunctionType(val); + let ret_ty = llvm::LLVMGetReturnType(fn_ty); + + // First we add the declaration of the __enzyme function + let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True); + let ad_fn = llvm::LLVMRustGetOrInsertFunction( + llmod, + ad_name.as_ptr() as *const c_char, + ad_name.len().try_into().unwrap(), + enzyme_ty, + ); + + let wrapper_name = String::from("enzyme_opt_helper_") + i.to_string().as_str(); + let wrapper_fn = llvm::LLVMRustGetOrInsertFunction( + llmod, + wrapper_name.as_ptr() as *const c_char, + wrapper_name.len().try_into().unwrap(), + fn_ty, + ); + let entry = llvm::LLVMAppendBasicBlockInContext(llcx, wrapper_fn, "entry".as_ptr() as *const c_char); + let builder = llvm::LLVMCreateBuilderInContext(llcx); + llvm::LLVMPositionBuilderAtEnd(builder, entry); + let num_args = llvm::LLVMCountParams(wrapper_fn); + let mut args = Vec::with_capacity(num_args as usize + 1); + args.push(val); + // metadata !"enzyme_const" + let enzyme_const = llvm::LLVMMDStringInContext(llcx, "enzyme_const".as_ptr() as *const c_char, 12); + let enzyme_out = llvm::LLVMMDStringInContext(llcx, "enzyme_out".as_ptr() as *const c_char, 10); + let enzyme_dup = llvm::LLVMMDStringInContext(llcx, "enzyme_dup".as_ptr() as *const c_char, 10); + let enzyme_dupnoneed = llvm::LLVMMDStringInContext(llcx, "enzyme_dupnoneed".as_ptr() as *const c_char, 16); + final_num_args = num_args * 2 + 1; + for i in 0..num_args { + let arg = llvm::LLVMGetParam(wrapper_fn, i); + let activity = inputs[i as usize]; + let (activity, duplicated): (&Value, bool) = match activity { + DiffActivity::None => panic!(), + DiffActivity::Const => (enzyme_const, false), + DiffActivity::Active => (enzyme_out, false), + DiffActivity::ActiveOnly => (enzyme_out, false), + DiffActivity::Dual => (enzyme_dup, true), + DiffActivity::DualOnly => (enzyme_dupnoneed, true), + DiffActivity::Duplicated => (enzyme_dup, true), + DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true), + DiffActivity::FakeActivitySize => (enzyme_const, false), + }; + args.push(activity); + args.push(arg); + if duplicated { + final_num_args += 1; + args.push(arg); + } + } + + // declare void @__enzyme_autodiff(...) + + // define void @enzyme_opt_helper_0(ptr %0, ptr %1) { + // call void (...) @__enzyme_autodiff(ptr @ffff, ptr %0, ptr %1) + // ret void + // } + + let call = llvm::LLVMBuildCall2(builder, enzyme_ty, ad_fn, args.as_mut_ptr(), final_num_args as usize, ad_name.as_ptr() as *const c_char); + let void_ty = llvm::LLVMVoidTypeInContext(llcx); + if llvm::LLVMTypeOf(call) != void_ty { + llvm::LLVMBuildRet(builder, call); + } else { + llvm::LLVMBuildRetVoid(builder); + } + llvm::LLVMDisposeBuilder(builder); + + let _fnc_ok = + llvm::LLVMVerifyFunction(wrapper_fn, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); + } + +} + fn add_tt<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context,val: &'ll Value, tt: FncTree) { let inputs = tt.args; let _ret: TypeTree = tt.ret;