Skip to content

Commit

Permalink
feat: add aggregate op and extend functions accordingly
Browse files Browse the repository at this point in the history
This PR adds support for the `AggregateRel` from the Substrait spec in
the form of the `aggregate` op. This is arguably the most complex op
implemented so far. It has an optional enum argument that requires
custom parsing, several optional regions that require custom parsing, an
attribute that depends on the presence and contents of the regions and
requires custom parsing to omit it in the common case, and return types
that depend on the two regions and the attribute. What's more, the
current version of the spec is such that it is almost impossibly to
interpret "grouping sets" because it relies on protobuf message
equality, which is something can protobuf does not offer. The current
implementation, thus, implements a best effort by using op equality
instead (but needs to run CSE during export to ensure op uniqueness).
Finally, the PR also extends the `call` op to represent also
`AggregateFunction` messages (in addition to `ScalarFunction` messages),
which are used by the new `aggregate` op.
  • Loading branch information
ingomueller-net committed Jan 24, 2025
1 parent 737d758 commit 1245335
Show file tree
Hide file tree
Showing 13 changed files with 2,012 additions and 41 deletions.
18 changes: 18 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class Substrait_Attr<string name, string typeMnemonic, list<Trait> traits = []>
let mnemonic = typeMnemonic;
}

//===----------------------------------------------------------------------===//
// Substrait attributes
//===----------------------------------------------------------------------===//

def Substrait_AdvancedExtensionAttr
: Substrait_Attr<"AdvancedExtension", "advanced_extension"> {
let summary = "Represents the `AdvancedExtenssion` message of Substrait";
Expand Down Expand Up @@ -93,6 +97,10 @@ def Substrait_TimestampTzAttr : Substrait_Attr<"TimestampTz", "timestamp_tz",
}];
}

//===----------------------------------------------------------------------===//
// Helpers and constraints
//===----------------------------------------------------------------------===//

