diff --git a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitAttrs.td b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitAttrs.td index 2098ca7c..467682bf 100644 --- a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitAttrs.td +++ b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitAttrs.td @@ -19,6 +19,10 @@ class Substrait_Attr traits = []> let mnemonic = typeMnemonic; } +//===----------------------------------------------------------------------===// +// Substrait attributes +//===----------------------------------------------------------------------===// + def Substrait_AdvancedExtensionAttr : Substrait_Attr<"AdvancedExtension", "advanced_extension"> { let summary = "Represents the `AdvancedExtenssion` message of Substrait"; @@ -93,6 +97,10 @@ def Substrait_TimestampTzAttr : Substrait_Attr<"TimestampTz", "timestamp_tz", }]; } +//===----------------------------------------------------------------------===// +// Helpers and constraints +//===----------------------------------------------------------------------===// + /// Attributes of currently supported atomic types, listed in order of substrait /// specification. def Substrait_AtomicAttributes { @@ -116,4 +124,14 @@ def Substrait_AtomicAttributes { /// Attribute of one of the currently supported atomic types. def Substrait_AtomicAttribute : AnyAttrOf; +/// `ArrayAttr` of `ArrayAttr`s if `i64`s. +def I64ArrayArrayAttr : TypedArrayAttrBase< + I64ArrayAttr, "64-bit integer array array attribute" + >; + +/// `ArrayAttr` of `ArrayAttr`s if `i64`s with at least one element. +def NonEmptyI64ArrayArrayAttr : + ConfinedAttr]>; + + #endif // SUBSTRAIT_DIALECT_SUBSTRAIT_IR_SUBSTRAITATTRS diff --git a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.td b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.td index c3721b56..7dcb1b80 100644 --- a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.td +++ b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.td @@ -15,6 +15,16 @@ def AggregationInvocationUnspecified: I32EnumAttrCase<"unspecified", 0>; def AggregationInvocationAll: I32EnumAttrCase<"all", 1>; def AggregationInvocationDistinct: I32EnumAttrCase<"distinct", 2>; +/// Represents the `AggregationInvocation` protobuf enum. +def AggregationInvocation : I32EnumAttr< + "AggregationInvocation", "aggregate invocation type", [ + AggregationInvocationUnspecified, + AggregationInvocationAll, + AggregationInvocationDistinct + ]> { + let cppNamespace = "::mlir::substrait"; +} + /// Represents the `JoinType` protobuf enum. def JoinTypeKind : I32EnumAttr<"JoinTypeKind", "The enum values correspond to those in the JoinRel.JoinType message.", [ diff --git a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td index 0a5651d8..ffe2e776 100644 --- a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td +++ b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td @@ -231,6 +231,7 @@ def Substrait_PlanRelOp : Substrait_Op<"relation", [ def Substrait_YieldOp : Substrait_Op<"yield", [ Terminator, ParentOneOf<[ + "::mlir::substrait::AggregateOp", "::mlir::substrait::FilterOp", "::mlir::substrait::PlanRelOp", "::mlir::substrait::ProjectOp" @@ -308,9 +309,11 @@ def Substrait_CallOp : Substrait_ExpressionOp<"call", [ ]> { let summary = "Function call expression"; let description = [{ - Represents a `ScalarFunction` message (or, in the future, other `*Function` - messages) together with all messages it contains and the `Expression` - message it is contained in. + Represents a `ScalarFunction` or `AggregateFunction` message (or, in the + future, a `WindowFunction` message) together with all messages it contains + and, where applicable, the `Expression` message it is contained in. Which of + the message types this op corresponds to depends on the presence of the + (otherwise optional) aggregate or window-related attributes. Currently, the specification of the function, which is in an external YAML file, is not taken into account, for example, to verify whether a matching @@ -332,11 +335,33 @@ def Substrait_CallOp : Substrait_ExpressionOp<"call", [ // TODO(ingomueller): Add support for `enum` and `type` argument types. let arguments = (ins FlatSymbolRefAttr:$callee, - Variadic:$args + Variadic:$args, + OptionalAttr:$aggregation_invocation ); let results = (outs Substrait_FieldType:$result); let assemblyFormat = [{ - $callee `(` $args `)` attr-dict `:` `(` type($args) `)` `->` type($result) + $callee `(` $args `)` + (`aggregate` `` custom($aggregation_invocation)^)? + attr-dict `:` `(` type($args) `)` `->` type($result) + }]; + let builders = [ + OpBuilder<(ins "::mlir::Type":$result, + "::mlir::FlatSymbolRefAttr":$callee, + "::mlir::ValueRange":$args), [{ + build($_builder, $_state, result, callee, args, + AggregationInvocationAttr()); + }]>, + OpBuilder<(ins "::mlir::Type":$result, "::llvm::StringRef":$callee, + "::mlir::ValueRange":$args), [{ + build($_builder, $_state, result, callee, args, + AggregationInvocationAttr()); + }]> + ]; + let extraClassDeclaration = [{ + // Helpers to distinguish function types. + bool isAggregate() { return getAggregationInvocation().has_value(); } + bool isScalar() { return !isAggregate() && !isWindow(); } + bool isWindow() { return false; } // TODO: change once supported. }]; } @@ -360,6 +385,77 @@ class Substrait_RelOp traits = []> : ]>> ]>; +def Substrait_AggregateOp : Substrait_RelOp<"aggregate", [ + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"::mlir::substrait::YieldOp">, + DeclareOpInterfaceMethods, + ]> { + let summary = "Aggregate operation"; + let description = [{ + Represents an `AggregateRel ` message together with the `RelCommon` and the + messages it contains. The `measures` field is represented as a region where + the yielded values correspond to the `AggregateFunction`s (and thus have + to be produced by a `CallOp` representing an aggregate function). Filters + are currently not supported. The `groupings` field is represented as a + region yielding the unique (deduplicated) grouping expressions and an array + of array of references to these expressions representing the grouping sets. + An empty array of grouping sets corresponds to *no* `groupings` messages; + an array with an empty grouping set corresponds to an *empty* `groupings` + messages. These two protobuf representations are different even though their + semantic is equivalent. The op can only be exported to the protobuf format + if the expressions yielded by the `groupings` region are all distinct after + CSE. The assembly format omits an empty region of groupings, an empty region + of measures, and the grouping sets attribute with one grouping set that + consists of all values yielded from `groupings` (or the empty grouping set + if that region is empty). + + Example: + + ```mlir + %0 = ... + %1 = aggregate %0 : tuple -> tuple + groupings { + ^bb0(%arg : tuple): + %2 = field_reference %arg[0] : tuple + yield %2 : si32 + } + grouping_sets [[0]] + measures { + ^bb0(%arg : tuple): + %2 = field_reference %arg[0] : tuple + %3 = call @function(%2) aggregate : (si32) -> si32 + yield %3 : si32 + } + ``` + }]; + let arguments = (ins + Substrait_Relation:$input, + I64ArrayArrayAttr:$grouping_sets + ); + let results = (outs Substrait_Relation:$result); + let regions = (region + AnyRegion:$groupings, + AnyRegion:$measures + ); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + custom($groupings, $measures, $grouping_sets) + }]; + let hasRegionVerifier = 1; + let builders = [ + OpBuilder<(ins + "::mlir::Value":$input, "::mlir::ArrayAttr":$grouping_sets, + "::mlir::Region *":$groupings, "::mlir::Region *":$measures + )>, + ]; + let extraClassDefinition = [{ + /// Implement OpAsmOpInterface. + ::llvm::StringRef $cppClass::getDefaultDialect() { + return SubstraitDialect::getDialectNamespace(); + } + }]; +} + def Substrait_CrossOp : Substrait_RelOp<"cross", [ DeclareOpInterfaceMethods ]> { diff --git a/lib/Dialect/Substrait/IR/Substrait.cpp b/lib/Dialect/Substrait/IR/Substrait.cpp index 7a306bc6..c6dbfa0a 100644 --- a/lib/Dialect/Substrait/IR/Substrait.cpp +++ b/lib/Dialect/Substrait/IR/Substrait.cpp @@ -100,12 +100,311 @@ void printCountAsAll(OpAsmPrinter &printer, Operation *op, IntegerAttr count) { // Substrait operations //===----------------------------------------------------------------------===// +namespace mlir { +namespace substrait { + +static ParseResult +parseAggregationInvocation(OpAsmParser &parser, + AggregationInvocationAttr &aggregationInvocation); +static void +printAggregationInvocation(OpAsmPrinter &printer, CallOp op, + AggregationInvocationAttr aggregationInvocation); +static ParseResult parseAggregateRegions(OpAsmParser &parser, + Region &groupingsRegion, + Region &measuresRegion, + ArrayAttr &groupingSetsAttr); +static void printAggregateRegions(OpAsmPrinter &printer, AggregateOp op, + Region &groupingsRegion, + Region &measuresRegion, + ArrayAttr groupingSetsAttr); + +} // namespace substrait +} // namespace mlir + #define GET_OP_CLASSES #include "substrait-mlir/Dialect/Substrait/IR/SubstraitOps.cpp.inc" namespace mlir { namespace substrait { +ParseResult +parseAggregationInvocation(OpAsmParser &parser, + AggregationInvocationAttr &aggregationInvocation) { + // This is essentially copied from `FieldParser` but + // sets the default `unspecified` case if no invocation type is present. + + MLIRContext *context = parser.getContext(); + std::string keyword; + if (failed(parser.parseOptionalKeywordOrString(&keyword))) { + // No keyword parse --> use default value. + aggregationInvocation = AggregationInvocationAttr::get( + context, AggregationInvocation::unspecified); + return success(); + } + + // Symbolize the keyword. + if (std::optional attr = + symbolizeAggregationInvocation(keyword)) { + aggregationInvocation = + AggregationInvocationAttr::get(parser.getContext(), attr.value()); + return success(); + } + + // Symbolization failed. + auto loc = parser.getCurrentLocation(); + return parser.emitError(loc) + << "has invalid aggregate invocation type specification: " << keyword; +} + +void printAggregationInvocation( + OpAsmPrinter &printer, CallOp op, + AggregationInvocationAttr aggregationInvocation) { + if (aggregationInvocation && + aggregationInvocation.getValue() != AggregationInvocation::unspecified) { + // The whitespace printed here compensates the trimming of whitespace in + // the declarative assembly format. + printer << " " << aggregationInvocation.getValue(); + } +} + +ParseResult parseAggregateRegions(OpAsmParser &parser, Region &groupingsRegion, + Region &measuresRegion, + ArrayAttr &groupingSetsAttr) { + MLIRContext *context = parser.getContext(); + + // Parse `measures` and `groupings` regions as well as `grouping_sets` attr. + bool hasMeasures = false; + bool hasGroupings = false; + bool hasGroupingSets = false; + { + auto ensureOneOccurrance = [&](bool &hasParsed, + StringRef name) -> LogicalResult { + if (hasParsed) { + SMLoc loc = parser.getCurrentLocation(); + return parser.emitError(loc, llvm::Twine("can only have one ") + name); + } + hasParsed = true; + return success(); + }; + + StringRef keyword; + while (succeeded(parser.parseOptionalKeyword( + &keyword, {"measures", "groupings", "grouping_sets"}))) { + if (keyword == "measures") { + if (failed(ensureOneOccurrance(hasMeasures, "'measures' region")) || + failed(parser.parseRegion(measuresRegion))) + return failure(); + } else if (keyword == "groupings") { + if (failed(ensureOneOccurrance(hasGroupings, "'groupings' region")) || + failed(parser.parseRegion(groupingsRegion))) + return failure(); + } else if (keyword == "grouping_sets") { + if (failed(ensureOneOccurrance(hasGroupingSets, + "'grouping_sets' attribute")) || + failed(parser.parseAttribute(groupingSetsAttr))) + return failure(); + } + } + } + + // Create default value of `grouping_sets` attr if not provided. + if (!hasGroupingSets) { + // If there is no `groupings` region, create only the empty grouping set. + if (!hasGroupings) + groupingSetsAttr = ArrayAttr::get(context, ArrayAttr::get(context, {})); + // Otherwise, create the grouping set with all grouping columns. + else if (!groupingsRegion.empty()) { + auto yieldOp = + llvm::dyn_cast(groupingsRegion.front().getTerminator()); + if (yieldOp) { + unsigned numColumns = yieldOp->getNumOperands(); + SmallVector allColumns; + llvm::append_range(allColumns, llvm::seq(0u, numColumns)); + IRRewriter rewriter(context); + ArrayAttr allColumnsAttr = rewriter.getI64ArrayAttr(allColumns); + groupingSetsAttr = rewriter.getArrayAttr({allColumnsAttr}); + } + } + } + + return success(); +} + +void printAggregateRegions(OpAsmPrinter &printer, AggregateOp op, + Region &groupingsRegion, Region &measuresRegion, + ArrayAttr groupingSetsAttr) { + printer.increaseIndent(); + + // `groupings` region. + if (!groupingsRegion.empty()) { + printer.printNewline(); + printer.printKeywordOrString("groupings"); + printer << " "; + printer.printRegion(groupingsRegion); + } + + // `grouping_sets` attribute. + if (groupingSetsAttr.size() != 1) { + // Note: A single grouping set is always of the form `seq(0, size)`. + printer.printNewline(); + printer.printKeywordOrString("grouping_sets"); + printer << " "; + printer.printAttribute(groupingSetsAttr); + } + + // `measures` regions. + if (!measuresRegion.empty()) { + printer.printNewline(); + printer.printKeywordOrString("measures"); + printer << " "; + printer.printRegion(measuresRegion); + } + + printer.decreaseIndent(); +} + +void AggregateOp::build(OpBuilder &builder, OperationState &result, Value input, + ArrayAttr groupingSets, Region *groupings, + Region *measures) { + + MLIRContext *context = builder.getContext(); + auto loc = UnknownLoc::get(context); + AggregateOp::Properties properties; + properties.grouping_sets = groupingSets; + SmallVector regions = {groupings, measures}; + + // Infer `returnTypes` from provided arguments. If that fails, then + // `returnType` will be empty. The rest of this function will continue to + // work, but the op that is built in the end will not verify and the + // diagnostics of `inferReturnType` will have been emitted. + SmallVector returnTypes; + (void)AggregateOp::inferReturnTypes(context, loc, input, {}, + OpaqueProperties(&properties), regions, + returnTypes); + + // Call existing `build` function and move bodies into the new regions. + AggregateOp::build(builder, result, returnTypes, input, groupingSets); + result.regions[0]->takeBody(*groupings); + result.regions[1]->takeBody(*measures); +} + +LogicalResult AggregateOp::inferReturnTypes( + MLIRContext *context, std::optional loc, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + llvm::SmallVectorImpl &inferredReturnTypes) { + auto *typedProperties = properties.as(); + assert(typedProperties && "could not get typed properties"); + Region *groupings = regions[0]; + Region *measures = regions[1]; + SmallVector fieldTypes; + if (!loc) + loc = UnknownLoc::get(context); + + // The left-most output columns are the `groupings` columns, then the + // `measures` columns. + for (Region *region : {groupings, measures}) { + if (region->empty()) + continue; + auto yieldOp = llvm::cast(region->front().getTerminator()); + llvm::append_range(fieldTypes, yieldOp.getOperandTypes()); + } + + // If there is more than one `grouping_set`, then we also have an additional + // `si32` column for the grouping set ID. + if (typedProperties->grouping_sets.size() > 1) { + auto si32 = IntegerType::get(context, /*width=*/32, IntegerType::Signed); + fieldTypes.push_back(si32); + } + + // Build tuple type from field types. + auto resultType = TupleType::get(context, fieldTypes); + inferredReturnTypes.push_back(resultType); + + return success(); +} + +LogicalResult AggregateOp::verifyRegions() { + // Verify properties that need to hold for both regions. + auto inputTupleType = getInput().getType(); + for (auto [idx, region] : llvm::enumerate(getRegions())) { + if (region->empty()) // Regions are allowed to be empty. + continue; + + // Verify that the regions have the input tuple as argument. + if (region->getArgumentTypes() != TypeRange{inputTupleType}) + return emitOpError() << "has region #" << idx + << " with invalid argument types (expected: " + << inputTupleType + << ", got: " << region->getArgumentTypes() << ")"; + + // Verify that at least one value is yielded. + auto yieldOp = llvm::cast(region->front().getTerminator()); + if (yieldOp->getNumOperands() == 0) + return emitOpError() + << "has region #" << idx + << " that yields no values (use empty region instead)"; + } + + // Verify that the grouping sets refer to values yielded from `groupings`, + // that all yielded values are referred to, and that the references are in the + // correct order. + { + int64_t numGroupingColumns = 0; + if (!getGroupings().empty()) { + auto yieldOp = + llvm::cast(getGroupings().front().getTerminator()); + numGroupingColumns = yieldOp->getNumOperands(); + } + + // Check bounds, collect grouping columns. + llvm::SmallSet allGroupingRefs; + for (auto [groupingSetIdx, groupingSet] : + llvm::enumerate(getGroupingSets())) { + for (auto [refIdx, refAttr] : + llvm::enumerate(cast(groupingSet))) { + auto ref = cast(refAttr).getInt(); + if (ref < 0 || ref >= numGroupingColumns) + return emitOpError() << "has invalid grouping set #" << groupingSetIdx + << ": column reference " << ref << " (column #" + << refIdx << ") is out of bounds"; + auto [_, hasInserted] = allGroupingRefs.insert(ref); + if (hasInserted && + ref != static_cast(allGroupingRefs.size() - 1)) + return emitOpError() + << "has invalid grouping sets: the first occerrences of the " + "column references must be densely increasing"; + } + } + + // Check that all grouping columns are used. + if (static_cast(allGroupingRefs.size()) != numGroupingColumns) { + for (int64_t i : llvm::seq(0, numGroupingColumns)) { + if (!allGroupingRefs.contains(i)) + return emitOpError() << "has 'groupings' region whose operand #" << i + << " is not contained in any 'grouping_set'"; + } + } + } + + // Verify that `measures` region yields only values produced by + // `AggregateFunction`s. + if (!getMeasures().empty()) { + for (Value value : getMeasures().front().getTerminator()->getOperands()) { + auto callOp = llvm::dyn_cast_or_null(value.getDefiningOp()); + if (!callOp || !callOp.isAggregate()) + return emitOpError() << "yields value from 'measures' region that was " + "not produced by an aggregate function: " + << value; + } + } + + if (getGroupings().empty() && getMeasures().empty()) + return emitOpError() + << "one of 'groupings' or 'measures' must be specified"; + + return success(); +} + /// Implement `SymbolOpInterface`. ::mlir::LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTables) { diff --git a/lib/Target/SubstraitPB/CMakeLists.txt b/lib/Target/SubstraitPB/CMakeLists.txt index 4e0301e6..bc957b31 100644 --- a/lib/Target/SubstraitPB/CMakeLists.txt +++ b/lib/Target/SubstraitPB/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_translation_library(MLIRTargetSubstraitPB MLIRIR MLIRSubstraitDialect MLIRSupport + MLIRTransforms MLIRTranslateLib substrait_proto protobuf::libprotobuf diff --git a/lib/Target/SubstraitPB/Export.cpp b/lib/Target/SubstraitPB/Export.cpp index 4990c2de..ef46c62a 100644 --- a/lib/Target/SubstraitPB/Export.cpp +++ b/lib/Target/SubstraitPB/Export.cpp @@ -9,7 +9,10 @@ #include "substrait-mlir/Target/SubstraitPB/Export.h" #include "ProtobufUtils.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/CSE.h" #include "substrait-mlir/Dialect/Substrait/IR/Substrait.h" #include "substrait-mlir/Target/SubstraitPB/Options.h" #include "llvm/ADT/TypeSwitch.h" @@ -44,6 +47,7 @@ class SubstraitExporter { #define DECLARE_EXPORT_FUNC(OP_TYPE, MESSAGE_TYPE) \ FailureOr> exportOperation(OP_TYPE op); + DECLARE_EXPORT_FUNC(AggregateOp, Rel) DECLARE_EXPORT_FUNC(CallOp, Expression) DECLARE_EXPORT_FUNC(CrossOp, Rel) DECLARE_EXPORT_FUNC(EmitOp, Rel) @@ -60,6 +64,17 @@ class SubstraitExporter { DECLARE_EXPORT_FUNC(RelOpInterface, Rel) DECLARE_EXPORT_FUNC(SetOp, Rel) + // Common export logic for aggregate, scalar, and window functions. + template + FailureOr> exportCallOpCommon(CallOp op); + + // Special handling for aggregate, scalar, and window functions, which have + // the same argument types but different return types. + FailureOr> + exportCallOpAggregate(CallOp op); + FailureOr> exportCallOpScalar(CallOp op); + FailureOr> exportCallOpWindow(CallOp op); + std::unique_ptr exportAny(StringAttr attr); FailureOr> exportOperation(Operation *op); FailureOr> exportType(Location loc, @@ -312,51 +327,144 @@ SubstraitExporter::exportType(Location loc, mlir::Type mlirType) { return emitError(loc) << "could not export unsupported type " << mlirType; } -FailureOr> -SubstraitExporter::exportOperation(CallOp op) { - using ScalarFunction = Expression::ScalarFunction; +FailureOr> +SubstraitExporter::exportOperation(AggregateOp op) { + // Build `RelCommon` message. + auto relCommon = std::make_unique(); + auto direct = std::make_unique(); + relCommon->set_allocated_direct(direct.release()); - Location loc = op.getLoc(); + // Build `input` message. + auto inputOp = + llvm::dyn_cast_if_present(op.getInput().getDefiningOp()); + if (!inputOp) + return op->emitOpError("input was not produced by Substrait relation op"); - // Build `ScalarFunction` message. - // TODO(ingomueller): Support other `*Function` messages. - auto scalarFunction = std::make_unique(); - int32_t anchor = lookupAnchor(op, op.getCallee()); - scalarFunction->set_function_reference(anchor); + FailureOr> inputRel = exportOperation(inputOp); + if (failed(inputRel)) + return failure(); - // Build messages for arguments. - for (auto [i, operand] : llvm::enumerate(op->getOperands())) { - // Build `Expression` message for operand. - auto definingOp = llvm::dyn_cast_if_present( - operand.getDefiningOp()); - if (!definingOp) - return op->emitOpError() - << "with operand " << i - << " that was not produced by Substrait relation op"; + // Build `AggregateRel` message. + auto aggregateRel = std::make_unique(); + aggregateRel->set_allocated_common(relCommon.release()); + aggregateRel->set_allocated_input(inputRel->release()); - FailureOr> expression = - exportOperation(definingOp); - if (failed(expression)) - return failure(); + // Build `groupings` field. + { + // Make sure grouping expressions are distinct after CSE. + if (!op.getGroupings().empty()) { + // Set up rewriter to make temporary copy. + IRRewriter rewriter(op.getContext()); + IRRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(op); + + // Create a temporary copy that gets *erased* by the rewriter when it goes + // out of scope. + auto eraseOp = [&](Operation *op) { rewriter.eraseOp(op); }; + std::unique_ptr opCopy(rewriter.clone(*op), + eraseOp); + AggregateOp aggrOpCopy = mlir::cast(opCopy.get()); + + // Run CSE on the copy. + { + DominanceInfo domInfo; + mlir::eliminateCommonSubExpressions(rewriter, domInfo, opCopy.get()); + } - // Build `FunctionArgument` message and add to arguments. - FunctionArgument arg; - arg.set_allocated_value(expression->release()); - *scalarFunction->add_arguments() = arg; + // Make sure that all yielded values are different. If they are not, then + // some of them would result in equivalent grouping expressions in the + // protobuf format, which would change the semantics of the op. + auto yieldOp = llvm::cast( + aggrOpCopy.getGroupings().front().getTerminator()); + ValueRange yieldedValues = yieldOp->getOperands(); + DenseSet distinctYieldedValues; + distinctYieldedValues.insert(yieldedValues.begin(), yieldedValues.end()); + if (yieldedValues.size() != distinctYieldedValues.size()) + return op.emitOpError() + << "cannot be exported: values yielded from 'groupings' region " + "are not all distinct after CSE"; + } + + // Export values yielded from `groupings` region as `Expression` messages. + SmallVector> columnExpressions; + { + // Get grouping expressions if any. + ArrayRef emptyValueRange; + ValueRange columnValues = emptyValueRange; + if (!op.getGroupings().empty()) { + auto yieldOp = + llvm::cast(op.getGroupings().front().getTerminator()); + columnValues = yieldOp->getOperands(); + } + + columnExpressions.reserve(columnValues.size()); + for (auto [columnIdx, columnVal] : llvm::enumerate(columnValues)) { + // Build `Expression` message for operand. + auto definingOp = llvm::dyn_cast_if_present( + columnVal.getDefiningOp()); + if (!definingOp) + return op->emitOpError() + << "yields grouping column " << columnIdx + << " that was not produced by Substrait expression op"; + + FailureOr> columnExpr = + exportOperation(definingOp); + if (failed(columnExpr)) + return failure(); + + columnExpressions.push_back(std::move(columnExpr.value())); + } + } + + // Populate repeated `groupings` field according to grouping sets. + for (auto groupingSet : op.getGroupingSets().getAsRange()) { + AggregateRel::Grouping *grouping = aggregateRel->add_groupings(); + for (auto columnIdxAttr : groupingSet.getAsRange()) { + // Look up exported expression and add as `grouping_expression`. + int64_t columnIdx = columnIdxAttr.getInt(); + Expression *columnExpr = columnExpressions[columnIdx].get(); + *grouping->add_grouping_expressions() = *columnExpr; + } + } } - // Build message for `output_type`. - FailureOr> outputType = - exportType(loc, op.getResult().getType()); - if (failed(outputType)) - return failure(); - scalarFunction->set_allocated_output_type(outputType->release()); + // Export measures if any. + if (!op.getMeasures().empty()) { + auto yieldOp = + llvm::cast(op.getMeasures().front().getTerminator()); + for (auto [measureIdx, measureVal] : + llvm::enumerate(yieldOp->getOperands())) { + // Build `Expression` message for operand. + auto callOp = llvm::cast(measureVal.getDefiningOp()); + assert(callOp.isAggregate() && "expected aggregate function"); + + FailureOr> aggregateFunction = + exportCallOpAggregate(callOp); + if (failed(aggregateFunction)) + return failure(); - // Build `Expression` message. - auto expression = std::make_unique(); - expression->set_allocated_scalar_function(scalarFunction.release()); + // Add `AggregateFunction` to `measures`. + AggregateRel::Measure *measure = aggregateRel->add_measures(); + measure->set_allocated_measure(aggregateFunction.value().release()); + } + } - return expression; + // Build `Rel` message. + auto rel = std::make_unique(); + rel->set_allocated_aggregate(aggregateRel.release()); + + return rel; +} + +FailureOr> +SubstraitExporter::exportOperation(CallOp op) { + if (op.isScalar()) + return exportCallOpScalar(op); + if (op.isWindow()) + return op.emitError() << "has a window function, which is currently not " + "supported for export"; + assert(op.isAggregate() && "unexpected function type"); + return op->emitOpError() << "with aggregate function not expected here"; } FailureOr> SubstraitExporter::exportOperation(CrossOp op) { @@ -1031,6 +1139,90 @@ SubstraitExporter::exportOperation(ProjectOp op) { return rel; } +template +FailureOr> +SubstraitExporter::exportCallOpCommon(CallOp op) { + Location loc = op.getLoc(); + + // Build main message. + auto function = std::make_unique(); + int32_t anchor = lookupAnchor(op, op.getCallee()); + function->set_function_reference(anchor); + + // Build messages for arguments. + for (auto [i, operand] : llvm::enumerate(op->getOperands())) { + // Build `Expression` message for operand. + auto definingOp = llvm::dyn_cast_if_present( + operand.getDefiningOp()); + if (!definingOp) + return op->emitOpError() + << "with operand " << i + << " that was not produced by Substrait expression op"; + + FailureOr> expression = + exportOperation(definingOp); + if (failed(expression)) + return failure(); + + // Build `FunctionArgument` message and add to arguments. + FunctionArgument arg; + arg.set_allocated_value(expression->release()); + *function->add_arguments() = arg; + } + + // Build message for `output_type`. + FailureOr> outputType = + exportType(loc, op.getResult().getType()); + if (failed(outputType)) + return failure(); + function->set_allocated_output_type(outputType->release()); + + return function; +} + +FailureOr> +SubstraitExporter::exportCallOpAggregate(CallOp op) { + assert(op.isAggregate() && "expected aggregate function"); + + // Export common fields. + FailureOr> maybeAggregateFunction = + exportCallOpCommon(op); + if (failed(maybeAggregateFunction)) + return failure(); + std::unique_ptr aggregateFunction = + std::move(maybeAggregateFunction.value()); + + // Add aggregation-specific fields. + AggregationInvocation invocation = op.getAggregationInvocation().value(); + aggregateFunction->set_invocation( + static_cast(invocation)); + + return aggregateFunction; +} + +FailureOr> +SubstraitExporter::exportCallOpScalar(CallOp op) { + using ScalarFunction = Expression::ScalarFunction; + assert(op.isScalar() && "expected scalar function"); + + // Export common fields. + FailureOr> scalarFunction = + exportCallOpCommon(op); + if (failed(scalarFunction)) + return failure(); + + // Build `Expression` message. + auto expression = std::make_unique(); + expression->set_allocated_scalar_function(scalarFunction.value().release()); + + return expression; +} + +FailureOr> +SubstraitExporter::exportCallOpWindow(CallOp op) { + llvm_unreachable("not implemented"); +} + FailureOr> SubstraitExporter::exportOperation(SetOp op) { // Build `RelCommon` message. auto relCommon = std::make_unique(); @@ -1069,6 +1261,7 @@ SubstraitExporter::exportOperation(RelOpInterface op) { return llvm::TypeSwitch>>(op) .Case< // clang-format off + AggregateOp, CrossOp, EmitOp, FetchOp, diff --git a/lib/Target/SubstraitPB/Import.cpp b/lib/Target/SubstraitPB/Import.cpp index 6e40950b..18aca537 100644 --- a/lib/Target/SubstraitPB/Import.cpp +++ b/lib/Target/SubstraitPB/Import.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/OwningOpRef.h" #include "substrait-mlir/Dialect/Substrait/IR/Substrait.h" #include "substrait-mlir/Target/SubstraitPB/Options.h" +#include "llvm/ADT/SmallSet.h" #include #include @@ -32,6 +33,30 @@ namespace pb = google::protobuf; namespace { +// Copied from +// https://github.com/llvm/llvm-project/blob/dea33c/mlir/lib/Transforms/CSE.cpp. +struct SimpleOperationInfo : public llvm::DenseMapInfo { + static unsigned getHashValue(const Operation *opC) { + return OperationEquivalence::computeHash( + const_cast(opC), + /*hashOperands=*/OperationEquivalence::directHashValue, + /*hashResults=*/OperationEquivalence::ignoreHashValue, + OperationEquivalence::IgnoreLocations); + } + static bool isEqual(const Operation *lhsC, const Operation *rhsC) { + auto *lhs = const_cast(lhsC); + auto *rhs = const_cast(rhsC); + if (lhs == rhs) + return true; + if (lhs == getTombstoneKey() || lhs == getEmptyKey() || + rhs == getTombstoneKey() || rhs == getEmptyKey()) + return false; + return OperationEquivalence::isEquivalentTo( + const_cast(lhsC), const_cast(rhsC), + OperationEquivalence::IgnoreLocations); + } +}; + // Forward declaration for the import function of the given message type. // // We need one such function for most message types that we want to import. The @@ -46,6 +71,8 @@ namespace { static FailureOr import##MESSAGE_TYPE(ImplicitLocOpBuilder builder, \ const ARG_TYPE &message); +DECLARE_IMPORT_FUNC(AggregateFunction, AggregateFunction, CallOp) +DECLARE_IMPORT_FUNC(AggregateRel, Rel, AggregateOp) DECLARE_IMPORT_FUNC(Any, pb::Any, StringAttr) DECLARE_IMPORT_FUNC(CrossRel, Rel, CrossOp) DECLARE_IMPORT_FUNC(FetchRel, Rel, FetchOp) @@ -64,6 +91,10 @@ DECLARE_IMPORT_FUNC(ReadRel, Rel, RelOpInterface) DECLARE_IMPORT_FUNC(Rel, Rel, RelOpInterface) DECLARE_IMPORT_FUNC(ScalarFunction, Expression::ScalarFunction, CallOp) +template +static FailureOr importFunctionCommon(ImplicitLocOpBuilder builder, + const MessageType &message); + FailureOr importAny(ImplicitLocOpBuilder builder, const pb::Any &message) { MLIRContext *context = builder.getContext(); @@ -153,6 +184,154 @@ static mlir::FailureOr importType(MLIRContext *context, } } +mlir::FailureOr +importAggregateFunction(ImplicitLocOpBuilder builder, + const AggregateFunction &message) { + MLIRContext *context = builder.getContext(); + Location loc = UnknownLoc::get(context); + + FailureOr maybeCallOp = importFunctionCommon(builder, message); + if (failed(maybeCallOp)) + return failure(); + CallOp callOp = maybeCallOp.value(); + + // Import `invocation` field. + AggregateFunction::AggregationInvocation invocation = message.invocation(); + std::optional invocationEnum = + symbolizeAggregationInvocation(invocation); + if (!invocationEnum.has_value()) + return emitError(loc) + << "unsupported enum value for aggregate function invocation"; + callOp.setAggregationInvocation(invocationEnum); + + assert(callOp.isAggregate() && "expected to build aggregate function"); + return callOp; +} + +static mlir::FailureOr +importAggregateRel(ImplicitLocOpBuilder builder, const Rel &message) { + using Grouping = AggregateRel::Grouping; + using Measure = AggregateRel::Measure; + + MLIRContext *context = builder.getContext(); + Location loc = UnknownLoc::get(context); + + const AggregateRel &aggregateRel = message.aggregate(); + + // Import input. + const Rel &inputRel = aggregateRel.input(); + mlir::FailureOr inputOp = importRel(builder, inputRel); + if (failed(inputOp)) + return failure(); + Value inputVal = inputOp.value()->getResult(0); + + // Import measures if any. + auto measuresRegion = std::make_unique(); + if (aggregateRel.measures_size() > 0) { + Block *measuresBlock = &measuresRegion->emplaceBlock(); + measuresBlock->addArgument(inputVal.getType(), loc); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(measuresBlock); + SmallVector measuresValues; + measuresValues.reserve(aggregateRel.measures_size()); + for (const Measure &measure : aggregateRel.measures()) { + const AggregateFunction &aggrFunc = measure.measure(); + + // Import measure as `CallOp`. + FailureOr callOp = importAggregateFunction(builder, aggrFunc); + if (failed(callOp)) + return failure(); + + measuresValues.push_back(callOp.value().getResult()); + } + + builder.create(measuresValues); + } + + // Import groupings if any. + auto groupingsRegion = std::make_unique(); + SmallVector groupingSetsAttrs; + if (aggregateRel.groupings_size() > 0) { + Block *groupingsBlock = &groupingsRegion->emplaceBlock(); + groupingsBlock->addArgument(inputVal.getType(), loc); + + // Grouping expressions, i.e., values yielded from `groupings`. + SmallVector groupingExprValues; + groupingSetsAttrs.reserve(aggregateRel.groupings_size()); + + // Ops that produce unique grouping expressions. In the protobuf messages, + // each grouping set repeats the grouping expressions whereas the + // `AggregateOp` yields unique grouping expressions from the `groupings` + // region and represents the grouping sets as references to those. + llvm::SmallDenseMap + groupingExprOps; + + // Import one grouping set at a time. + for (const Grouping &grouping : aggregateRel.groupings()) { + // Collect references of grouping expressions for this grouping set. + SmallVector expressionRefs; + expressionRefs.reserve(grouping.grouping_expressions_size()); + for (const Expression &expression : grouping.grouping_expressions()) { + // Import expression message into `groupings` region. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(groupingsBlock); + FailureOr exprOp = + importExpression(builder, expression); + if (failed(exprOp)) + return failure(); + + // Create or look-up reference. + auto [it, hasInserted] = groupingExprOps.try_emplace(exprOp.value()); + + // If it's a new expression, assign new reference. + if (hasInserted) { + it->second = groupingExprOps.size() - 1; + groupingExprValues.emplace_back(exprOp.value()->getResult(0)); + } else { + // Otherwise, undo import by removing ops recursively. + llvm::SmallVector worklist; + worklist.push_back(exprOp.value()); + while (!worklist.empty()) { + Operation *nextOp = worklist.pop_back_val(); + for (Value v : nextOp->getOperands()) { + if (Operation *defOp = v.getDefiningOp()) + worklist.push_back(defOp); + } + nextOp->erase(); + } + } + + // Remember reference for grouping set attribute. + expressionRefs.push_back(it->second); + } + + // Create `ArrayAttr` for current grouping set. + ArrayAttr groupingSet = builder.getI64ArrayAttr(expressionRefs); + groupingSetsAttrs.push_back(groupingSet); + } + + // Assemble `YieldOp` of groupings region if there are grouping expressions. + if (!groupingExprOps.empty()) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToEnd(groupingsBlock); + builder.create(loc, groupingExprValues); + } else { + // If there aren't any, we should clear the `groupings` region. + groupingsRegion->getBlocks().clear(); + } + } + + // Create attribute for grouping sets. + auto groupingSets = ArrayAttr::get(builder.getContext(), groupingSetsAttrs); + + // Build `AggregateOp` and move regions into it. + auto aggregateOp = builder.create( + inputVal, groupingSets, groupingsRegion.get(), measuresRegion.get()); + + return aggregateOp; +} + static mlir::FailureOr importCrossRel(ImplicitLocOpBuilder builder, const Rel &message) { const CrossRel &crossRel = message.cross(); @@ -710,6 +889,9 @@ static mlir::FailureOr importRel(ImplicitLocOpBuilder builder, Rel::RelTypeCase relType = message.rel_type_case(); FailureOr maybeOp; switch (relType) { + case Rel::RelTypeCase::kAggregate: + maybeOp = importAggregateRel(builder, message); + break; case Rel::RelTypeCase::kCross: maybeOp = importCrossRel(builder, message); break; @@ -767,6 +949,15 @@ static mlir::FailureOr importRel(ImplicitLocOpBuilder builder, static mlir::FailureOr importScalarFunction(ImplicitLocOpBuilder builder, const Expression::ScalarFunction &message) { + FailureOr callOp = importFunctionCommon(builder, message); + assert((failed(callOp) || callOp->isScalar()) && + "expected to build scalar function"); + return callOp; +} + +template +FailureOr importFunctionCommon(ImplicitLocOpBuilder builder, + const MessageType &message) { MLIRContext *context = builder.getContext(); Location loc = UnknownLoc::get(context); diff --git a/lib/Target/SubstraitPB/ProtobufUtils.cpp b/lib/Target/SubstraitPB/ProtobufUtils.cpp index 397fcb0b..d118579d 100644 --- a/lib/Target/SubstraitPB/ProtobufUtils.cpp +++ b/lib/Target/SubstraitPB/ProtobufUtils.cpp @@ -27,6 +27,8 @@ static const RelCommon *getCommon(const RelType &rel) { FailureOr getCommon(const Rel &rel, Location loc) { Rel::RelTypeCase relType = rel.rel_type_case(); switch (relType) { + case Rel::RelTypeCase::kAggregate: + return getCommon(rel.aggregate()); case Rel::RelTypeCase::kCross: return getCommon(rel.cross()); case Rel::RelTypeCase::kFetch: @@ -56,6 +58,8 @@ static RelCommon *getMutableCommon(RelType *rel) { FailureOr getMutableCommon(Rel *rel, Location loc) { Rel::RelTypeCase relType = rel->rel_type_case(); switch (relType) { + case Rel::RelTypeCase::kAggregate: + return getMutableCommon(rel->mutable_aggregate()); case Rel::RelTypeCase::kCross: return getMutableCommon(rel->mutable_cross()); case Rel::RelTypeCase::kFetch: diff --git a/test/Dialect/Substrait/aggregate-invalid.mlir b/test/Dialect/Substrait/aggregate-invalid.mlir new file mode 100644 index 00000000..15cca43e --- /dev/null +++ b/test/Dialect/Substrait/aggregate-invalid.mlir @@ -0,0 +1,176 @@ +// RUN: substrait-opt -verify-diagnostics -split-input-file %s + +// Verify that wrong arg type to `groupings` is detected. + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + // expected-error@+1 {{'substrait.aggregate' op has region #0 with invalid argument types (expected: 'tuple', got: 'tuple')}} + %1 = aggregate %0 : tuple -> tuple + groupings { + ^bb0(%arg : tuple): + %2 = literal 0 : si1 + yield %2 : si1 + } + yield %1 : tuple + } +} + +// ----- + +// Verify that wrong arg type to `measures` is detected. + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + // expected-error@+1 {{'substrait.aggregate' op has region #1 with invalid argument types (expected: 'tuple', got: 'tuple')}} + %1 = aggregate %0 : tuple -> tuple + measures { + ^bb0(%arg : tuple): + %2 = literal 0 : si1 + %3 = call @function(%2) aggregate : (si1) -> si1 + yield %3 : si1 + } + yield %1 : tuple + } +} + +// ----- + +// Verify that out-of-bound column refs in grouping sets are detected. + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + // expected-error@+1 {{'substrait.aggregate' op has invalid grouping set #0: column reference 1 (column #0) is out of bounds}} + %1 = aggregate %0 : tuple -> tuple + groupings { + ^bb0(%arg : tuple): + %2 = literal 0 : si1 + yield %2 : si1 + } + grouping_sets [[1]] + yield %1 : tuple + } +} + +// ----- + +// Verify that it's detected if first occurrences of column references are not +// densely increasing. + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + // expected-error@+1 {{'substrait.aggregate' op has invalid grouping sets: the first occerrences of the column references must be densely increasing}} + %1 = aggregate %0 : tuple -> tuple + groupings { + ^bb0(%arg : tuple): + %2 = literal 0 : si1 + yield %2, %2 : si1, si1 + } + grouping_sets [[1, 0]] + yield %1 : tuple + } +} + +// ----- + +// Verify that yielded value unused by all grouping sets is detected. + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + // expected-error@+1 {{'substrait.aggregate' op has 'groupings' region whose operand #1 is not contained in any 'grouping_set'}} + %1 = aggregate %0 : tuple -> tuple + groupings { + ^bb0(%arg : tuple): + %2 = literal 0 : si1 + yield %2, %2 : si1, si1 + } + grouping_sets [[0]] + yield %1 : tuple + } +} + +// ----- + +// Verify that missing `groupings` *and* `measures regions is detected. + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + // expected-error@+1 {{one of 'groupings' or 'measures' must be specified}} + %1 = aggregate %0 : tuple -> tuple<> + grouping_sets [[]] + yield %1 : tuple<> + } +} + +// ----- + +// Verify that unaggregated measure is detected. + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + // expected-error-re@+1 {{'substrait.aggregate' op yields value from 'measures' region that was not produced by an aggregate function: {{.*}}substrait.call{{.*}}}} + %1 = aggregate %0 : tuple -> tuple + measures { + ^bb0(%arg : tuple): + %2 = literal 0 : si32 + %3 = call @function(%2) : (si32) -> si32 + yield %3 : si32 + } + yield %1 : tuple + } +} + +// ----- + +// Verify that invalid aggregation invocation mode is detected. + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = aggregate %0 : tuple -> tuple + measures { + ^bb0(%arg : tuple): + %2 = literal 0 : si32 + // expected-error@+1 {{custom op 'substrait.call' has invalid aggregate invocation type specification: foo}} + %3 = call @function(%2) aggregate foo : (si32) -> si32 + yield %3 : si32 + } + yield %1 : tuple + } +} + +// ----- + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + // expected-error@+1 {{'substrait.aggregate' op has region #1 that yields no values (use empty region instead)}} + %1 = aggregate %0 : tuple -> tuple<> + measures { + ^bb0(%arg : tuple): + yield + } + yield %1 : tuple<> + } +} + +// ----- + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + // expected-error@+1 {{'substrait.aggregate' op has region #0 that yields no values (use empty region instead)}} + %1 = aggregate %0 : tuple -> tuple<> + groupings { + ^bb0(%arg : tuple): + yield + } + yield %1 : tuple<> + } +} diff --git a/test/Dialect/Substrait/aggregate.mlir b/test/Dialect/Substrait/aggregate.mlir new file mode 100644 index 00000000..7e4121a0 --- /dev/null +++ b/test/Dialect/Substrait/aggregate.mlir @@ -0,0 +1,258 @@ +// RUN: substrait-opt -split-input-file %s \ +// RUN: | FileCheck %s + +// Check complete op with all regions and attributes. + +// CHECK-LABEL: substrait.plan +// CHECK: relation +// CHECK: %[[V0:.*]] = named_table +// CHECK-NEXT: %[[V1:.*]] = aggregate %[[V0]] : tuple -> tuple +// CHECK-NEXT: groupings { +// CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple): +// CHECK-NEXT: %[[V2:.*]] = literal 0 : si1 +// CHECK-NEXT: yield %[[V2]], %[[V2]] : si1, si1 +// CHECK-NEXT: } +// CHECK-NEXT: grouping_sets {{\[}}[0], [0, 1], [1], []] +// CHECK-NEXT: measures { +// CHECK-NEXT: ^[[BB1:.*]](%[[ARG1:.*]]: tuple): +// CHECK-DAG: %[[V3:.*]] = field_reference %[[ARG1]][0] +// CHECK-DAG: %[[V4:.*]] = literal 0 +// CHECK-DAG: %[[V5:.*]] = call @function(%[[V3]]) aggregate : +// CHECK-DAG: %[[V6:.*]] = call @function(%[[V4]]) aggregate : +// CHECK-NEXT: yield %[[V5]], %[[V6]] : si32, si1 +// CHECK-NEXT: } +// CHECK-NEXT: yield %[[V1]] + +substrait.plan version 0 : 42 : 1 { + extension_uri @extension at "http://some.url/with/extensions.yml" + extension_function @function at @extension["somefunc"] + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = aggregate %0 : tuple -> tuple + groupings { + ^bb0(%arg : tuple): + %2 = literal 0 : si1 + yield %2, %2 : si1, si1 + } + grouping_sets [[0], [0, 1], [1], []] + measures { + ^bb0(%arg : tuple): + %2 = field_reference %arg[0] : tuple + %3 = literal 0 : si1 + %4 = call @function(%2) aggregate : (si32) -> si32 + %5 = call @function(%3) aggregate unspecified : (si1) -> si1 + yield %4, %5 : si32, si1 + } + yield %1 : tuple + } +} + +// ----- + +// Check complete op with different order. + +// CHECK-LABEL: substrait.plan +// CHECK: relation +// CHECK: %[[V0:.*]] = named_table +// CHECK-NEXT: %[[V1:.*]] = aggregate %[[V0]] +// CHECK-NEXT: groupings { +// CHECK: } +// CHECK-NEXT: grouping_sets +// CHECK-NEXT: measures { +// CHECK-NEXT: ^[[BB1:.*]](%[[ARG1:.*]]: tuple): +// CHECK-DAG: %[[V3:.*]] = field_reference %[[ARG1]][0] +// CHECK-DAG: %[[V4:.*]] = literal 0 +// CHECK-DAG: %[[V5:.*]] = call @function(%[[V3]]) aggregate all : +// CHECK-DAG: %[[V6:.*]] = call @function(%[[V4]]) aggregate distinct : +// CHECK-NEXT: yield %[[V5]], %[[V6]] : si32, si1 +// CHECK-NEXT: } +// CHECK-NEXT: yield %[[V1]] + +substrait.plan version 0 : 42 : 1 { + extension_uri @extension at "http://some.url/with/extensions.yml" + extension_function @function at @extension["somefunc"] + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = aggregate %0 : tuple -> tuple + measures { + ^bb0(%arg : tuple): + %2 = field_reference %arg[0] : tuple + %3 = literal 0 : si1 + %4 = call @function(%2) aggregate all : (si32) -> si32 + %5 = call @function(%3) aggregate distinct : (si1) -> si1 + yield %4, %5 : si32, si1 + } + grouping_sets [[0], [0, 1], [1], []] + groupings { + ^bb0(%arg : tuple): + %2 = literal 0 : si1 + yield %2, %2 : si1, si1 + } + yield %1 : tuple + } +} + +// ----- + +// Check op without measures. + +// CHECK-LABEL: substrait.plan +// CHECK: relation +// CHECK: %[[V0:.*]] = named_table +// CHECK-NEXT: %[[V1:.*]] = aggregate %[[V0]] +// CHECK-NEXT: groupings { +// CHECK: } +// CHECK-NEXT: grouping_sets +// CHECK-NEXT: yield %[[V1]] + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = aggregate %0 : tuple -> tuple + groupings { + ^bb0(%arg : tuple): + %2 = literal 0 : si1 + yield %2, %2 : si1, si1 + } + grouping_sets [[0], [0, 1], [1], []] + yield %1 : tuple + } +} + +// ----- + +// Check op with explicit single grouping_set. + +// CHECK-LABEL: substrait.plan +// CHECK: relation +// CHECK: %[[V0:.*]] = named_table +// CHECK-NEXT: %[[V1:.*]] = aggregate %[[V0]] +// CHECK-NEXT: groupings { +// CHECK: } +// CHECK-NEXT: yield %[[V1]] + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = aggregate %0 : tuple -> tuple + groupings { + ^bb0(%arg : tuple): + %2 = literal 0 : si1 + yield %2, %2 : si1, si1 + } + grouping_sets [[0, 1]] + yield %1 : tuple + } +} + +// ----- + +// Check op with implicit single grouping_set. + +// CHECK-LABEL: substrait.plan +// CHECK: relation +// CHECK: %[[V0:.*]] = named_table +// CHECK-NEXT: %[[V1:.*]] = aggregate %[[V0]] +// CHECK-NEXT: groupings { +// CHECK: } +// CHECK-NEXT: yield %[[V1]] + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = aggregate %0 : tuple -> tuple + groupings { + ^bb0(%arg : tuple): + %2 = literal 0 : si1 + yield %2, %2 : si1, si1 + } + yield %1 : tuple + } +} + +// ----- + +// Check op without `grouping` and no grouping sets. + +// CHECK-LABEL: substrait.plan +// CHECK: relation +// CHECK: %[[V0:.*]] = named_table +// CHECK-NEXT: %[[V1:.*]] = aggregate %[[V0]] +// CHECK-NEXT: grouping_sets [] +// CHECK-NEXT: measures { +// CHECK: } +// CHECK-NEXT: yield %[[V1]] + +substrait.plan version 0 : 42 : 1 { + extension_uri @extension at "http://some.url/with/extensions.yml" + extension_function @function at @extension["somefunc"] + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = aggregate %0 : tuple -> tuple + grouping_sets [] + measures { + ^bb0(%arg : tuple): + %2 = field_reference %arg[0] : tuple + %3 = call @function(%2) aggregate : (si32) -> si32 + yield %3 : si32 + } + yield %1 : tuple + } +} + +// ----- + +// Check op without `grouping` and implicit (empty) grouping set. + +// CHECK-LABEL: substrait.plan +// CHECK: relation +// CHECK: %[[V0:.*]] = named_table +// CHECK-NEXT: %[[V1:.*]] = aggregate %[[V0]] +// CHECK-NEXT: measures { +// CHECK: } +// CHECK-NEXT: yield %[[V1]] + +substrait.plan version 0 : 42 : 1 { + extension_uri @extension at "http://some.url/with/extensions.yml" + extension_function @function at @extension["somefunc"] + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = aggregate %0 : tuple -> tuple + measures { + ^bb0(%arg : tuple): + %2 = field_reference %arg[0] : tuple + %3 = call @function(%2) aggregate : (si32) -> si32 + yield %3 : si32 + } + yield %1 : tuple + } +} + +// ----- + +// Check op without `grouping` and explicit empty grouping set. + +// CHECK-LABEL: substrait.plan +// CHECK: relation +// CHECK: %[[V0:.*]] = named_table +// CHECK-NEXT: %[[V1:.*]] = aggregate %[[V0]] +// CHECK-NEXT: measures { +// CHECK: } +// CHECK-NEXT: yield %[[V1]] + +substrait.plan version 0 : 42 : 1 { + extension_uri @extension at "http://some.url/with/extensions.yml" + extension_function @function at @extension["somefunc"] + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = aggregate %0 : tuple -> tuple + grouping_sets [[]] + measures { + ^bb0(%arg : tuple): + %2 = field_reference %arg[0] : tuple + %3 = call @function(%2) aggregate : (si32) -> si32 + yield %3 : si32 + } + yield %1 : tuple + } +} diff --git a/test/Target/SubstraitPB/Export/aggregate-invalid.mlir b/test/Target/SubstraitPB/Export/aggregate-invalid.mlir new file mode 100644 index 00000000..06352dc4 --- /dev/null +++ b/test/Target/SubstraitPB/Export/aggregate-invalid.mlir @@ -0,0 +1,20 @@ +// RUN: substrait-translate -verify-diagnostics -split-input-file %s \ +// RUN: -substrait-to-protobuf + +// The groupings aren't unique after CSE. This has a different meaning once +// exported to protobuf. + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + // expected-error@+1 {{'substrait.aggregate' op cannot be exported: values yielded from 'groupings' region are not all distinct after CSE}} + %1 = aggregate %0 : tuple -> tuple + groupings { + ^bb0(%arg : tuple): + %2 = literal 0 : si1 + %3 = literal 0 : si1 + yield %2, %3 : si1, si1 + } + yield %1 : tuple + } +} diff --git a/test/Target/SubstraitPB/Export/aggregate.mlir b/test/Target/SubstraitPB/Export/aggregate.mlir new file mode 100644 index 00000000..dcda07e9 --- /dev/null +++ b/test/Target/SubstraitPB/Export/aggregate.mlir @@ -0,0 +1,229 @@ +// RUN: substrait-translate -substrait-to-protobuf --split-input-file %s \ +// RUN: | FileCheck %s + +// RUN: substrait-translate -substrait-to-protobuf %s \ +// RUN: --split-input-file --output-split-marker="# -----" \ +// RUN: | substrait-translate -protobuf-to-substrait \ +// RUN: --split-input-file="# -----" --output-split-marker="// ""-----" \ +// RUN: | substrait-translate -substrait-to-protobuf \ +// RUN: --split-input-file --output-split-marker="# -----" \ +// RUN: | FileCheck %s + +// Check complete op with all regions and attributes. + +// CHECK: extension_uris { +// CHECK-NEXT: uri: "http://some.url/with/extensions.yml" +// CHECK-NEXT: } +// CHECK-NEXT: extensions { +// CHECK-NEXT: extension_function { +// CHECK-NEXT: name: "somefunc" +// CHECK-NEXT: } +// CHECK: relations { +// CHECK-NEXT: rel { +// CHECK-NEXT: aggregate { +// CHECK: input { +// CHECK: groupings { +// CHECK-NEXT: grouping_expressions { +// CHECK-NEXT: literal { +// CHECK-NEXT: boolean: false +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: groupings { +// CHECK-NEXT: grouping_expressions { +// CHECK-NEXT: literal { +// CHECK-NEXT: boolean: false +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: grouping_expressions { +// CHECK-NEXT: literal { +// CHECK-NEXT: boolean: true +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: groupings { +// CHECK-NEXT: grouping_expressions { +// CHECK-NEXT: literal { +// CHECK-NEXT: boolean: true +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: groupings { +// CHECK-NEXT: } +// CHECK-NEXT: measures { +// CHECK-NEXT: measure { +// CHECK-NEXT: output_type { +// CHECK-NEXT: i32 { +// CHECK-NEXT: nullability: NULLABILITY_REQUIRED +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: arguments { +// CHECK-NEXT: value { +// CHECK-NEXT: selection { +// CHECK-NOT: measure +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: measures { +// CHECK-NEXT: measure { +// CHECK-NEXT: output_type { +// CHECK-NEXT: bool { +// CHECK-NEXT: nullability: NULLABILITY_REQUIRED +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: arguments { +// CHECK-NEXT: value { +// CHECK-NEXT: literal { +// CHECK-NEXT: boolean: false +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK: version + +substrait.plan version 0 : 42 : 1 { + extension_uri @extension at "http://some.url/with/extensions.yml" + extension_function @function at @extension["somefunc"] + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = aggregate %0 : tuple -> tuple + groupings { + ^bb0(%arg : tuple): + %2 = literal 0 : si1 + %3 = literal -1 : si1 + yield %2, %3 : si1, si1 + } + grouping_sets [[0], [0, 1], [1], []] + measures { + ^bb0(%arg : tuple): + %2 = field_reference %arg[0] : tuple + %3 = literal 0 : si1 + %4 = call @function(%2) aggregate : (si32) -> si32 + %5 = call @function(%3) aggregate unspecified : (si1) -> si1 + yield %4, %5 : si32, si1 + } + yield %1 : tuple + } +} + +// ----- + +// Check op without measures. + +// CHECK: relations { +// CHECK-NEXT: rel { +// CHECK-NEXT: aggregate { +// CHECK: input { +// CHECK: groupings { +// CHECK-NOT: measures +// CHECK: version + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = aggregate %0 : tuple -> tuple + groupings { + ^bb0(%arg : tuple): + %2 = literal 0 : si1 + yield %2 : si1 + } + grouping_sets [[0]] + yield %1 : tuple + } +} + +// ----- + +// Check op special invocation modes. + +// CHECK: extension_uris { +// CHECK: relations { +// CHECK-NEXT: rel { +// CHECK-NEXT: aggregate { +// CHECK: measures { +// CHECK-NEXT: measure { +// CHECK-NOT: measure +// CHECK: invocation: AGGREGATION_INVOCATION_ALL +// CHECK: measure { +// CHECK-NOT: measure +// CHECK: invocation: AGGREGATION_INVOCATION_DISTINCT +// CHECK: version + +substrait.plan version 0 : 42 : 1 { + extension_uri @extension at "http://some.url/with/extensions.yml" + extension_function @function at @extension["somefunc"] + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = aggregate %0 : tuple -> tuple + measures { + ^bb0(%arg : tuple): + %2 = field_reference %arg[0] : tuple + %3 = literal 0 : si1 + %4 = call @function(%2) aggregate all : (si32) -> si32 + %5 = call @function(%3) aggregate distinct : (si1) -> si1 + yield %4, %5 : si32, si1 + } + yield %1 : tuple + } +} + +// ----- + +// Check op without `grouping` and no grouping sets. + +// CHECK: extension_uris { +// CHECK: relations { +// CHECK-NEXT: rel { +// CHECK-NEXT: aggregate { +// CHECK-NOT: groupings +// CHECK: version + +substrait.plan version 0 : 42 : 1 { + extension_uri @extension at "http://some.url/with/extensions.yml" + extension_function @function at @extension["somefunc"] + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = aggregate %0 : tuple -> tuple + grouping_sets [] + measures { + ^bb0(%arg : tuple): + %2 = field_reference %arg[0] : tuple + %4 = call @function(%2) aggregate : (si32) -> si32 + yield %4 : si32 + } + yield %1 : tuple + } +} + +// ----- + +// Check op without `grouping` and (implicit) empty grouping set. + +// CHECK: extension_uris { +// CHECK: relations { +// CHECK-NEXT: rel { +// CHECK-NEXT: aggregate { +// CHECK: groupings { +// CHECK: version + +substrait.plan version 0 : 42 : 1 { + extension_uri @extension at "http://some.url/with/extensions.yml" + extension_function @function at @extension["somefunc"] + relation { + %0 = named_table @t1 as ["a"] : tuple + %1 = aggregate %0 : tuple -> tuple + measures { + ^bb0(%arg : tuple): + %2 = field_reference %arg[0] : tuple + %4 = call @function(%2) aggregate : (si32) -> si32 + yield %4 : si32 + } + yield %1 : tuple + } +} diff --git a/test/Target/SubstraitPB/Import/aggregate.textpb b/test/Target/SubstraitPB/Import/aggregate.textpb new file mode 100644 index 00000000..2b2759b1 --- /dev/null +++ b/test/Target/SubstraitPB/Import/aggregate.textpb @@ -0,0 +1,476 @@ +# RUN: substrait-translate -protobuf-to-substrait %s \ +# RUN: --split-input-file="# ""-----" \ +# RUN: | FileCheck %s + +# RUN: substrait-translate -protobuf-to-substrait %s \ +# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \ +# RUN: | substrait-translate -substrait-to-protobuf \ +# RUN: --split-input-file --output-split-marker="# ""-----" \ +# RUN: | substrait-translate -protobuf-to-substrait \ +# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \ +# RUN: | FileCheck %s + +# CHECK-LABEL: substrait.plan +# CHECK-NEXT: extension_uri @[[URI:.*]] at "http://some.url/with/extensions.yml" +# CHECK-NEXT: extension_function @[[F1:.*]] at @[[URI]]["somefunc"] +# CHECK-NEXT: relation +# CHECK-NEXT: %[[V0:.*]] = named_table +# CHECK-NEXT: %[[V1:.*]] = aggregate %[[V0]] : tuple -> tuple +# CHECK-NEXT: groupings { +# CHECK-NEXT: (%[[ARG0:.*]]: tuple): +# CHECK-DAG: %[[V2:.*]] = literal -1 : si1 +# CHECK-DAG: %[[V3:.*]] = literal 0 : si1 +# CHECK-NEXT: yield %[[V3]], %[[V2]] : si1, si1 +# CHECK-NEXT: } +# CHECK-NEXT: grouping_sets {{\[}}[0], [0, 1], [1], []] +# CHECK-NEXT: measures { +# CHECK-NEXT: (%[[ARG1:.*]]: tuple): +# CHECK-DAG: %[[V4:.*]] = field_reference %[[ARG0]][0] : tuple +# CHECK-DAG: %[[V5:.*]] = call @[[F1]](%[[V4]]) aggregate : (si32) -> si32 +# CHECK-DAG: %[[V6:.*]] = literal 0 : si1 +# CHECK-DAG: %[[V7:.*]] = call @[[F1]](%[[V6]]) aggregate : (si1) -> si1 +# CHECK-NEXT: yield %[[V5]], %[[V7]] : si32, si1 +# CHECK-NEXT: } +# CHECK-NEXT: yield %[[V1]] : tuple + +extension_uris { + uri: "http://some.url/with/extensions.yml" +} +extensions { + extension_function { + name: "somefunc" + } +} +relations { + rel { + aggregate { + common { + direct { + } + } + input { + read { + common { + direct { + } + } + base_schema { + names: "a" + struct { + types { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + nullability: NULLABILITY_REQUIRED + } + } + named_table { + names: "t1" + } + } + } + groupings { + grouping_expressions { + literal { + boolean: false + } + } + } + groupings { + grouping_expressions { + literal { + boolean: false + } + } + grouping_expressions { + literal { + boolean: true + } + } + } + groupings { + grouping_expressions { + literal { + boolean: true + } + } + } + groupings { + } + measures { + measure { + output_type { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + arguments { + value { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + } + } + } + measures { + measure { + output_type { + bool { + nullability: NULLABILITY_REQUIRED + } + } + arguments { + value { + literal { + boolean: false + } + } + } + } + } + } + } +} +version { + minor_number: 42 + patch_number: 1 +} + +# ----- + + +# CHECK-LABEL: substrait.plan +# CHECK-NEXT: relation +# CHECK-NEXT: %[[V0:.*]] = named_table +# CHECK-NEXT: %[[V1:.*]] = aggregate %[[V0]] : tuple -> tuple +# CHECK-NEXT: groupings { +# CHECK-NEXT: (%[[ARG0:.*]]: tuple): +# CHECK-DAG: %[[V2:.*]] = literal 0 : si1 +# CHECK-NEXT: yield %[[V2]] : si1 +# CHECK-NEXT: } +# CHECK-NEXT: yield %[[V1]] : tuple + +relations { + rel { + aggregate { + common { + direct { + } + } + input { + read { + common { + direct { + } + } + base_schema { + names: "a" + struct { + types { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + nullability: NULLABILITY_REQUIRED + } + } + named_table { + names: "t1" + } + } + } + groupings { + grouping_expressions { + literal { + boolean: false + } + } + } + } + } +} +version { + minor_number: 42 + patch_number: 1 +} + +# ----- + + +# CHECK-LABEL: substrait.plan +# CHECK-NEXT: extension_uri @[[URI:.*]] at "http://some.url/with/extensions.yml" +# CHECK-NEXT: extension_function @[[F1:.*]] at @[[URI]]["somefunc"] +# CHECK-NEXT: relation +# CHECK-NEXT: %[[V0:.*]] = named_table +# CHECK-NEXT: %[[V1:.*]] = aggregate %[[V0]] : tuple -> tuple +# CHECK-NEXT: measures { +# CHECK-NEXT: (%[[ARG1:.*]]: tuple): +# CHECK-DAG: %[[V2:.*]] = field_reference %[[ARG0]][0] : tuple +# CHECK-DAG: %[[V3:.*]] = call @[[F1]](%[[V2]]) aggregate all : (si32) -> si32 +# CHECK-DAG: %[[V4:.*]] = literal 0 : si1 +# CHECK-DAG: %[[V5:.*]] = call @[[F1]](%[[V4]]) aggregate distinct : (si1) -> si1 +# CHECK-NEXT: yield %[[V3]], %[[V5]] : si32, si1 +# CHECK-NEXT: } +# CHECK-NEXT: yield %[[V1]] : tuple + +extension_uris { + uri: "http://some.url/with/extensions.yml" +} +extensions { + extension_function { + name: "somefunc" + } +} +relations { + rel { + aggregate { + common { + direct { + } + } + input { + read { + common { + direct { + } + } + base_schema { + names: "a" + struct { + types { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + nullability: NULLABILITY_REQUIRED + } + } + named_table { + names: "t1" + } + } + } + groupings { + } + measures { + measure { + output_type { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + invocation: AGGREGATION_INVOCATION_ALL + arguments { + value { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + } + } + } + measures { + measure { + output_type { + bool { + nullability: NULLABILITY_REQUIRED + } + } + invocation: AGGREGATION_INVOCATION_DISTINCT + arguments { + value { + literal { + boolean: false + } + } + } + } + } + } + } +} +version { + minor_number: 42 + patch_number: 1 +} + +# ----- + + +# CHECK-LABEL: substrait.plan +# CHECK-NEXT: extension_uri @[[URI:.*]] at "http://some.url/with/extensions.yml" +# CHECK-NEXT: extension_function @[[F1:.*]] at @[[URI]]["somefunc"] +# CHECK-NEXT: relation +# CHECK-NEXT: %[[V0:.*]] = named_table +# CHECK-NEXT: %[[V1:.*]] = aggregate %[[V0]] : tuple -> tuple +# CHECK-NEXT: grouping_sets [] +# CHECK-NEXT: measures { +# CHECK-NEXT: (%[[ARG1:.*]]: tuple): +# CHECK-DAG: %[[V2:.*]] = field_reference %[[ARG0]][0] : tuple +# CHECK-DAG: %[[V3:.*]] = call @[[F1]](%[[V2]]) aggregate : (si32) -> si32 +# CHECK-NEXT: yield %[[V3]] : si32 +# CHECK-NEXT: } +# CHECK-NEXT: yield %[[V1]] : tuple + +extension_uris { + uri: "http://some.url/with/extensions.yml" +} +extensions { + extension_function { + name: "somefunc" + } +} +relations { + rel { + aggregate { + common { + direct { + } + } + input { + read { + common { + direct { + } + } + base_schema { + names: "a" + struct { + types { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + nullability: NULLABILITY_REQUIRED + } + } + named_table { + names: "t1" + } + } + } + measures { + measure { + output_type { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + arguments { + value { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + } + } + } + } + } +} +version { + minor_number: 42 + patch_number: 1 +} + +# ----- + + +# CHECK-LABEL: substrait.plan +# CHECK-NEXT: extension_uri @[[URI:.*]] at "http://some.url/with/extensions.yml" +# CHECK-NEXT: extension_function @[[F1:.*]] at @[[URI]]["somefunc"] +# CHECK-NEXT: relation +# CHECK-NEXT: %[[V0:.*]] = named_table +# CHECK-NEXT: %[[V1:.*]] = aggregate %[[V0]] : tuple -> tuple +# CHECK-NEXT: measures { +# CHECK-NEXT: (%[[ARG1:.*]]: tuple): +# CHECK-DAG: %[[V2:.*]] = field_reference %[[ARG0]][0] : tuple +# CHECK-DAG: %[[V3:.*]] = call @[[F1]](%[[V2]]) aggregate : (si32) -> si32 +# CHECK-NEXT: yield %[[V3]] : si32 +# CHECK-NEXT: } +# CHECK-NEXT: yield %[[V1]] : tuple + +extension_uris { + uri: "http://some.url/with/extensions.yml" +} +extensions { + extension_function { + name: "somefunc" + } +} +relations { + rel { + aggregate { + common { + direct { + } + } + input { + read { + common { + direct { + } + } + base_schema { + names: "a" + struct { + types { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + nullability: NULLABILITY_REQUIRED + } + } + named_table { + names: "t1" + } + } + } + groupings { + } + measures { + measure { + output_type { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + arguments { + value { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + } + } + } + } + } +} +version { + minor_number: 42 + patch_number: 1 +}