Skip to content

Commit

Permalink
Merge
Browse files Browse the repository at this point in the history
  • Loading branch information
rmoyard committed Jan 26, 2024
2 parents a02ce63 + 098ea88 commit 1bb9b90
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 58 deletions.
64 changes: 53 additions & 11 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1314,14 +1314,14 @@ class EnzymeBase {
return type_args;
}

bool HandleTruncate(CallInst *CI) {
bool HandleTruncateFunc(CallInst *CI) {
IRBuilder<> Builder(CI);
Function *F = parseFunctionParameter(CI);
if (!F)
return false;
if (CI->arg_size() != 3) {
EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
"Had incorrect number of args to __enzyme_truncate", *CI,
"Had incorrect number of args to __enzyme_truncate_func", *CI,
" - expected 3");
return false;
}
Expand All @@ -1330,7 +1330,7 @@ class EnzymeBase {
auto Cto = cast<ConstantInt>(CI->getArgOperand(2));
assert(Cto);
RequestContext context(CI, &Builder);
llvm::Value *res = Logic.CreateTruncate(
llvm::Value *res = Logic.CreateTruncateFunc(
context, F, (unsigned)Cfrom->getValue().getZExtValue(),
(unsigned)Cto->getValue().getZExtValue());
if (!res)
Expand All @@ -1341,6 +1341,28 @@ class EnzymeBase {
return true;
}

bool HandleTruncateValue(CallInst *CI, bool isTruncate) {
IRBuilder<> Builder(CI);
if (CI->arg_size() != 3) {
EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
"Had incorrect number of args to __enzyme_truncate_value",
*CI, " - expected 3");
return false;
}
auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
assert(Cfrom);
auto Cto = cast<ConstantInt>(CI->getArgOperand(2));
assert(Cto);
auto Addr = CI->getArgOperand(0);
RequestContext context(CI, &Builder);
bool res = Logic.CreateTruncateValue(
context, Addr, (unsigned)Cfrom->getValue().getZExtValue(),
(unsigned)Cto->getValue().getZExtValue(), isTruncate);
if (!res)
return false;
return true;
}

bool HandleBatch(CallInst *CI) {
unsigned width = 1;
unsigned truei = 0;
Expand Down Expand Up @@ -2088,7 +2110,9 @@ class EnzymeBase {
MapVector<CallInst *, DerivativeMode> toVirtual;
MapVector<CallInst *, DerivativeMode> toSize;
SmallVector<CallInst *, 4> toBatch;
SmallVector<CallInst *, 4> toTruncate;
SmallVector<CallInst *, 4> toTruncateFunc;
SmallVector<CallInst *, 4> toTruncateValue;
SmallVector<CallInst *, 4> toExpandValue;
MapVector<CallInst *, ProbProgMode> toProbProg;
SetVector<CallInst *> InactiveCalls;
SetVector<CallInst *> IterCalls;
Expand Down Expand Up @@ -2398,7 +2422,9 @@ class EnzymeBase {
bool virtualCall = false;
bool sizeOnly = false;
bool batch = false;
bool truncate = false;
bool truncateFunc = false;
bool truncateValue = false;
bool expandValue = false;
bool probProg = false;
DerivativeMode derivativeMode;
ProbProgMode probProgMode;
Expand Down Expand Up @@ -2428,9 +2454,15 @@ class EnzymeBase {
} else if (Fn->getName().contains("__enzyme_batch")) {
enableEnzyme = true;
batch = true;
} else if (Fn->getName().contains("__enzyme_truncate")) {
} else if (Fn->getName().contains("__enzyme_truncate_func")) {
enableEnzyme = true;
truncate = true;
truncateFunc = true;
} else if (Fn->getName().contains("__enzyme_truncate_value")) {
enableEnzyme = true;
truncateValue = true;
} else if (Fn->getName().contains("__enzyme_expand_value")) {
enableEnzyme = true;
expandValue = true;
} else if (Fn->getName().contains("__enzyme_likelihood")) {
enableEnzyme = true;
probProgMode = ProbProgMode::Likelihood;
Expand Down Expand Up @@ -2488,8 +2520,12 @@ class EnzymeBase {
toSize[CI] = derivativeMode;
else if (batch)
toBatch.push_back(CI);
else if (truncate)
toTruncate.push_back(CI);
else if (truncateFunc)
toTruncateFunc.push_back(CI);
else if (truncateValue)
toTruncateValue.push_back(CI);
else if (expandValue)
toExpandValue.push_back(CI);
else if (probProg) {
toProbProg[CI] = probProgMode;
} else
Expand Down Expand Up @@ -2583,8 +2619,14 @@ class EnzymeBase {
for (auto call : toBatch) {
HandleBatch(call);
}
for (auto call : toTruncate) {
HandleTruncate(call);
for (auto call : toTruncateFunc) {
HandleTruncateFunc(call);
}
for (auto call : toTruncateValue) {
HandleTruncateValue(call, true);
}
for (auto call : toExpandValue) {
HandleTruncateValue(call, false);
}

for (auto &&[call, mode] : toProbProg) {
Expand Down
134 changes: 98 additions & 36 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4813,11 +4813,53 @@ Function *EnzymeLogic::CreateForwardDiff(
return nf;
}

static Type *getTypeForWidth(LLVMContext &ctx, unsigned width) {
switch (width) {
default:
return llvm::Type::getIntNTy(ctx, width);
case 64:
return llvm::Type::getDoubleTy(ctx);
case 32:
return llvm::Type::getFloatTy(ctx);
case 16:
return llvm::Type::getHalfTy(ctx);
}
}

static Value *floatTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock,
unsigned fromwidth, unsigned towidth) {
Type *fromTy = getTypeForWidth(B.getContext(), fromwidth);
Type *toTy = getTypeForWidth(B.getContext(), towidth);
if (!tmpBlock)
tmpBlock = B.CreateAlloca(fromTy);
B.CreateStore(
v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType())));
return B.CreateLoad(
toTy, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(toTy)));
}

static Value *floatExpand(IRBuilderBase &B, Value *v, Value *tmpBlock,
unsigned fromwidth, unsigned towidth) {
Type *fromTy = getTypeForWidth(B.getContext(), fromwidth);
if (!tmpBlock)
tmpBlock = B.CreateAlloca(fromTy);
auto c0 =
Constant::getNullValue(llvm::Type::getIntNTy(B.getContext(), fromwidth));
B.CreateStore(
c0, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(c0->getType())));
B.CreateStore(
v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType())));
return B.CreateLoad(
fromTy, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(fromTy)));
}

class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator> {
private:
ValueToValueMapTy &originalToNewFn;
unsigned fromwidth;
unsigned towidth;
Type *fromType;
Type *toType;
Function *oldFunc;
Function *newFunc;
AllocaInst *tmpBlock;
Expand All @@ -4830,7 +4872,11 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator> {
: originalToNewFn(originalToNewFn), fromwidth(fromwidth),
towidth(towidth), oldFunc(oldFunc), newFunc(newFunc), Logic(Logic) {
IRBuilder<> B(&newFunc->getEntryBlock().front());
tmpBlock = B.CreateAlloca(getTypeForWidth(fromwidth));

fromType = getTypeForWidth(B.getContext(), fromwidth);
toType = getTypeForWidth(B.getContext(), towidth);

tmpBlock = B.CreateAlloca(fromType);
}

void visitInstruction(llvm::Instruction &inst) {
Expand All @@ -4848,42 +4894,16 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator> {
todo(inst);
}

Type *getTypeForWidth(unsigned width) {
switch (width) {
default:
return llvm::Type::getIntNTy(oldFunc->getContext(), width);
case 64:
return llvm::Type::getDoubleTy(oldFunc->getContext());
case 32:
return llvm::Type::getFloatTy(oldFunc->getContext());
case 16:
return llvm::Type::getHalfTy(oldFunc->getContext());
}
}
Type *getFromType() { return fromType; }

Type *getFromType() { return getTypeForWidth(fromwidth); }

Type *getToType() { return getTypeForWidth(towidth); }
Type *getToType() { return toType; }

Value *truncate(IRBuilder<> &B, Value *v) {
Type *nextType = getTypeForWidth(towidth);
B.CreateStore(
v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType())));
return B.CreateLoad(
nextType,
B.CreatePointerCast(tmpBlock, PointerType::getUnqual(nextType)));
return floatTruncate(B, v, tmpBlock, fromwidth, towidth);
}

Value *expand(IRBuilder<> &B, Value *v) {
Type *origT = getFromType();
auto c0 = Constant::getNullValue(
llvm::Type::getIntNTy(oldFunc->getContext(), fromwidth));
B.CreateStore(c0, B.CreatePointerCast(
tmpBlock, PointerType::getUnqual(c0->getType())));
B.CreateStore(
v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType())));
return B.CreateLoad(
origT, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(origT)));
return floatExpand(B, v, tmpBlock, fromwidth, towidth);
}

