Skip to content

Commit

Permalink
Add log1p (rust-lang#667)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored May 28, 2022
1 parent 302711a commit e290aab
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 0 deletions.
54 changes: 54 additions & 0 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -8562,6 +8562,60 @@ class AdjointGenerator
llvm_unreachable("unhandled openmp function");
}

if (funcName == "log1p" || funcName == "log1pf" || funcName == "log1pl") {
if (gutils->knownRecomputeHeuristic.find(orig) !=
gutils->knownRecomputeHeuristic.end()) {
if (!gutils->knownRecomputeHeuristic[orig]) {
gutils->cacheForReverse(BuilderZ, newCall,
getIndex(orig, CacheType::Self));
}
}
eraseIfUnused(*orig);
if (gutils->isConstantInstruction(orig))
return;

switch (Mode) {
case DerivativeMode::ForwardModeSplit:
case DerivativeMode::ForwardMode: {
IRBuilder<> Builder2(&call);
getForwardBuilder(Builder2);
Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0));
Value *onePx =
Builder2.CreateFAdd(ConstantFP::get(x->getType(), 1.0), x);

Value *op = diffe(orig->getArgOperand(0), Builder2);

auto rule = [&](Value *op) { return Builder2.CreateFDiv(op, onePx); };
Value *dif0 = applyChainRule(call.getType(), Builder2, rule, op);
setDiffe(orig, dif0, Builder2);
return;
}
case DerivativeMode::ReverseModeGradient:
case DerivativeMode::ReverseModeCombined: {
IRBuilder<> Builder2(call.getParent());
getReverseBuilder(Builder2);
Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)),
Builder2);
Value *onePx =
Builder2.CreateFAdd(ConstantFP::get(x->getType(), 1.0), x);

auto rule = [&](Value *dorig) {
return Builder2.CreateFDiv(dorig, onePx);
};

Value *dorig = diffe(orig, Builder2);
Value *dif0 = applyChainRule(orig->getArgOperand(0)->getType(),
Builder2, rule, dorig);

addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType());
return;
}
case DerivativeMode::ReverseModePrimal: {
return;
}
}
}

if (funcName == "asin" || funcName == "asinf" || funcName == "asinl") {
if (gutils->knownRecomputeHeuristic.find(orig) !=
gutils->knownRecomputeHeuristic.end()) {
Expand Down
27 changes: 27 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/log1p.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -gvn -simplifycfg -instcombine -S | FileCheck %s

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x) {
entry:
%0 = tail call double @log1p(double %x)
ret double %0
}

define double @test_derivative(double %x) {
entry:
%0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0)
ret double %0
}

; Function Attrs: nounwind readnone speculatable
declare double @log1p(double)

; Function Attrs: nounwind
declare double @__enzyme_fwddiff(double (double)*, ...)

; CHECK: define internal double @fwddiffetester(double %x, double %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = fadd fast double %x, 1.000000e+00
; CHECK-NEXT: %1 = fdiv fast double %"x'", %0
; CHECK-NEXT: ret double %1
; CHECK-NEXT: }
37 changes: 37 additions & 0 deletions enzyme/test/Enzyme/ForwardModeVector/log1p.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -gvn -simplifycfg -instcombine -S | FileCheck %s

%struct.Gradients = type { double, double, double }

; Function Attrs: nounwind
declare %struct.Gradients @__enzyme_fwddiff(double (double)*, ...)

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x) {
entry:
%0 = tail call double @log1p(double %x)
ret double %0
}

define %struct.Gradients @test_derivative(double %x) {
entry:
%0 = tail call %struct.Gradients (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, metadata !"enzyme_width", i64 3, double %x, double 1.0, double 2.0, double 3.0)
ret %struct.Gradients %0
}

; Function Attrs: nounwind readnone speculatable
declare double @log1p(double)

; CHECK: define internal [3 x double] @fwddiffe3tester(double %x, [3 x double] %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = fadd fast double %x, 1.000000e+00
; CHECK-NEXT: %1 = extractvalue [3 x double] %"x'", 0
; CHECK-NEXT: %2 = fdiv fast double %1, %0
; CHECK-NEXT: %3 = insertvalue [3 x double] undef, double %2, 0
; CHECK-NEXT: %4 = extractvalue [3 x double] %"x'", 1
; CHECK-NEXT: %5 = fdiv fast double %4, %0
; CHECK-NEXT: %6 = insertvalue [3 x double] %3, double %5, 1
; CHECK-NEXT: %7 = extractvalue [3 x double] %"x'", 2
; CHECK-NEXT: %8 = fdiv fast double %7, %0
; CHECK-NEXT: %9 = insertvalue [3 x double] %6, double %8, 2
; CHECK-NEXT: ret [3 x double] %9
; CHECK-NEXT: }
28 changes: 28 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/log1p.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -gvn -instcombine -simplifycfg -S | FileCheck %s

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x) {
entry:
%0 = tail call fast double @log1p(double %x)
ret double %0
}

define double @test_derivative(double %x) {
entry:
%0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x)
ret double %0
}

; Function Attrs: nounwind readnone speculatable
declare double @log1p(double)

; Function Attrs: nounwind
declare double @__enzyme_autodiff(double (double)*, ...)

; CHECK: define internal { double } @diffetester(double %x, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = fadd fast double %x, 1.000000e+00
; CHECK-NEXT: %1 = fdiv fast double %differeturn, %0
; CHECK-NEXT: %2 = insertvalue { double } undef, double %1, 0
; CHECK-NEXT: ret { double } %2
; CHECK-NEXT: }

0 comments on commit e290aab

Please sign in to comment.