Skip to content

Commit

Permalink
Extract functions handling finding of the function to differentiate a…
Browse files Browse the repository at this point in the history
…nd enzyme_width (rust-lang#661)
  • Loading branch information
tgymnich authored May 27, 2022
1 parent d222f5e commit 302711a
Showing 1 changed file with 67 additions and 44 deletions.
111 changes: 67 additions & 44 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,15 +417,9 @@ class Enzyme : public ModulePass {
// AU.addRequiredID(llvm::LoopSimplifyID);//<LoopSimplifyWrapperPass>();
}

/// Return whether successful
bool HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, DerivativeMode mode,
bool sizeOnly) {

Optional<Function *> parseFunctionParameter(CallInst *CI) {
Value *fn = CI->getArgOperand(0);

std::vector<DIFFE_TYPE> constants;
SmallVector<Value *, 2> args;

// determine function to differentiate
if (CI->hasStructRetAttr()) {
fn = CI->getArgOperand(1);
Expand All @@ -444,25 +438,21 @@ class Enzyme : public ModulePass {
EmitFailure("NoFunctionToDifferentiate", CI->getDebugLoc(), CI,
"failed to find fn to differentiate", *CI, " - found - ",
*fn);
return false;
return None;
}
if (cast<Function>(fn)->empty()) {
EmitFailure("EmptyFunctionToDifferentiate", CI->getDebugLoc(), CI,
"failed to find fn to differentiate", *CI, " - found - ",
*fn);
return false;
return None;
}
auto FT = cast<Function>(fn)->getFunctionType();
assert(fn);

IRBuilder<> Builder(CI);
unsigned truei = 0;
return cast<Function>(fn);
}

Optional<unsigned> parseWidthParameter(CallInst *CI) {
unsigned width = 1;
std::map<unsigned, Value *> batchOffset;
bool returnUsed = !cast<Function>(fn)->getReturnType()->isVoidTy() &&
!cast<Function>(fn)->getReturnType()->isEmptyTy();

// determine width
#if LLVM_VERSION_MAJOR >= 14
for (auto [i, found] = std::tuple{0u, false}; i < CI->arg_size(); ++i)
#else
Expand All @@ -478,7 +468,7 @@ class Enzyme : public ModulePass {
EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
"vector width declared more than once",
*CI->getArgOperand(i), " in", *CI);
return false;
return None;
}

#if LLVM_VERSION_MAJOR >= 14
Expand All @@ -490,7 +480,7 @@ class Enzyme : public ModulePass {
EmitFailure("MissingVectorWidth", CI->getDebugLoc(), CI,
"constant integer followong enzyme_width is missing",
*CI->getArgOperand(i), " in", *CI);
return false;
return None;
}

Value *width_arg = CI->getArgOperand(i + 1);
Expand All @@ -501,23 +491,59 @@ class Enzyme : public ModulePass {
EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
"enzyme_width must be a constant integer",
*CI->getArgOperand(i), " in", *CI);
return false;
return None;
}

if (!found) {
EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI,
"illegal enzyme vector argument width ",
*CI->getArgOperand(i), " in", *CI);
return false;
return None;
}
}
}
}
return width;
}

