diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 8b58808cb8dee..071952eca7549 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -417,15 +417,9 @@ class Enzyme : public ModulePass { // AU.addRequiredID(llvm::LoopSimplifyID);//(); } - /// Return whether successful - bool HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, DerivativeMode mode, - bool sizeOnly) { - + Optional parseFunctionParameter(CallInst *CI) { Value *fn = CI->getArgOperand(0); - std::vector constants; - SmallVector args; - // determine function to differentiate if (CI->hasStructRetAttr()) { fn = CI->getArgOperand(1); @@ -444,25 +438,21 @@ class Enzyme : public ModulePass { EmitFailure("NoFunctionToDifferentiate", CI->getDebugLoc(), CI, "failed to find fn to differentiate", *CI, " - found - ", *fn); - return false; + return None; } if (cast(fn)->empty()) { EmitFailure("EmptyFunctionToDifferentiate", CI->getDebugLoc(), CI, "failed to find fn to differentiate", *CI, " - found - ", *fn); - return false; + return None; } - auto FT = cast(fn)->getFunctionType(); - assert(fn); - IRBuilder<> Builder(CI); - unsigned truei = 0; + return cast(fn); + } + + Optional parseWidthParameter(CallInst *CI) { unsigned width = 1; - std::map batchOffset; - bool returnUsed = !cast(fn)->getReturnType()->isVoidTy() && - !cast(fn)->getReturnType()->isEmptyTy(); - // determine width #if LLVM_VERSION_MAJOR >= 14 for (auto [i, found] = std::tuple{0u, false}; i < CI->arg_size(); ++i) #else @@ -478,7 +468,7 @@ class Enzyme : public ModulePass { EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI, "vector width declared more than once", *CI->getArgOperand(i), " in", *CI); - return false; + return None; } #if LLVM_VERSION_MAJOR >= 14 @@ -490,7 +480,7 @@ class Enzyme : public ModulePass { EmitFailure("MissingVectorWidth", CI->getDebugLoc(), CI, "constant integer followong enzyme_width is missing", *CI->getArgOperand(i), " in", *CI); - return false; + return None; } Value *width_arg = CI->getArgOperand(i + 1); @@ -501,23 +491,59 @@ class Enzyme : public ModulePass { EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI, "enzyme_width must be a constant integer", *CI->getArgOperand(i), " in", *CI); - return false; + return None; } if (!found) { EmitFailure("IllegalVectorWidth", CI->getDebugLoc(), CI, "illegal enzyme vector argument width ", *CI->getArgOperand(i), " in", *CI); - return false; + return None; } } } } + return width; + } + + /// Return whether successful + bool HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, DerivativeMode mode, + bool sizeOnly) { + + // determine function to differentiate + Function *fn; + auto parsedFunction = parseFunctionParameter(CI); + if (parsedFunction.hasValue()) { + fn = parsedFunction.getValue(); + } else { + return false; + } + + auto FT = fn->getFunctionType(); + assert(fn); + + IRBuilder<> Builder(CI); + unsigned truei = 0; + unsigned width = 1; + std::map batchOffset; + bool returnUsed = + !fn->getReturnType()->isVoidTy() && !fn->getReturnType()->isEmptyTy(); + + // find and handle enzyme_width + auto parsedWidth = parseWidthParameter(CI); + if (parsedWidth.hasValue()) { + width = parsedWidth.getValue(); + } else { + return false; + } + + std::vector constants; + SmallVector args; // handle different argument order for struct return. bool sret = CI->hasStructRetAttr() || - cast(fn)->hasParamAttribute(0, Attribute::StructRet); - if (cast(fn)->hasParamAttribute(0, Attribute::StructRet)) { + fn->hasParamAttribute(0, Attribute::StructRet); + if (fn->hasParamAttribute(0, Attribute::StructRet)) { Type *fnsrety = cast(FT->getParamType(0)); truei = 1; @@ -594,7 +620,7 @@ class Enzyme : public ModulePass { bool freeMemory = true; - DIFFE_TYPE retType = whatType(cast(fn)->getReturnType(), mode); + DIFFE_TYPE retType = whatType(fn->getReturnType(), mode); bool differentialReturn = (mode == DerivativeMode::ReverseModeCombined || mode == DerivativeMode::ReverseModeGradient) && @@ -627,7 +653,7 @@ class Enzyme : public ModulePass { differet = Builder.CreateLoad(differet); #endif } - assert(differet->getType() == cast(fn)->getReturnType()); + assert(differet->getType() == fn->getReturnType()); continue; } else if (tape == nullptr) { tape = res; @@ -869,7 +895,7 @@ class Enzyme : public ModulePass { } std::map volatile_args; - FnTypeInfo type_args(cast(fn)); + FnTypeInfo type_args(fn); for (auto &a : type_args.Function->args()) { volatile_args[&a] = !(mode == DerivativeMode::ReverseModeCombined); TypeTree dt; @@ -904,7 +930,7 @@ class Enzyme : public ModulePass { switch (mode) { case DerivativeMode::ForwardMode: newFunc = Logic.CreateForwardDiff( - cast(fn), retType, constants, TA, + fn, retType, constants, TA, /*should return*/ false, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr); @@ -912,10 +938,10 @@ class Enzyme : public ModulePass { case DerivativeMode::ForwardModeSplit: { bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1; aug = &Logic.CreateAugmentedPrimal( - cast(fn), retType, constants, TA, + fn, retType, constants, TA, /*returnUsed*/ false, /*shadowReturnUsed*/ false, type_args, volatile_args, forceAnonymousTape, width, /*atomicAdd*/ AtomicAdd); - auto &DL = cast(fn)->getParent()->getDataLayout(); + auto &DL = fn->getParent()->getDataLayout(); if (!forceAnonymousTape) { assert(!aug->tapeType); if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) { @@ -948,7 +974,7 @@ class Enzyme : public ModulePass { tapeType = PointerType::getInt8PtrTy(fn->getContext()); } newFunc = Logic.CreateForwardDiff( - cast(fn), retType, constants, TA, + fn, retType, constants, TA, /*should return*/ false, mode, freeMemory, width, /*addedType*/ tapeType, type_args, volatile_args, aug); break; @@ -956,7 +982,7 @@ class Enzyme : public ModulePass { case DerivativeMode::ReverseModeCombined: assert(freeMemory); newFunc = Logic.CreatePrimalAndGradient( - (ReverseCacheKey){.todiff = cast(fn), + (ReverseCacheKey){.todiff = fn, .retType = retType, .constant_args = constants, .uncacheable_args = volatile_args, @@ -976,10 +1002,10 @@ class Enzyme : public ModulePass { bool shadowReturnUsed = returnUsed && (retType == DIFFE_TYPE::DUP_ARG || retType == DIFFE_TYPE::DUP_NONEED); aug = &Logic.CreateAugmentedPrimal( - cast(fn), retType, constants, TA, returnUsed, - shadowReturnUsed, type_args, volatile_args, forceAnonymousTape, width, + fn, retType, constants, TA, returnUsed, shadowReturnUsed, type_args, + volatile_args, forceAnonymousTape, width, /*atomicAdd*/ AtomicAdd); - auto &DL = cast(fn)->getParent()->getDataLayout(); + auto &DL = fn->getParent()->getDataLayout(); if (!forceAnonymousTape) { assert(!aug->tapeType); if (aug->returns.find(AugmentedStruct::Tape) != aug->returns.end()) { @@ -1015,7 +1041,7 @@ class Enzyme : public ModulePass { newFunc = aug->fn; else newFunc = Logic.CreatePrimalAndGradient( - (ReverseCacheKey){.todiff = cast(fn), + (ReverseCacheKey){.todiff = fn, .retType = retType, .constant_args = constants, .uncacheable_args = volatile_args, @@ -1037,19 +1063,16 @@ class Enzyme : public ModulePass { if (differentialReturn) { if (differet) args.push_back(differet); - else if (cast(fn)->getReturnType()->isFPOrFPVectorTy()) { - Constant *seed = - ConstantFP::get(cast(fn)->getReturnType(), 1.0); + else if (fn->getReturnType()->isFPOrFPVectorTy()) { + Constant *seed = ConstantFP::get(fn->getReturnType(), 1.0); if (width == 1) { args.push_back(seed); } else { - ArrayType *arrayType = - ArrayType::get(cast(fn)->getReturnType(), width); + ArrayType *arrayType = ArrayType::get(fn->getReturnType(), width); args.push_back(ConstantArray::get( arrayType, SmallVector(width, seed))); } - } else if (auto ST = dyn_cast( - cast(fn)->getReturnType())) { + } else if (auto ST = dyn_cast(fn->getReturnType())) { SmallVector csts; for (auto e : ST->elements()) { csts.push_back(ConstantFP::get(e, 1.0)); @@ -1061,7 +1084,7 @@ class Enzyme : public ModulePass { if ((mode == DerivativeMode::ReverseModeGradient || mode == DerivativeMode::ForwardModeSplit) && tape && tapeType) { - auto &DL = cast(fn)->getParent()->getDataLayout(); + auto &DL = fn->getParent()->getDataLayout(); if (tapeIsPointer) { tape = Builder.CreateBitCast( tape, PointerType::get( @@ -1212,7 +1235,7 @@ class Enzyme : public ModulePass { } CI->replaceAllUsesWith(newStruct); } else if (mode == DerivativeMode::ReverseModePrimal) { - auto &DL = cast(fn)->getParent()->getDataLayout(); + auto &DL = fn->getParent()->getDataLayout(); if (DL.getTypeSizeInBits(CI->getType()) >= DL.getTypeSizeInBits(diffret->getType())) { IRBuilder<> EB(