Skip to content

Commit

Permalink
Fix memset for floats (rust-lang#765)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Aug 2, 2022
1 parent c832e60 commit 0d1c595
Show file tree
Hide file tree
Showing 12 changed files with 1,002 additions and 171 deletions.
392 changes: 359 additions & 33 deletions enzyme/Enzyme/ActivityAnalysis.cpp

Large diffs are not rendered by default.

337 changes: 300 additions & 37 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Constants.h"
Expand Down Expand Up @@ -2669,19 +2670,19 @@ class AdjointGenerator
void visitMemSetInst(llvm::MemSetInst &MS) { visitMemSetCommon(MS); }

void visitMemSetCommon(llvm::CallInst &MS) {
IRBuilder<> BuilderZ(&MS);
getForwardBuilder(BuilderZ);

IRBuilder<> Builder2(&MS);
if (Mode == DerivativeMode::ReverseModeGradient ||
Mode == DerivativeMode::ReverseModeCombined)
getReverseBuilder(Builder2);

eraseIfUnused(MS);

Value *orig_op0 = MS.getArgOperand(0);
Value *orig_op1 = MS.getArgOperand(1);

// TODO this should 1) assert that the value being meset is constant
// 2) duplicate the memset for the inverted pointer

if (gutils->isConstantInstruction(&MS) &&
Mode != DerivativeMode::ForwardMode) {
return;
}

// If constant destination then no operation needs doing
if (gutils->isConstantValue(orig_op0)) {
return;
Expand All @@ -2701,36 +2702,10 @@ class AdjointGenerator
report_fatal_error("non constant in memset");
}

bool backwardsShadow = false;
bool forwardsShadow = true;
for (auto pair : gutils->backwardsOnlyShadows) {
if (pair.second.stores.count(&MS)) {
backwardsShadow = true;
forwardsShadow = pair.second.primalInitialize;
if (auto inst = dyn_cast<Instruction>(pair.first))
if (!forwardsShadow && pair.second.LI &&
pair.second.LI->contains(inst->getParent()))
backwardsShadow = false;
}
}

if ((Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) ||
(Mode == DerivativeMode::ReverseModeGradient && backwardsShadow) ||
(Mode == DerivativeMode::ReverseModeCombined &&
(forwardsShadow && backwardsShadow)) ||
Mode == DerivativeMode::ForwardMode) {
IRBuilder<> BuilderZ(&MS);
getForwardBuilder(BuilderZ);

bool forwardMode = Mode == DerivativeMode::ForwardMode;

if (Mode == DerivativeMode::ForwardMode) {
Value *op0 = gutils->invertPointerM(orig_op0, BuilderZ);
Value *op1 = gutils->getNewFromOriginal(MS.getArgOperand(1));
if (!forwardMode)
op1 = gutils->lookupM(op1, BuilderZ);
Value *op2 = gutils->getNewFromOriginal(MS.getArgOperand(2));
if (!forwardMode)
op2 = gutils->lookupM(op2, BuilderZ);
Value *op3 = nullptr;
#if LLVM_VERSION_MAJOR >= 14
if (3 < MS.arg_size())
Expand All @@ -2739,8 +2714,6 @@ class AdjointGenerator
#endif
{
op3 = gutils->getNewFromOriginal(MS.getOperand(3));
if (!forwardMode)
op3 = gutils->lookupM(op3, BuilderZ);
}

auto Defs =
Expand All @@ -2763,6 +2736,296 @@ class AdjointGenerator
cal->setDebugLoc(gutils->getNewFromOriginal(MS.getDebugLoc()));
},
op0);
return;
}

bool backwardsShadow = false;
bool forwardsShadow = true;
for (auto pair : gutils->backwardsOnlyShadows) {
if (pair.second.stores.count(&MS)) {
backwardsShadow = true;
forwardsShadow = pair.second.primalInitialize;
if (auto inst = dyn_cast<Instruction>(pair.first))
if (!forwardsShadow && pair.second.LI &&
pair.second.LI->contains(inst->getParent()))
backwardsShadow = false;
}
}

size_t size = 1;
if (auto ci = dyn_cast<ConstantInt>(MS.getOperand(2))) {
size = ci->getLimitedValue();
}

// TODO note that we only handle memset of ONE type (aka memset of {int,
// double} not allowed)

if (size == 0) {
llvm::errs() << MS << "\n";
}
assert(size != 0);

auto &DL = gutils->newFunc->getParent()->getDataLayout();
auto vd = TR.query(MS.getOperand(0)).Data0().ShiftIndices(DL, 0, size, 0);

if (!vd.isKnownPastPointer()) {
// If unknown type results, consider the intersection of all incoming.
if (isa<PHINode>(MS.getOperand(0)) || isa<SelectInst>(MS.getOperand(0))) {
SmallVector<Value *, 2> todo = {MS.getOperand(0)};
bool set = false;
SmallSet<Value *, 2> seen;
TypeTree vd2;
while (todo.size()) {
Value *cur = todo.back();
todo.pop_back();
if (seen.count(cur))
continue;
seen.insert(cur);
if (auto PN = dyn_cast<PHINode>(cur)) {
for (size_t i = 0, end = PN->getNumIncomingValues(); i < end; i++) {
todo.push_back(PN->getIncomingValue(i));
}
continue;
}
if (auto S = dyn_cast<SelectInst>(cur)) {
todo.push_back(S->getTrueValue());
todo.push_back(S->getFalseValue());
continue;
}
if (auto CE = dyn_cast<ConstantExpr>(cur)) {
if (CE->isCast()) {
todo.push_back(CE->getOperand(0));
continue;
}
}
if (auto CI = dyn_cast<CastInst>(cur)) {
todo.push_back(CI->getOperand(0));
continue;
}
if (isa<ConstantPointerNull>(cur))
continue;
if (auto CI = dyn_cast<ConstantInt>(cur))
if (CI->isZero())
continue;
auto curTT = TR.query(cur).Data0().ShiftIndices(DL, 0, size, 0);
if (!set)
vd2 = curTT;
else
vd2 &= curTT;
set = true;
}
vd = vd2;
}
}
if (!vd.isKnownPastPointer()) {
if (looseTypeAnalysis) {
if (auto CI = dyn_cast<CastInst>(MS.getOperand(0))) {
if (auto PT = dyn_cast<PointerType>(CI->getSrcTy())) {
auto ET = PT->getPointerElementType();
while (1) {
if (auto ST = dyn_cast<StructType>(ET)) {
if (ST->getNumElements()) {
ET = ST->getElementType(0);
continue;
}
}
if (auto AT = dyn_cast<ArrayType>(ET)) {
ET = AT->getElementType();
continue;
}
break;
}
if (ET->isFPOrFPVectorTy()) {
vd = TypeTree(ConcreteType(ET->getScalarType())).Only(0);
goto known;
}
if (ET->isPointerTy()) {
vd = TypeTree(BaseType::Pointer).Only(0);
goto known;
}
if (ET->isIntOrIntVectorTy()) {
vd = TypeTree(BaseType::Integer).Only(0);
goto known;
}
}
}
if (auto gep = dyn_cast<GetElementPtrInst>(MS.getOperand(0))) {
if (auto AT = dyn_cast<ArrayType>(gep->getSourceElementType())) {
if (AT->getElementType()->isIntegerTy()) {
vd = TypeTree(BaseType::Integer).Only(0);
goto known;
}
}
}
EmitWarning("CannotDeduceType", MS.getDebugLoc(), gutils->oldFunc,
MS.getParent(), &MS, "failed to deduce type of memset ",
MS);
vd = TypeTree(BaseType::Pointer).Only(0);
goto known;
}
if (CustomErrorHandler) {
std::string str;
raw_string_ostream ss(str);
ss << "Cannot deduce type of memset " << MS;
CustomErrorHandler(str.c_str(), wrap(&MS), ErrorType::NoType,
&TR.analyzer);
}
EmitFailure("CannotDeduceType", MS.getDebugLoc(), &MS,
"failed to deduce type of memset ", MS);

TR.firstPointer(size, MS.getOperand(0), /*errifnotfound*/ true,
/*pointerIntSame*/ true);
llvm_unreachable("bad msi");
}
known:;

#if 0
#if LLVM_VERSION_MAJOR >= 10
unsigned dstalign = dstAlign.valueOrOne().value();
unsigned srcalign = srcAlign.valueOrOne().value();
#else
unsigned dstalign = dstAlign;
unsigned srcalign = srcAlign;
#endif
#endif

unsigned start = 0;

Value *op1 = gutils->getNewFromOriginal(MS.getArgOperand(1));
Value *new_size = gutils->getNewFromOriginal(MS.getArgOperand(2));
Value *op3 = nullptr;
#if LLVM_VERSION_MAJOR >= 14
if (3 < MS.arg_size())
#else
if (3 < MS.getNumArgOperands())
#endif
{
op3 = gutils->getNewFromOriginal(MS.getOperand(3));
}

while (1) {
unsigned nextStart = size;

auto dt = vd[{-1}];
for (size_t i = start; i < size; ++i) {
bool Legal = true;
dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal);
if (!Legal) {
nextStart = i;
break;
}
}
if (!dt.isKnown()) {
TR.dump();
llvm::errs() << " vd:" << vd.str() << " start:" << start
<< " size: " << size << " dt:" << dt.str() << "\n";
}
assert(dt.isKnown());

Value *length = new_size;
if (nextStart != size) {
length = ConstantInt::get(new_size->getType(), nextStart);
}
if (start != 0)
length = BuilderZ.CreateSub(
length, ConstantInt::get(new_size->getType(), start));

#if 0
unsigned subdstalign = dstalign;
// todo make better alignment calculation
if (dstalign != 0) {
if (start % dstalign != 0) {
dstalign = 1;
}
}
unsigned subsrcalign = srcalign;
// todo make better alignment calculation
if (srcalign != 0) {
if (start % srcalign != 0) {
srcalign = 1;
}
}
#endif

Value *shadow_dst = gutils->invertPointerM(MS.getOperand(0), BuilderZ);

// TODO ponder forward split mode
Type *secretty = dt.isFloat();
if (!secretty &&
((Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) ||
(Mode == DerivativeMode::ReverseModeCombined && forwardsShadow) ||
(Mode == DerivativeMode::ReverseModeGradient && backwardsShadow) ||
(Mode == DerivativeMode::ForwardModeSplit && backwardsShadow))) {
auto Defs =
gutils->getInvertedBundles(&MS,
{ValueType::Shadow, ValueType::Primal,
ValueType::Primal, ValueType::Primal},
BuilderZ, /*lookup*/ false);
auto rule = [&](Value *op0) {
if (start != 0) {
Value *idxs[] = {
ConstantInt::get(Type::getInt32Ty(op0->getContext()), start)};
#if LLVM_VERSION_MAJOR > 7
op0 = BuilderZ.CreateInBoundsGEP(
op0->getType()->getPointerElementType(), op0, idxs);
#else
op0 = BuilderZ.CreateInBoundsGEP(op0, idxs);
#endif
}
SmallVector<Value *, 4> args = {op0, op1, length};
if (op3)
args.push_back(op3);
auto cal = BuilderZ.CreateCall(MS.getCalledFunction(), args, Defs);
cal->copyMetadata(MS, MD_ToCopy);
cal->setAttributes(MS.getAttributes());
cal->setCallingConv(MS.getCallingConv());
cal->setTailCallKind(MS.getTailCallKind());
cal->setDebugLoc(gutils->getNewFromOriginal(MS.getDebugLoc()));
};

applyChainRule(BuilderZ, rule, shadow_dst);
}
if (secretty && (Mode == DerivativeMode::ReverseModeGradient ||
Mode == DerivativeMode::ReverseModeCombined)) {

auto Defs =
gutils->getInvertedBundles(&MS,
{ValueType::Shadow, ValueType::Primal,
ValueType::Primal, ValueType::Primal},
BuilderZ, /*lookup*/ true);
Value *op1l = gutils->lookupM(op1, Builder2);
Value *op3l = op3;
if (op3l)
op3l = gutils->lookupM(op3l, BuilderZ);
length = gutils->lookupM(length, Builder2);
auto rule = [&](Value *op0) {
if (start != 0) {
Value *idxs[] = {
ConstantInt::get(Type::getInt32Ty(op0->getContext()), start)};
#if LLVM_VERSION_MAJOR > 7
op0 = Builder2.CreateInBoundsGEP(
op0->getType()->getPointerElementType(), op0, idxs);
#else
op0 = Builder2.CreateInBoundsGEP(op0, idxs);
#endif
}
SmallVector<Value *, 4> args = {op0, op1l, length};
if (op3l)
args.push_back(op3l);
auto cal = Builder2.CreateCall(MS.getCalledFunction(), args, Defs);
cal->copyMetadata(MS, MD_ToCopy);
cal->setAttributes(MS.getAttributes());
cal->setCallingConv(MS.getCallingConv());
cal->setTailCallKind(MS.getTailCallKind());
cal->setDebugLoc(gutils->getNewFromOriginal(MS.getDebugLoc()));
};

applyChainRule(Builder2, rule, gutils->lookupM(shadow_dst, Builder2));
}

if (nextStart == size)
break;
start = nextStart;
}
}

Expand Down
Loading

0 comments on commit 0d1c595

Please sign in to comment.