From 7b97a9b6b6f92516a733d12c78a33f118e63091d Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Fri, 29 Nov 2024 20:00:10 +0000 Subject: [PATCH] tblgen: Implement SelectIfComplex (#2183) * selectifcomplex * remove ConjIfComplex this is unused and can now be implemented using SelectIfComplex * used primal --- enzyme/Enzyme/MLIR/Implementations/Common.td | 9 ++- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 65 ++++++++++++++------ 2 files changed, 49 insertions(+), 25 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index c40be825a827..9846f9ae6183 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -103,6 +103,10 @@ def SelectIfActive : Operation { } +def SelectIfComplex : Operation { + +} + class ConstantFP : Operation { string value = val; string dialect = dialect_; @@ -110,11 +114,6 @@ class ConstantFP : Ope string type = type_; } -class ConjIfComplex : Operation { - string dialect = dialect_; - string opName = op_; -} - def ResultTypes : GlobalExprgetResultTypes()">; def TypeOf : Operation { diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index b6848cedcebe..f789a125b222 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -488,29 +488,24 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, os << curIndent << INDENT << "imVal;\n"; os << curIndent << "})"; return true; - } else if (opName == "ConjIfComplex" || - Def->isSubClassOf("ConjIfComplex")) { - if (resultRoot->getNumArgs() != 1) + } else if (opName == "SelectIfComplex" || + Def->isSubClassOf("SelectIfComplex")) { + if (resultRoot->getNumArgs() != 3) PrintFatalError(pattern->getLoc(), - "only three op ConjIfComplex supported"); + "only three op SelectIfComplex supported"); os << "({\n"; - os << curIndent << INDENT << "// Computing ConjIfComplex\n"; + os << curIndent << INDENT << "// Computing SelectIfComplex\n"; if (intrinsic == MLIRDerivatives) - os << curIndent << INDENT << "mlir::Value imVal"; + os << curIndent << INDENT << "mlir::Value imVal = "; else - os << curIndent << INDENT << "llvm::Value *imVal"; - - os << curIndent << INDENT << "if (!gutils->isConstantValue("; + os << curIndent << INDENT << "llvm::Value *imVal = "; if (isa(resultRoot->getArg(0)) && resultRoot->getArgName(0)) { auto name = resultRoot->getArgName(0)->getAsUnquotedString(); auto [ord, isVec, ext] = nameToOrdinal.lookup(name, pattern, resultRoot); - os << ord; - assert(!ext.size()); - os << ord; - os << ";\n"; + os << ord << ";\n"; } else { handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern, resultRoot->getArg(0), builder, nameToOrdinal, lookup, retidx, @@ -518,15 +513,45 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, os << ";\n"; } - os << " (isa(imVal.getType()) || " + os << curIndent << INDENT + << "if (isa(imVal.getType()) || " "(isa(imVal.getType()) && " "isa(cast(imVal.getType()).getElementType(" - ")))) ? "; - os << builder << ".create<" - << cast(Def->getValueInit("dialect"))->getValue() - << "::" << cast(Def->getValueInit("opName"))->getValue() - << ">(op.getLoc(), imVal.getType(), imVal) : imVal;\n"; - os << curIndent << "})"; + ")))) {\n"; + + os << curIndent << INDENT << INDENT << "imVal = "; + if (isa(resultRoot->getArg(1)) && resultRoot->getArgName(1)) { + auto name = resultRoot->getArgName(1)->getAsUnquotedString(); + auto [ord, isVec, ext] = + nameToOrdinal.lookup(name, pattern, resultRoot); + assert(!ext.size()); + os << ord << ";\n"; + } else { + handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern, + resultRoot->getArg(1), builder, nameToOrdinal, lookup, retidx, + origName, newFromOriginal, intrinsic); + os << ";\n"; + } + + os << curIndent << INDENT << "} else {\n"; + + os << curIndent << INDENT << INDENT << "imVal = "; + if (isa(resultRoot->getArg(2)) && resultRoot->getArgName(2)) { + auto name = resultRoot->getArgName(2)->getAsUnquotedString(); + auto [ord, isVec, ext] = + nameToOrdinal.lookup(name, pattern, resultRoot); + assert(!ext.size()); + os << ord << ";\n"; + } else { + handle(curIndent + INDENT + INDENT, argPattern + "_cic", os, pattern, + resultRoot->getArg(2), builder, nameToOrdinal, lookup, retidx, + origName, newFromOriginal, intrinsic); + os << ";\n"; + } + + os << curIndent << INDENT << "}\n"; + os << curIndent << INDENT << "imVal;"; + os << curIndent << INDENT << "})\n"; return true; } else if (opName == "ConstantFP" || Def->isSubClassOf("ConstantFP")) { auto value = dyn_cast(Def->getValueInit("value"));