From 8abdb9f2336e9026e924ad8e12d727f65dd4f01e Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Thu, 28 Jul 2022 19:28:35 +0200 Subject: [PATCH] Cabs calling convention (#749) * Handle array types in TypeAnalysis * Handle array calling convention of cabs * Add tests --- enzyme/Enzyme/AdjointGenerator.h | 84 ++++++++++++++++--- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 16 +++- enzyme/test/Enzyme/ForwardMode/cabs-const.ll | 28 +++++++ enzyme/test/Enzyme/ForwardMode/cabs.ll | 31 +++++++ enzyme/test/Enzyme/ForwardMode/cabs2-const.ll | 38 +++++++++ enzyme/test/Enzyme/ForwardMode/cabs2.ll | 36 ++++++++ enzyme/test/Enzyme/ForwardModeVector/cabs.ll | 50 +++++++++++ enzyme/test/Enzyme/ForwardModeVector/cabs2.ll | 55 ++++++++++++ enzyme/test/Enzyme/ReverseMode/cabs-const.ll | 29 +++++++ enzyme/test/Enzyme/ReverseMode/cabs2-const.ll | 34 ++++++++ enzyme/test/Enzyme/ReverseMode/cabs2.ll | 36 ++++++++ 11 files changed, 424 insertions(+), 13 deletions(-) create mode 100644 enzyme/test/Enzyme/ForwardMode/cabs-const.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/cabs.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/cabs2-const.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/cabs2.ll create mode 100644 enzyme/test/Enzyme/ForwardModeVector/cabs.ll create mode 100644 enzyme/test/Enzyme/ForwardModeVector/cabs2.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/cabs-const.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/cabs2-const.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/cabs2.ll diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 2af84e8f7d3d..80fad09b1a2f 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -9945,11 +9945,18 @@ class AdjointGenerator Value *d = Builder2.CreateCall(called, args); if (args.size() == 2) { - Value *op0 = diffe(orig->getArgOperand(0), Builder2); - - Value *op1 = diffe(orig->getArgOperand(1), Builder2); + Value *op0 = gutils->isConstantValue(orig->getArgOperand(0)) + ? nullptr + : diffe(orig->getArgOperand(0), Builder2); + Value *op1 = gutils->isConstantValue(orig->getArgOperand(1)) + ? nullptr + : diffe(orig->getArgOperand(1), Builder2); + + auto rule1 = [&](Value *op) { + return Builder2.CreateFMul(args[0], Builder2.CreateFDiv(op, d)); + }; - auto rule = [&](Value *op0, Value *op1) { + auto rule2 = [&](Value *op0, Value *op1) { Value *dif1 = Builder2.CreateFMul(args[0], Builder2.CreateFDiv(op0, d)); Value *dif2 = @@ -9957,14 +9964,46 @@ class AdjointGenerator return Builder2.CreateFAdd(dif1, dif2); }; - Value *dif = - applyChainRule(call.getType(), Builder2, rule, op0, op1); + Value *dif; + if (op0 && op1) + dif = applyChainRule(call.getType(), Builder2, rule2, op0, op1); + else if (op0) + dif = applyChainRule(call.getType(), Builder2, rule1, op0); + else if (op1) + dif = applyChainRule(call.getType(), Builder2, rule1, op1); + else + llvm_unreachable( + "trying to differentiate a constant instruction"); + setDiffe(orig, dif, Builder2); return; - } else { - llvm::errs() << *orig << "\n"; - llvm_unreachable("unknown calling convention found for cabs"); + } else if (args.size() == 1) { + if (auto AT = dyn_cast(args[0]->getType())) { + if (AT->getNumElements() == 2) { + Value *op = diffe(orig->getArgOperand(0), Builder2); + Value *args0 = Builder2.CreateExtractValue(args[0], 0); + Value *args1 = Builder2.CreateExtractValue(args[0], 1); + + auto rule = [&](Value *op) { + Value *op0 = Builder2.CreateExtractValue(op, 0); + Value *op1 = Builder2.CreateExtractValue(op, 1); + + Value *dif1 = + Builder2.CreateFMul(args0, Builder2.CreateFDiv(op0, d)); + Value *dif2 = + Builder2.CreateFMul(args1, Builder2.CreateFDiv(op1, d)); + return Builder2.CreateFAdd(dif1, dif2); + }; + + Value *dif = + applyChainRule(call.getType(), Builder2, rule, op); + setDiffe(orig, dif, Builder2); + return; + } + } } + llvm::errs() << *orig << "\n"; + llvm_unreachable("unknown calling convention found for cabs"); } case DerivativeMode::ReverseModeGradient: case DerivativeMode::ReverseModeCombined: { @@ -9998,10 +10037,31 @@ class AdjointGenerator Builder2.CreateFMul(args[i], div), Builder2, orig->getType()); return; - } else { - llvm::errs() << *orig << "\n"; - llvm_unreachable("unknown calling convention found for cabs"); + } else if (args.size() == 1) { + if (auto AT = dyn_cast(args[0]->getType())) { + if (AT->getNumElements() == 2) { + if (!gutils->isConstantValue(orig->getArgOperand(0))) { + Value *agg = UndefValue::get(args[0]->getType()); + agg = Builder2.CreateInsertValue( + agg, + Builder2.CreateFMul( + Builder2.CreateExtractValue(args[0], 0), div), + 0); + agg = Builder2.CreateInsertValue( + agg, + Builder2.CreateFMul( + Builder2.CreateExtractValue(args[0], 1), div), + 1); + + addToDiffe(orig->getArgOperand(0), agg, Builder2, + orig->getType()); + return; + } + } + } } + llvm::errs() << *orig << "\n"; + llvm_unreachable("unknown calling convention found for cabs"); } case DerivativeMode::ReverseModePrimal: { return; diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index ce78b307955f..81f0ff98b595 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -4240,7 +4240,21 @@ void TypeAnalyzer::visitCallInst(CallInst &call) { llvm::errs() << *T << " - " << call << "\n"; llvm_unreachable("Unknown type for libm"); } - + } else if (auto AT = dyn_cast(T)) { + assert(AT->getNumElements() >= 1); + if (AT->getElementType()->isFloatingPointTy()) + updateAnalysis( + call.getArgOperand(i), + TypeTree(ConcreteType(AT->getElementType()->getScalarType())) + .Only(-1), + &call); + else if (AT->getElementType()->isIntegerTy()) { + updateAnalysis(call.getArgOperand(i), + TypeTree(BaseType::Integer).Only(-1), &call); + } else { + llvm::errs() << *T << " - " << call << "\n"; + llvm_unreachable("Unknown type for libm"); + } } else { llvm::errs() << *T << " - " << call << "\n"; llvm_unreachable("Unknown type for libm"); diff --git a/enzyme/test/Enzyme/ForwardMode/cabs-const.ll b/enzyme/test/Enzyme/ForwardMode/cabs-const.ll new file mode 100644 index 000000000000..824eb7af09bf --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/cabs-const.ll @@ -0,0 +1,28 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %call = call double @cabs(double %x, double %y) + ret double %call +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, metadata !"enzyme_const", double %x, double %y, double 1.0) + ret double %0 +} + +declare double @cabs(double, double) + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double, double)*, ...) + + +; CHECK: define internal double @fwddiffetester(double %x, double %y, double %"y'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @cabs(double %x, double %y) +; CHECK-NEXT: %1 = fdiv fast double %"y'", %0 +; CHECK-NEXT: %2 = fmul fast double %x, %1 +; CHECK-NEXT: ret double %2 +; CHECK-NEXT:} \ No newline at end of file diff --git a/enzyme/test/Enzyme/ForwardMode/cabs.ll b/enzyme/test/Enzyme/ForwardMode/cabs.ll new file mode 100644 index 000000000000..34899c8ebae7 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/cabs.ll @@ -0,0 +1,31 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %call = call double @cabs(double %x, double %y) + ret double %call +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, double %x, double 1.0, double %y, double 1.0) + ret double %0 +} + +declare double @cabs(double, double) + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double, double)*, ...) + + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @cabs(double %x, double %y) +; CHECK-NEXT: %1 = fdiv fast double %"x'", %0 +; CHECK-NEXT: %2 = fmul fast double %x, %1 +; CHECK-NEXT: %3 = fdiv fast double %"y'", %0 +; CHECK-NEXT: %4 = fmul fast double %y, %3 +; CHECK-NEXT: %5 = fadd fast double %2, %4 +; CHECK-NEXT: ret double %5 +; CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/Enzyme/ForwardMode/cabs2-const.ll b/enzyme/test/Enzyme/ForwardMode/cabs2-const.ll new file mode 100644 index 000000000000..1a6986b4e979 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/cabs2-const.ll @@ -0,0 +1,38 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +; Function Attrs: nounwind readnone willreturn +declare double @cabs([2 x double]) + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %agg0 = insertvalue [2 x double] undef, double %x, 0 + %agg1 = insertvalue [2 x double] %agg0, double %y, 1 + %call = call double @cabs([2 x double] %agg1) + ret double %call +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, metadata !"enzyme_const", double %x, double %y, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double, double)*, ...) + + +; CHECK: define internal double @fwddiffetester(double %x, double %y, double %"y'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %agg0 = insertvalue [2 x double] undef, double %x, 0 +; CHECK-NEXT: %"agg1'ipiv" = insertvalue [2 x double] zeroinitializer, double %"y'", 1 +; CHECK-NEXT: %agg1 = insertvalue [2 x double] %agg0, double %y, 1 +; CHECK-NEXT: %0 = call fast double @cabs([2 x double] %agg1) +; CHECK-NEXT: %1 = extractvalue [2 x double] %"agg1'ipiv", 0 +; CHECK-NEXT: %2 = fdiv fast double %1, %0 +; CHECK-NEXT: %3 = fmul fast double %x, %2 +; CHECK-NEXT: %4 = fdiv fast double %"y'", %0 +; CHECK-NEXT: %5 = fmul fast double %y, %4 +; CHECK-NEXT: %6 = fadd fast double %3, %5 +; CHECK-NEXT: ret double %6 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardMode/cabs2.ll b/enzyme/test/Enzyme/ForwardMode/cabs2.ll new file mode 100644 index 000000000000..2b2d648c4aa0 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/cabs2.ll @@ -0,0 +1,36 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +; Function Attrs: nounwind readnone willreturn +declare double @cabs([2 x double]) + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %agg0 = insertvalue [2 x double] undef, double %x, 0 + %agg1 = insertvalue [2 x double] %agg0, double %y, 1 + %call = call double @cabs([2 x double] %agg1) + ret double %call +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, double %x, double 1.0, double %y, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double, double)*, ...) + + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %agg0 = insertvalue [2 x double] undef, double %x, 0 +; CHECK-NEXT: %agg1 = insertvalue [2 x double] %agg0, double %y, 1 +; CHECK-NEXT: %0 = call fast double @cabs([2 x double] %agg1) +; CHECK-NEXT: %1 = fdiv fast double %"x'", %0 +; CHECK-NEXT: %2 = fmul fast double %x, %1 +; CHECK-NEXT: %3 = fdiv fast double %"y'", %0 +; CHECK-NEXT: %4 = fmul fast double %y, %3 +; CHECK-NEXT: %5 = fadd fast double %2, %4 +; CHECK-NEXT: ret double %5 +; CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/Enzyme/ForwardModeVector/cabs.ll b/enzyme/test/Enzyme/ForwardModeVector/cabs.ll new file mode 100644 index 000000000000..0cdba08f6db1 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeVector/cabs.ll @@ -0,0 +1,50 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %call = call double @cabs(double %x, double %y) + ret double %call +} + +define [3 x double] @test_derivative(double %x, double %y) { +entry: + %0 = tail call [3 x double] (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, metadata !"enzyme_width", i64 3, double %x, double 1.0, double 1.3, double 2.0, double %y, double 1.0, double 0.0, double 2.0) + ret [3 x double] %0 +} + +declare double @cabs(double, double) + +; Function Attrs: nounwind +declare [3 x double] @__enzyme_fwddiff(double (double, double)*, ...) + + +; CHECK: define internal [3 x double] @fwddiffe3tester(double %x, [3 x double] %"x'", double %y, [3 x double] %"y'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @cabs(double %x, double %y) +; CHECK-NEXT: %1 = extractvalue [3 x double] %"x'", 0 +; CHECK-NEXT: %2 = extractvalue [3 x double] %"y'", 0 +; CHECK-NEXT: %3 = fdiv fast double %1, %0 +; CHECK-NEXT: %4 = fmul fast double %x, %3 +; CHECK-NEXT: %5 = fdiv fast double %2, %0 +; CHECK-NEXT: %6 = fmul fast double %y, %5 +; CHECK-NEXT: %7 = fadd fast double %4, %6 +; CHECK-NEXT: %8 = insertvalue [3 x double] undef, double %7, 0 +; CHECK-NEXT: %9 = extractvalue [3 x double] %"x'", 1 +; CHECK-NEXT: %10 = extractvalue [3 x double] %"y'", 1 +; CHECK-NEXT: %11 = fdiv fast double %9, %0 +; CHECK-NEXT: %12 = fmul fast double %x, %11 +; CHECK-NEXT: %13 = fdiv fast double %10, %0 +; CHECK-NEXT: %14 = fmul fast double %y, %13 +; CHECK-NEXT: %15 = fadd fast double %12, %14 +; CHECK-NEXT: %16 = insertvalue [3 x double] %8, double %15, 1 +; CHECK-NEXT: %17 = extractvalue [3 x double] %"x'", 2 +; CHECK-NEXT: %18 = extractvalue [3 x double] %"y'", 2 +; CHECK-NEXT: %19 = fdiv fast double %17, %0 +; CHECK-NEXT: %20 = fmul fast double %x, %19 +; CHECK-NEXT: %21 = fdiv fast double %18, %0 +; CHECK-NEXT: %22 = fmul fast double %y, %21 +; CHECK-NEXT: %23 = fadd fast double %20, %22 +; CHECK-NEXT: %24 = insertvalue [3 x double] %16, double %23, 2 +; CHECK-NEXT: ret [3 x double] %24 +; CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/Enzyme/ForwardModeVector/cabs2.ll b/enzyme/test/Enzyme/ForwardModeVector/cabs2.ll new file mode 100644 index 000000000000..a92cbd16420d --- /dev/null +++ b/enzyme/test/Enzyme/ForwardModeVector/cabs2.ll @@ -0,0 +1,55 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +; Function Attrs: nounwind readnone willreturn +declare double @cabs([2 x double]) #7 + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %agg0 = insertvalue [2 x double] undef, double %x, 0 + %agg1 = insertvalue [2 x double] %agg0, double %y, 1 + %call = call double @cabs([2 x double] %agg1) + ret double %call +} + +define [3 x double] @test_derivative(double %x, double %y) { +entry: + %0 = tail call [3 x double] (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, metadata !"enzyme_width", i64 3, double %x, double 1.0, double 1.3, double 2.0, double %y, double 1.0, double 0.0, double 2.0) + ret [3 x double] %0 +} + +; Function Attrs: nounwind +declare [3 x double] @__enzyme_fwddiff(double (double, double)*, ...) + + +; CHECK: define internal [3 x double] @fwddiffe3tester(double %x, [3 x double] %"x'", double %y, [3 x double] %"y'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = extractvalue [3 x double] %"x'", 0 +; CHECK-NEXT: %1 = extractvalue [3 x double] %"x'", 1 +; CHECK-NEXT: %2 = extractvalue [3 x double] %"x'", 2 +; CHECK-NEXT: %agg0 = insertvalue [2 x double] undef, double %x, 0 +; CHECK-NEXT: %3 = extractvalue [3 x double] %"y'", 0 +; CHECK-NEXT: %4 = extractvalue [3 x double] %"y'", 1 +; CHECK-NEXT: %5 = extractvalue [3 x double] %"y'", 2 +; CHECK-NEXT: %agg1 = insertvalue [2 x double] %agg0, double %y, 1 +; CHECK-NEXT: %6 = call fast double @cabs([2 x double] %agg1) +; CHECK-NEXT: %7 = fdiv fast double %0, %6 +; CHECK-NEXT: %8 = fmul fast double %x, %7 +; CHECK-NEXT: %9 = fdiv fast double %3, %6 +; CHECK-NEXT: %10 = fmul fast double %y, %9 +; CHECK-NEXT: %11 = fadd fast double %8, %10 +; CHECK-NEXT: %12 = insertvalue [3 x double] undef, double %11, 0 +; CHECK-NEXT: %13 = fdiv fast double %1, %6 +; CHECK-NEXT: %14 = fmul fast double %x, %13 +; CHECK-NEXT: %15 = fdiv fast double %4, %6 +; CHECK-NEXT: %16 = fmul fast double %y, %15 +; CHECK-NEXT: %17 = fadd fast double %14, %16 +; CHECK-NEXT: %18 = insertvalue [3 x double] %12, double %17, 1 +; CHECK-NEXT: %19 = fdiv fast double %2, %6 +; CHECK-NEXT: %20 = fmul fast double %x, %19 +; CHECK-NEXT: %21 = fdiv fast double %5, %6 +; CHECK-NEXT: %22 = fmul fast double %y, %21 +; CHECK-NEXT: %23 = fadd fast double %20, %22 +; CHECK-NEXT: %24 = insertvalue [3 x double] %18, double %23, 2 +; CHECK-NEXT: ret [3 x double] %24 +; CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/Enzyme/ReverseMode/cabs-const.ll b/enzyme/test/Enzyme/ReverseMode/cabs-const.ll new file mode 100644 index 000000000000..db8a855dfd5d --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/cabs-const.ll @@ -0,0 +1,29 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %call = call double @cabs(double %x, double %y) + ret double %call +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_autodiff(double (double, double)* nonnull @tester, metadata !"enzyme_const", double %x, double %y) + ret double %0 +} + +declare double @cabs(double, double) + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double, double)*, ...) + + +; CHECK: define internal { double } @diffetester(double %x, double %y, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @cabs(double %x, double %y) +; CHECK-NEXT: %1 = fdiv fast double %differeturn, %0 +; CHECK-NEXT: %2 = fmul fast double %y, %1 +; CHECK-NEXT: %3 = insertvalue { double } undef, double %2, 0 +; CHECK-NEXT: ret { double } %3 +; CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/Enzyme/ReverseMode/cabs2-const.ll b/enzyme/test/Enzyme/ReverseMode/cabs2-const.ll new file mode 100644 index 000000000000..96f7fabd8386 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/cabs2-const.ll @@ -0,0 +1,34 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +; Function Attrs: nounwind readnone willreturn +declare dso_local double @cabs([2 x double]) #7 + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %agg0 = insertvalue [2 x double] undef, double %x, 0 + %agg1 = insertvalue [2 x double] %agg0, double %y, 1 + %call = call double @cabs([2 x double] %agg1) + ret double %call +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_autodiff(double (double, double)* nonnull @tester, metadata !"enzyme_const", double %x, double %y) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double, double)*, ...) + + +; CHECK: define internal { double } @diffetester(double %x, double %y, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %agg0 = insertvalue [2 x double] undef, double %x, 0 +; CHECK-NEXT: %agg1 = insertvalue [2 x double] %agg0, double %y, 1 +; CHECK-NEXT: %0 = call fast double @cabs([2 x double] %agg1) +; CHECK-NEXT: %1 = fdiv fast double %differeturn, %0 +; CHECK-NEXT: %2 = fmul fast double %y, %1 +; CHECK-NEXT: %3 = insertvalue { double } undef, double %2, 0 +; CHECK-NEXT: ret { double } %3 +; CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/Enzyme/ReverseMode/cabs2.ll b/enzyme/test/Enzyme/ReverseMode/cabs2.ll new file mode 100644 index 000000000000..c23cd7df056b --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/cabs2.ll @@ -0,0 +1,36 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +; Function Attrs: nounwind readnone willreturn +declare double @cabs([2 x double]) #7 + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %agg0 = insertvalue [2 x double] undef, double %x, 0 + %agg1 = insertvalue [2 x double] %agg0, double %y, 1 + %call = call double @cabs([2 x double] %agg1) + ret double %call +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_autodiff(double (double, double)* nonnull @tester, double %x, double %y) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double, double)*, ...) + + +; CHECK: define internal { double, double } @diffetester(double %x, double %y, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %agg0 = insertvalue [2 x double] undef, double %x, 0 +; CHECK-NEXT: %agg1 = insertvalue [2 x double] %agg0, double %y, 1 +; CHECK-NEXT: %0 = call fast double @cabs([2 x double] %agg1) +; CHECK-NEXT: %1 = fdiv fast double %differeturn, %0 +; CHECK-NEXT: %2 = fmul fast double %x, %1 +; CHECK-NEXT: %3 = fmul fast double %y, %1 +; CHECK-NEXT: %4 = insertvalue { double, double } undef, double %2, 0 +; CHECK-NEXT: %5 = insertvalue { double, double } %4, double %3, 1 +; CHECK-NEXT: ret { double, double } %5 +; CHECK-NEXT: }