Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tablegen: Add StaticSelect to select based on static condition #2206

Merged
merged 7 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions enzyme/Enzyme/InstructionDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ class ConstantFP<string val> : Operation</*primal*/0, /*shadow*/0> {
string value = val;
}

class StaticSelect<string condition_> : Operation</*primal*/0, /*shadow*/0, /*custom*/0> {
string condition = condition_;
}

def SelectIfActive : StaticSelect<"!gutils->isConstantValue(imVal)">;

class Attribute<string name_> {
string name = name_;
Expand Down Expand Up @@ -62,9 +67,6 @@ class Inst<string mnemonic> : Operation</*primal*/1, /*shadow*/0> {
def TypeOf : Operation</*primal*/0, /*shadow*/0> {
}
def VectorSize : Operation</*primal*/0, /*shadow*/0> {
}
def SelectIfActive : Operation</*primal*/0, /*shadow*/0, /*custom*/1> {

}

// Define ops to rewrite.
Expand Down
25 changes: 18 additions & 7 deletions enzyme/Enzyme/MLIR/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,20 @@ def DiffeRet : DiffeRetIndex<[-1]>;
def Shadow : Operation</*primal*/0, /*shadow*/1> {
}

class GlobalExpr<bit uses_primal, bit uses_shadow, string val> : Operation<uses_primal, uses_shadow>{
class GlobalExpr<bit uses_primal, bit uses_shadow, string val> : Operation<uses_primal, uses_shadow> {
string value = val;
}

// Class for a dag operator that generates either a or b
// It can then be used with a two or three arguments.
// The two arguments version is (StaticSelect a, b)
// The three arguments version accepts a name as a first argument
// which is then available in the condition as a `Value` under the
// variable `imVal`.
class StaticSelect<string condition_> : Operation</*usesPrimal*/0, /*usesShadow*/0, /*usesCustom*/0> {
string condition = condition_;
}

class Inst<string mnemonic, string dialect_, string postop_=""> : Operation</*primal*/1, /*shadow*/0> {
string name = mnemonic;
string dialect = dialect_;
Expand All @@ -99,13 +109,14 @@ class Inst<string mnemonic, string dialect_, string postop_=""> : Operation</*p
def Op {
}

def SelectIfActive : Operation</*primal*/0, /*shadow*/0, /*custom*/1> {

}

def SelectIfComplex : Operation</*primal*/1, /*shadow*/0, /*custom*/0> {
def SelectIfActive : StaticSelect<"!gutils->isConstantValue(imVal)">;

}
def SelectIfComplex : StaticSelect<[{
auto ty = imVal.getType();
ty.isa<ComplexType>() ||
ty.isa<TensorType>() &&
ty.cast<TensorType>().getElementType().isa<ComplexType>();
}]>;

class ConstantFP<string val, string dialect_, string op_, string type_=""> : Operation</*primal*/0, /*shadow*/0> {
string value = val;
Expand Down
Loading
Loading