void todo(llvm::Instruction &I) {
Expand Down Expand Up @@ -5180,7 +5200,7 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator> {

Value *GetShadow(RequestContext &ctx, Value *v) {
if (auto F = dyn_cast<Function>(v))
return Logic.CreateTruncate(ctx, F, fromwidth, towidth);
return Logic.CreateTruncateFunc(ctx, F, fromwidth, towidth);
llvm::errs() << " unknown get truncated func: " << *v << "\n";
llvm_unreachable("unknown get truncated func");
return v;
Expand All @@ -5203,10 +5223,52 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator> {
}
};

llvm::Function *EnzymeLogic::CreateTruncate(RequestContext context,
llvm::Function *totrunc,
unsigned fromwidth,
unsigned towidth) {
bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v,
unsigned fromwidth, unsigned towidth,
bool isTruncate) {
assert(context.req && context.ip);

if (fromwidth == towidth) {
context.req->eraseFromParent();
return true;
}

if (fromwidth < towidth) {
std::string s;
llvm::raw_string_ostream ss(s);
ss << "Cannot truncate into a large width\n";
if (context.req) {
ss << " at context: " << *context.req;
EmitFailure("NoTruncate", context.req->getDebugLoc(), context.req,
ss.str());
return false;
}
llvm_unreachable("failed to truncate value");
}

IRBuilderBase &B = *context.ip;
Type *fromTy = getTypeForWidth(B.getContext(), fromwidth);
Type *toTy = getTypeForWidth(B.getContext(), towidth);

Value *converted = nullptr;
if (isTruncate)
converted =
floatExpand(B, B.CreateFPTrunc(v, toTy), nullptr, fromwidth, towidth);
else
converted =
B.CreateFPExt(floatTruncate(B, v, nullptr, fromwidth, towidth), fromTy);
assert(converted);

context.req->replaceAllUsesWith(converted);
context.req->eraseFromParent();

return true;
}

llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context,
llvm::Function *totrunc,
unsigned fromwidth,
unsigned towidth) {
if (fromwidth == towidth)
return totrunc;

Expand Down
10 changes: 7 additions & 3 deletions enzyme/Enzyme/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

extern "C" {
extern llvm::cl::opt<bool> EnzymePrint;
extern llvm::cl::opt<bool> EnzymeJuliaAddrLoad;
}

enum class AugmentedStruct { Tape, Return, DifferentialReturn };
Expand Down Expand Up @@ -512,9 +513,12 @@ class EnzymeLogic {

using TruncateCacheKey = std::tuple<llvm::Function *, unsigned, unsigned>;
std::map<TruncateCacheKey, llvm::Function *> TruncateCachedFunctions;
llvm::Function *CreateTruncate(RequestContext context,
llvm::Function *tobatch, unsigned fromwidth,
unsigned towidth);
llvm::Function *CreateTruncateFunc(RequestContext context,
llvm::Function *tobatch,
unsigned fromwidth, unsigned towidth);
bool CreateTruncateValue(RequestContext context, llvm::Value *addr,
unsigned fromwidth, unsigned towidth,
bool isTruncate);

/// Create a traced version of a function
/// \p context the instruction which requested this trace (or null).
Expand Down
7 changes: 7 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9325,6 +9325,13 @@ bool GradientUtils::needsCacheWholeAllocation(
if (found == knownRecomputeHeuristic.end())
continue;

// If caching a julia base object, this is fine as
// GC will deal with any issues with.
if (auto PT = dyn_cast<PointerType>(cur->getType()))
if (PT->getAddressSpace() == 10)
if (EnzymeJuliaAddrLoad)
continue;

// If caching this user, it cannot be a gep/cast of original
if (!found->second) {
llvm::errs() << " mod: " << *oldFunc->getParent() << "\n";
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/Truncate/cmp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ define i1 @f(double %x, double %y) {
ret i1 %res
}

declare i1 (double, double)* @__enzyme_truncate(...)
declare i1 (double, double)* @__enzyme_truncate_func(...)

define i1 @tester(double %x, double %y) {
entry:
%ptr = call i1 (double, double)* (...) @__enzyme_truncate(i1 (double, double)* @f, i64 64, i64 32)
%ptr = call i1 (double, double)* (...) @__enzyme_truncate_func(i1 (double, double)* @f, i64 64, i64 32)
%res = call i1 %ptr(double %x, double %y)
ret i1 %res
}
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/Truncate/intrinsic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ define double @f(double %x, double %y) {
ret double %res
}

declare double (double, double)* @__enzyme_truncate(...)
declare double (double, double)* @__enzyme_truncate_func(...)

define double @tester(double %x, double %y) {
entry:
%ptr = call double (double, double)* (...) @__enzyme_truncate(double (double, double)* @f, i64 64, i64 32)
%ptr = call double (double, double)* (...) @__enzyme_truncate_func(double (double, double)* @f, i64 64, i64 32)
%res = call double %ptr(double %x, double %y)
ret double %res
}
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/Truncate/select.ll
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ define double @f(double %x, double %y, i1 %cond) {
ret double %res
}

declare double (double, double, i1)* @__enzyme_truncate(...)
declare double (double, double, i1)* @__enzyme_truncate_func(...)

define double @tester(double %x, double %y, i1 %cond) {
entry:
%ptr = call double (double, double, i1)* (...) @__enzyme_truncate(double (double, double, i1)* @f, i64 64, i64 32)
%ptr = call double (double, double, i1)* (...) @__enzyme_truncate_func(double (double, double, i1)* @f, i64 64, i64 32)
%res = call double %ptr(double %x, double %y, i1 %cond)
ret double %res
}
Expand Down
Loading

0 comments on commit 1bb9b90

Please sign in to comment.