/// Return whether successful
bool HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, DerivativeMode mode,
bool sizeOnly) {

// determine function to differentiate
Function *fn;
auto parsedFunction = parseFunctionParameter(CI);
if (parsedFunction.hasValue()) {
fn = parsedFunction.getValue();
} else {
return false;
}

auto FT = fn->getFunctionType();
assert(fn);

IRBuilder<> Builder(CI);
unsigned truei = 0;
unsigned width = 1;
std::map<unsigned, Value *> batchOffset;
bool returnUsed =
!fn->getReturnType()->isVoidTy() && !fn->getReturnType()->isEmptyTy();

// find and handle enzyme_width
auto parsedWidth = parseWidthParameter(CI);
if (parsedWidth.hasValue()) {
width = parsedWidth.getValue();
} else {
return false;
}

std::vector<DIFFE_TYPE> constants;
SmallVector<Value *, 2> args;

// handle different argument order for struct return.
bool sret = CI->hasStructRetAttr() ||
cast<Function>(fn)->hasParamAttribute(0, Attribute::StructRet);
if (cast<Function>(fn)->hasParamAttribute(0, Attribute::StructRet)) {
fn->hasParamAttribute(0, Attribute::StructRet);
if (fn->hasParamAttribute(0, Attribute::StructRet)) {
Type *fnsrety = cast<PointerType>(FT->getParamType(0));

truei = 1;
Expand Down Expand Up @@ -594,7 +620,7 @@ class Enzyme : public ModulePass {

bool freeMemory = true;

DIFFE_TYPE retType = whatType(cast<Function>(fn)->getReturnType(), mode);
DIFFE_TYPE retType = whatType(fn->getReturnType(), mode);

bool differentialReturn = (mode == DerivativeMode::ReverseModeCombined ||
mode == DerivativeMode::ReverseModeGradient) &&
Expand Down Expand Up @@ -627,7 +653,7 @@ class Enzyme : public ModulePass {
differet = Builder.CreateLoad(differet);
#endif
}
assert(differet->getType() == cast<Function>(fn)->getReturnType());
assert(differet->getType() == fn->getReturnType());
continue;
} else if (tape == nullptr) {
tape = res;
Expand Down Expand Up @@ -869,7 +895,7 @@ class Enzyme : public ModulePass {
}

std::map<Argument *, bool> volatile_args;
FnTypeInfo type_args(cast<Function>(fn));
FnTypeInfo type_args(fn);
for (auto &a : type_args.Function->args()) {
volatile_args[&a] = !(mode == DerivativeMode::ReverseModeCombined);
TypeTree dt;
Expand Down Expand Up @@ -904,18 +930,18 @@ class Enzyme : public ModulePass {
switch (mode) {
case DerivativeMode::ForwardMode:
newFunc = Logic.CreateForwardDiff(
cast<Function>(fn), retType, constants, TA,
fn, retType, constants, TA,
/*should return*/ false, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr);
break;
case DerivativeMode::ForwardModeSplit: {
bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1;
aug = &Logic.CreateAugmentedPrimal(
cast<Function>(fn), retType, constants, TA,
fn, retType, constants, TA,
/*returnUsed*/ false, /*shadowReturnUsed*/ false, type_args,
volatile_args, forceAnonymousTape, width, /*atomicAdd*/ AtomicAdd);
auto &DL = cast<Function>(fn)->getParent()->getDataLayout();
auto &DL = fn->getParent()->getDataLayout();
if (!forceAnonymousTape) {
assert(!aug->tapeType);
if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) {
Expand Down Expand Up @@ -948,15 +974,15 @@ class Enzyme : public ModulePass {
tapeType = PointerType::getInt8PtrTy(fn->getContext());
}
newFunc = Logic.CreateForwardDiff(
cast<Function>(fn), retType, constants, TA,
fn, retType, constants, TA,
/*should return*/ false, mode, freeMemory, width,
/*addedType*/ tapeType, type_args, volatile_args, aug);
break;
}
case DerivativeMode::ReverseModeCombined:
assert(freeMemory);
newFunc = Logic.CreatePrimalAndGradient(
(ReverseCacheKey){.todiff = cast<Function>(fn),
(ReverseCacheKey){.todiff = fn,
.retType = retType,
.constant_args = constants,
.uncacheable_args = volatile_args,
Expand All @@ -976,10 +1002,10 @@ class Enzyme : public ModulePass {
bool shadowReturnUsed = returnUsed && (retType == DIFFE_TYPE::DUP_ARG ||
retType == DIFFE_TYPE::DUP_NONEED);
aug = &Logic.CreateAugmentedPrimal(
cast<Function>(fn), retType, constants, TA, returnUsed,
shadowReturnUsed, type_args, volatile_args, forceAnonymousTape, width,
fn, retType, constants, TA, returnUsed, shadowReturnUsed, type_args,
volatile_args, forceAnonymousTape, width,
/*atomicAdd*/ AtomicAdd);
auto &DL = cast<Function>(fn)->getParent()->getDataLayout();
auto &DL = fn->getParent()->getDataLayout();
if (!forceAnonymousTape) {
assert(!aug->tapeType);
if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) {
Expand Down Expand Up @@ -1015,7 +1041,7 @@ class Enzyme : public ModulePass {
newFunc = aug->fn;
else
newFunc = Logic.CreatePrimalAndGradient(
(ReverseCacheKey){.todiff = cast<Function>(fn),
(ReverseCacheKey){.todiff = fn,
.retType = retType,
.constant_args = constants,
.uncacheable_args = volatile_args,
Expand All @@ -1037,19 +1063,16 @@ class Enzyme : public ModulePass {
if (differentialReturn) {
if (differet)
args.push_back(differet);
else if (cast<Function>(fn)->getReturnType()->isFPOrFPVectorTy()) {
Constant *seed =
ConstantFP::get(cast<Function>(fn)->getReturnType(), 1.0);
else if (fn->getReturnType()->isFPOrFPVectorTy()) {
Constant *seed = ConstantFP::get(fn->getReturnType(), 1.0);
if (width == 1) {
args.push_back(seed);
} else {
ArrayType *arrayType =
ArrayType::get(cast<Function>(fn)->getReturnType(), width);
ArrayType *arrayType = ArrayType::get(fn->getReturnType(), width);
args.push_back(ConstantArray::get(
arrayType, SmallVector<Constant *, 3>(width, seed)));
}
} else if (auto ST = dyn_cast<StructType>(
cast<Function>(fn)->getReturnType())) {
} else if (auto ST = dyn_cast<StructType>(fn->getReturnType())) {
SmallVector<Constant *, 2> csts;
for (auto e : ST->elements()) {
csts.push_back(ConstantFP::get(e, 1.0));
Expand All @@ -1061,7 +1084,7 @@ class Enzyme : public ModulePass {
if ((mode == DerivativeMode::ReverseModeGradient ||
mode == DerivativeMode::ForwardModeSplit) &&
tape && tapeType) {
auto &DL = cast<Function>(fn)->getParent()->getDataLayout();
auto &DL = fn->getParent()->getDataLayout();
if (tapeIsPointer) {
tape = Builder.CreateBitCast(
tape, PointerType::get(
Expand Down Expand Up @@ -1212,7 +1235,7 @@ class Enzyme : public ModulePass {
}
CI->replaceAllUsesWith(newStruct);
} else if (mode == DerivativeMode::ReverseModePrimal) {
auto &DL = cast<Function>(fn)->getParent()->getDataLayout();
auto &DL = fn->getParent()->getDataLayout();
if (DL.getTypeSizeInBits(CI->getType()) >=
DL.getTypeSizeInBits(diffret->getType())) {
IRBuilder<> EB(
Expand Down

0 comments on commit 302711a

Please sign in to comment.