Skip to content

Commit 44ee51d

Browse files
authored
Rollup merge of #142444 - KMJ-007:autodiff-codegen-test, r=ZuseZ4
adding run-make test to autodiff r? `@ZuseZ4`
2 parents fff53f1 + 80224c3 commit 44ee51d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+1315
-0
lines changed

compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,10 @@ struct LLVMRustSanitizerOptions {
700700
#ifdef ENZYME
701701
extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB,
702702
/* augmentPassBuilder */ bool);
703+
704+
extern "C" {
705+
extern llvm::cl::opt<std::string> EnzymeFunctionToAnalyze;
706+
}
703707
#endif
704708

705709
extern "C" LLVMRustResult LLVMRustOptimize(
@@ -1069,6 +1073,15 @@ extern "C" LLVMRustResult LLVMRustOptimize(
10691073
return LLVMRustResult::Failure;
10701074
}
10711075

1076+
// Check if PrintTAFn was used and add type analysis pass if needed
1077+
if (!EnzymeFunctionToAnalyze.empty()) {
1078+
if (auto Err = PB.parsePassPipeline(MPM, "print-type-analysis")) {
1079+
std::string ErrMsg = toString(std::move(Err));
1080+
LLVMRustSetLastError(ErrMsg.c_str());
1081+
return LLVMRustResult::Failure;
1082+
}
1083+
}
1084+
10721085
if (PrintAfterEnzyme) {
10731086
// Handle the Rust flag `-Zautodiff=PrintModAfter`.
10741087
std::string Banner = "Module after EnzymeNewPM";
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Autodiff Type-Trees Type Analysis Tests
2+
3+
This directory contains run-make tests for the autodiff type-trees type analysis functionality. These tests verify that the autodiff compiler correctly analyzes and tracks type information for different Rust types during automatic differentiation.
4+
5+
## What These Tests Do
6+
7+
Each test compiles a simple Rust function with the `#[autodiff_reverse]` attribute and verifies that the compiler:
8+
9+
1. **Correctly identifies type information** in the generated LLVM IR
10+
2. **Tracks type annotations** for variables and operations
11+
3. **Preserves type context** through the autodiff transformation process
12+
13+
The tests capture the stdout from the autodiff compiler (which contains type analysis information) and verify it matches expected patterns using FileCheck.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
2+
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float}
3+
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
4+
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
5+
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 4, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}
6+
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
7+
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
8+
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
9+
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 8, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
10+
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
11+
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
12+
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
13+
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
14+
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float}:{}
15+
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float}
16+
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
17+
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
18+
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 4, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}
19+
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
20+
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
21+
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
22+
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 8, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
23+
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
24+
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
25+
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
26+
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
5+
#[autodiff_reverse(d_square, Duplicated, Active)]
6+
#[no_mangle]
7+
fn callee(x: &[f32; 3]) -> f32 {
8+
x[0] * x[0] + x[1] * x[1] + x[2] * x[2]
9+
}
10+
11+
fn main() {
12+
let x = [1.0f32, 2.0, 3.0];
13+
let mut df_dx = [0.0f32; 3];
14+
let out = callee(&x);
15+
let out_ = d_square(&x, &mut df_dx, 1.0);
16+
assert_eq!(out, out_);
17+
assert_eq!(2.0, df_dx[0]);
18+
assert_eq!(4.0, df_dx[1]);
19+
assert_eq!(6.0, df_dx[2]);
20+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use std::fs;
5+
6+
use run_make_support::{llvm_filecheck, rfs, rustc};
7+
8+
fn main() {
9+
// Compile the Rust file with the required flags, capturing both stdout and stderr
10+
let output = rustc()
11+
.input("array.rs")
12+
.arg("-Zautodiff=Enable,PrintTAFn=callee")
13+
.arg("-Zautodiff=NoPostopt")
14+
.opt_level("3")
15+
.arg("-Clto=fat")
16+
.arg("-g")
17+
.run();
18+
19+
let stdout = output.stdout_utf8();
20+
let stderr = output.stderr_utf8();
21+
22+
// Write the outputs to files
23+
rfs::write("array.stdout", stdout);
24+
rfs::write("array.stderr", stderr);
25+
26+
// Run FileCheck on the stdout using the check file
27+
llvm_filecheck().patterns("array.check").stdin_buf(rfs::read("array.stdout")).run();
28+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
2+
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
3+
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
4+
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
5+
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
6+
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
7+
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
8+
// CHECK-DAG: %{{[0-9]+}} = phi float [ 0.000000e+00, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
9+
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
10+
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
11+
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x [2 x float]], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
12+
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
13+
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
14+
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
15+
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
16+
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
17+
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
18+
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x float], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
19+
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
20+
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
21+
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
22+
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
23+
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw float, ptr %{{[0-9]+}}, i64 %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
24+
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
25+
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
26+
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
27+
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
28+
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Float@float}:{}
29+
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
30+
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
31+
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
32+
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
33+
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
34+
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
35+
// CHECK-DAG: %{{[0-9]+}} = phi float [ 0.000000e+00, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
36+
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
37+
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
38+
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x [2 x float]], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
39+
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
40+
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], !dbg !{{[0-9]+}}: {[-1]:Float@float}
41+
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
42+
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
43+
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
44+
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
45+
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw [2 x float], ptr %{{[0-9]+}}, i64 %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
46+
// CHECK-DAG: br label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
47+
// CHECK-DAG: %{{[0-9]+}} = phi float [ %{{[0-9]+}}, %{{[0-9]+}} ], [ %{{[0-9]+}}, %{{[0-9]+}} ]: {[-1]:Float@float}
48+
// CHECK-DAG: %{{[0-9]+}} = phi i1 [ true, %{{[0-9]+}} ], [ false, %{{[0-9]+}} ]: {[-1]:Integer}
49+
// CHECK-DAG: %{{[0-9]+}} = phi i64 [ 0, %{{[0-9]+}} ], [ 1, %{{[0-9]+}} ]: {[-1]:Integer}
50+
// CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw float, ptr %{{[0-9]+}}, i64 %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
51+
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
52+
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
53+
// CHECK-DAG: %{{[0-9]+}} = fadd float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
54+
// CHECK-DAG: br i1 %{{[0-9]+}}, label %{{[0-9]+}}, label %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
5+
#[autodiff_reverse(d_square, Duplicated, Active)]
6+
#[no_mangle]
7+
fn callee(x: &[[[f32; 2]; 2]; 2]) -> f32 {
8+
let mut sum = 0.0;
9+
for i in 0..2 {
10+
for j in 0..2 {
11+
for k in 0..2 {
12+
sum += x[i][j][k] * x[i][j][k];
13+
}
14+
}
15+
}
16+
sum
17+
}
18+
19+
fn main() {
20+
let x = [[[1.0f32, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
21+
let mut df_dx = [[[0.0f32; 2]; 2]; 2];
22+
let out = callee(&x);
23+
let out_ = d_square(&x, &mut df_dx, 1.0);
24+
assert_eq!(out, out_);
25+
for i in 0..2 {
26+
for j in 0..2 {
27+
for k in 0..2 {
28+
assert_eq!(df_dx[i][j][k], 2.0 * x[i][j][k]);
29+
}
30+
}
31+
}
32+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use std::fs;
5+
6+
use run_make_support::{llvm_filecheck, rfs, rustc};
7+
8+
fn main() {
9+
// Compile the Rust file with the required flags, capturing both stdout and stderr
10+
let output = rustc()
11+
.input("array3d.rs")
12+
.arg("-Zautodiff=Enable,PrintTAFn=callee")
13+
.arg("-Zautodiff=NoPostopt")
14+
.opt_level("3")
15+
.arg("-Clto=fat")
16+
.arg("-g")
17+
.run();
18+
19+
let stdout = output.stdout_utf8();
20+
let stderr = output.stderr_utf8();
21+
22+
// Write the outputs to files
23+
rfs::write("array3d.stdout", stdout);
24+
rfs::write("array3d.stderr", stderr);
25+
26+
// Run FileCheck on the stdout using the check file
27+
llvm_filecheck().patterns("array3d.check").stdin_buf(rfs::read("array3d.stdout")).run();
28+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{}
2+
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Float@float}
3+
// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !{{[0-9]+}}, !align !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
4+
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
5+
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
6+
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
7+
// CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Float@float}:{}
8+
// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Float@float}
9+
// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !{{[0-9]+}}, !align !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@float}
10+
// CHECK-DAG: %{{[0-9]+}} = load float, ptr %{{[0-9]+}}, align 4, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {[-1]:Float@float}
11+
// CHECK-DAG: %{{[0-9]+}} = fmul float %{{[0-9]+}}, %{{[0-9]+}}, !dbg !{{[0-9]+}}: {[-1]:Float@float}
12+
// CHECK-DAG: ret float %{{[0-9]+}}, !dbg !{{[0-9]+}}: {}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
5+
#[autodiff_reverse(d_square, Duplicated, Active)]
6+
#[no_mangle]
7+
fn callee(x: &Box<f32>) -> f32 {
8+
**x * **x
9+
}
10+
11+
fn main() {
12+
let x = Box::new(7.0f32);
13+
let mut df_dx = Box::new(0.0f32);
14+
let out = callee(&x);
15+
let out_ = d_square(&x, &mut df_dx, 1.0);
16+
assert_eq!(out, out_);
17+
assert_eq!(14.0, *df_dx);
18+
}

0 commit comments

Comments
 (0)