Skip to content

Commit

Permalink
WIP variadic
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 18, 2025
1 parent 5c632cc commit 3cab0c2
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 9 deletions.
3 changes: 3 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def AssertingInactiveArg : InactiveArgSpec {
bit asserting = 1;
}

class Variadic<string getter_> {
string getter = getter_;
}

def Unimplemented {

Expand Down
17 changes: 15 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class MEnzymeLogic {
unsigned width;
mlir::Type additionalType;
const MFnTypeInfo typeInfo;
bool omp;

inline bool operator<(const MForwardCacheKey &rhs) const {
if (todiff < rhs.todiff)
Expand Down Expand Up @@ -100,6 +101,12 @@ class MEnzymeLogic {
return true;
if (rhs.typeInfo < typeInfo)
return false;

if (omp < rhs.omp)
return true;
if (rhs.omp < omp)
return false;

// equal
return false;
}
Expand All @@ -117,6 +124,7 @@ class MEnzymeLogic {
mlir::Type additionalType;
const MFnTypeInfo typeInfo;
const std::vector<bool> volatileArgs;
bool omp;

inline bool operator<(const MReverseCacheKey &rhs) const {
if (todiff < rhs.todiff)
Expand Down Expand Up @@ -182,6 +190,11 @@ class MEnzymeLogic {
if (rhs.volatileArgs < volatileArgs)
return false;

if (omp < rhs.omp)
return true;
if (rhs.omp < omp)
return false;

// equal
return false;
}
Expand All @@ -196,7 +209,7 @@ class MEnzymeLogic {
std::vector<bool> returnPrimals, DerivativeMode mode,
bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args,
void *augmented, llvm::StringRef postpasses);
void *augmented, bool omp, llvm::StringRef postpasses);

FunctionOpInterface
CreateReverseDiff(FunctionOpInterface fn, std::vector<DIFFE_TYPE> retType,
Expand All @@ -205,7 +218,7 @@ class MEnzymeLogic {
std::vector<bool> returnShadows, DerivativeMode mode,
bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args,
void *augmented, llvm::StringRef postpasses);
void *augmented, bool omp, llvm::StringRef postpasses);

void
initializeShadowValues(SmallVector<mlir::Block *> &dominatorToposortBlocks,
Expand Down
3 changes: 2 additions & 1 deletion enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
std::vector<bool> returnPrimals, std::vector<bool> returnShadows,
DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented,
bool omp,
llvm::StringRef postpasses) {

if (fn.getFunctionBody().empty()) {
Expand Down Expand Up @@ -217,7 +218,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(

MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone(
*this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
retType, constants, addedType, postpasses);
retType, constants, addedType, omp, postpasses);

ReverseCachedFunctions[tup] = gutils->newFunc;

Expand Down
8 changes: 4 additions & 4 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse(
ArrayRef<DIFFE_TYPE> ReturnActivity, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode_, unsigned width, StringRef postpasses)
DerivativeMode mode_, unsigned width, bool omp, StringRef postpasses)
: MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {},
invertedPointers_, returnPrimals, returnShadows,
constantvalues_, activevals_, ReturnActivity,
ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_,
mode_, width, /*omp*/ false, postpasses) {}
mode_, width, omp, postpasses) {}

Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() {
Type indexType = getIndexType();
Expand Down Expand Up @@ -138,7 +138,7 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone(
FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
const ArrayRef<bool> returnPrimals, const ArrayRef<bool> returnShadows,
ArrayRef<DIFFE_TYPE> retType, ArrayRef<DIFFE_TYPE> constant_args,
mlir::Type additionalArg, llvm::StringRef postpasses) {
mlir::Type additionalArg, bool omp, llvm::StringRef postpasses) {
std::string prefix;

switch (mode_) {
Expand Down Expand Up @@ -174,5 +174,5 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone(
return new MGradientUtilsReverse(
Logic, newFunc, todiff, TA, invertedPointers, returnPrimals,
returnShadows, constant_values, nonconstant_values, retType,
constant_args, originalToNew, originalToNewOps, mode_, width, postpasses);
constant_args, originalToNew, originalToNewOps, mode_, width, omp, postpasses);
}
37 changes: 35 additions & 2 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,8 +1394,30 @@ static VariableSetting parseVariables(const DagInit *tree, ActionType intrinsic,
for (auto tree : ptree->getArgs()) {
SmallVector<unsigned, 2> next(prev.begin(), prev.end());
next.push_back(i);
if (auto dg = dyn_cast<DagInit>(tree))
if (auto dg = dyn_cast<DagInit>(tree)) {
if (ptree->getArgNameStr(i).size()) {
auto opName = dg->getOperator()->getAsString();
auto Def = cast<DefInit>(dg->getOperator())->getDef();
if (opName == "Variadic" || Def->isSubClassOf("Variadic")) {
auto expr = Def->getValueAsString("getter");
std::string op;
if (intrinsic != MLIRDerivatives)
op = (origName + "." + expr + "()").str();
else
op = (origName + "->" + expr + "()").str();
std::vector<int> extractions;
if (prev.size() > 0) {
for (unsigned i = 1; i < next.size(); i++) {
extractions.push_back(next[i]);
}
}
nameToOrdinal.insert(ptree->getArgNameStr(i), op, false,
extractions);
continue;
}
}
insert(dg, next);
}

if (ptree->getArgNameStr(i).size()) {
std::string op;
Expand Down Expand Up @@ -1580,8 +1602,19 @@ static void emitMLIRReverse(raw_ostream &os, const Record *pattern,
auto name = ptree->getArgNameStr(treeEn.index());
SmallVector<unsigned, 2> next(prev.begin(), prev.end());
next.push_back(treeEn.index());
if (auto dg = dyn_cast<DagInit>(tree))
if (auto dg = dyn_cast<DagInit>(tree)) {
if (name.size()) {
auto opName = dg->getOperator()->getAsString();
auto Def = cast<DefInit>(dg->getOperator())->getDef();
if (opName == "Variadic" || Def->isSubClassOf("Variadic")) {
auto expr = Def->getValueAsString("getter");
varNameToCondition[name] = std::make_tuple(
("llvm::is_contained(op->getOperand(idx), op." + expr + "())").str(), "", false);
continue;
}
}
insert(dg, next);
}

if (name.size()) {
varNameToCondition[name] = std::make_tuple(
Expand Down

0 comments on commit 3cab0c2

Please sign in to comment.