/// Attributes of currently supported atomic types, listed in order of substrait
/// specification.
def Substrait_AtomicAttributes {
Expand All @@ -116,4 +124,14 @@ def Substrait_AtomicAttributes {
/// Attribute of one of the currently supported atomic types.
def Substrait_AtomicAttribute : AnyAttrOf<Substrait_AtomicAttributes.attrs>;

/// `ArrayAttr` of `ArrayAttr`s if `i64`s.
def I64ArrayArrayAttr : TypedArrayAttrBase<
I64ArrayAttr, "64-bit integer array array attribute"
>;

/// `ArrayAttr` of `ArrayAttr`s if `i64`s with at least one element.
def NonEmptyI64ArrayArrayAttr :
ConfinedAttr<I64ArrayArrayAttr, [ArrayMinCount<1>]>;


#endif // SUBSTRAIT_DIALECT_SUBSTRAIT_IR_SUBSTRAITATTRS
10 changes: 10 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ def AggregationInvocationUnspecified: I32EnumAttrCase<"unspecified", 0>;
def AggregationInvocationAll: I32EnumAttrCase<"all", 1>;
def AggregationInvocationDistinct: I32EnumAttrCase<"distinct", 2>;

/// Represents the `AggregationInvocation` protobuf enum.
def AggregationInvocation : I32EnumAttr<
"AggregationInvocation", "aggregate invocation type", [
AggregationInvocationUnspecified,
AggregationInvocationAll,
AggregationInvocationDistinct
]> {
let cppNamespace = "::mlir::substrait";
}

/// Represents the `JoinType` protobuf enum.
def JoinTypeKind : I32EnumAttr<"JoinTypeKind",
"The enum values correspond to those in the JoinRel.JoinType message.", [
Expand Down
106 changes: 101 additions & 5 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def Substrait_PlanRelOp : Substrait_Op<"relation", [
def Substrait_YieldOp : Substrait_Op<"yield", [
Terminator,
ParentOneOf<[
"::mlir::substrait::AggregateOp",
"::mlir::substrait::FilterOp",
"::mlir::substrait::PlanRelOp",
"::mlir::substrait::ProjectOp"
Expand Down Expand Up @@ -308,9 +309,11 @@ def Substrait_CallOp : Substrait_ExpressionOp<"call", [
]> {
let summary = "Function call expression";
let description = [{
Represents a `ScalarFunction` message (or, in the future, other `*Function`
messages) together with all messages it contains and the `Expression`
message it is contained in.
Represents a `ScalarFunction` or `AggregateFunction` message (or, in the
future, a `WindowFunction` message) together with all messages it contains
and, where applicable, the `Expression` message it is contained in. Which of
the message types this op corresponds to depends on the presence of the
(otherwise optional) aggregate or window-related attributes.

Currently, the specification of the function, which is in an external YAML
file, is not taken into account, for example, to verify whether a matching
Expand All @@ -332,11 +335,33 @@ def Substrait_CallOp : Substrait_ExpressionOp<"call", [
// TODO(ingomueller): Add support for `enum` and `type` argument types.
let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<Substrait_FieldType>:$args
Variadic<Substrait_FieldType>:$args,
OptionalAttr<AggregationInvocation>:$aggregation_invocation
);
let results = (outs Substrait_FieldType:$result);
let assemblyFormat = [{
$callee `(` $args `)` attr-dict `:` `(` type($args) `)` `->` type($result)
$callee `(` $args `)`
(`aggregate` `` custom<AggregationInvocation>($aggregation_invocation)^)?
attr-dict `:` `(` type($args) `)` `->` type($result)
}];
let builders = [
OpBuilder<(ins "::mlir::Type":$result,
"::mlir::FlatSymbolRefAttr":$callee,
"::mlir::ValueRange":$args), [{
build($_builder, $_state, result, callee, args,
AggregationInvocationAttr());
}]>,
OpBuilder<(ins "::mlir::Type":$result, "::llvm::StringRef":$callee,
"::mlir::ValueRange":$args), [{
build($_builder, $_state, result, callee, args,
AggregationInvocationAttr());
}]>
];
let extraClassDeclaration = [{
// Helpers to distinguish function types.
bool isAggregate() { return getAggregationInvocation().has_value(); }
bool isScalar() { return !isAggregate() && !isWindow(); }
bool isWindow() { return false; } // TODO: change once supported.
}];
}

Expand All @@ -360,6 +385,77 @@ class Substrait_RelOp<string mnemonic, list<Trait> traits = []> :
]>>
]>;

def Substrait_AggregateOp : Substrait_RelOp<"aggregate", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>,
SingleBlockImplicitTerminator<"::mlir::substrait::YieldOp">,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
]> {
let summary = "Aggregate operation";
let description = [{
Represents an `AggregateRel ` message together with the `RelCommon` and the
messages it contains. The `measures` field is represented as a region where
the yielded values correspond to the `AggregateFunction`s (and thus have
to be produced by a `CallOp` representing an aggregate function). Filters
are currently not supported. The `groupings` field is represented as a
region yielding the unique (deduplicated) grouping expressions and an array
of array of references to these expressions representing the grouping sets.
An empty array of grouping sets corresponds to *no* `groupings` messages;
an array with an empty grouping set corresponds to an *empty* `groupings`
messages. These two protobuf representations are different even though their
semantic is equivalent. The op can only be exported to the protobuf format
if the expressions yielded by the `groupings` region are all distinct after
CSE. The assembly format omits an empty region of groupings, an empty region
of measures, and the grouping sets attribute with one grouping set that
consists of all values yielded from `groupings` (or the empty grouping set
if that region is empty).

Example:

```mlir
%0 = ...
%1 = aggregate %0 : tuple<si32> -> tuple<si32, si32>
groupings {
^bb0(%arg : tuple<si32>):
%2 = field_reference %arg[0] : tuple<si32>
yield %2 : si32
}
grouping_sets [[0]]
measures {
^bb0(%arg : tuple<si32>):
%2 = field_reference %arg[0] : tuple<si32>
%3 = call @function(%2) aggregate : (si32) -> si32
yield %3 : si32
}
```
}];
let arguments = (ins
Substrait_Relation:$input,
I64ArrayArrayAttr:$grouping_sets
);
let results = (outs Substrait_Relation:$result);
let regions = (region
AnyRegion:$groupings,
AnyRegion:$measures
);
let assemblyFormat = [{
$input attr-dict `:` type($input) `->` type($result)
custom<AggregateRegions>($groupings, $measures, $grouping_sets)
}];
let hasRegionVerifier = 1;
let builders = [
OpBuilder<(ins
"::mlir::Value":$input, "::mlir::ArrayAttr":$grouping_sets,
"::mlir::Region *":$groupings, "::mlir::Region *":$measures
)>,
];
let extraClassDefinition = [{
/// Implement OpAsmOpInterface.
::llvm::StringRef $cppClass::getDefaultDialect() {
return SubstraitDialect::getDialectNamespace();
}
}];
}

def Substrait_CrossOp : Substrait_RelOp<"cross", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
Expand Down
Loading

0 comments on commit 1245335

Please sign in to comment.