Skip to content

Commit

Permalink
feat: implement extension_table op
Browse files Browse the repository at this point in the history
This PR implements the `extension_table` op, the MLIR equivalent of the
`ReadRel.ExtensionTable` message. Since that message uses the
`google.protobuf.Any` message type, the PR also slightly extends and
improves the handling of that message in the dialect. Finally, because
`extension_table` is the second op corresponding to a `ReadRel` case
(after `named_table`), the PR makes some effort to factor out common
logic between the two, namely, how the named structs representing the
schema of the op are handled. However, there is still some opportunity
for factoring out further that the PR does not do, such as defining a
`ReadRelInterface`.

Signed-off-by: Ingo Müller <[email protected]>
  • Loading branch information
ingomueller-net committed Feb 6, 2025
1 parent 76ab7c1 commit 8ba4a2d
Show file tree
Hide file tree
Showing 9 changed files with 284 additions and 39 deletions.
3 changes: 3 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def Substrait_DefaultEmptyStringParameter
// Substrait attributes
//===----------------------------------------------------------------------===//

/// Attribute used for `google.protobuf.Any` messages.
def Substrait_AnyAttr : TypedStrAttr<Substrait_AnyType>;

def Substrait_AdvancedExtensionAttr
: Substrait_Attr<"AdvancedExtension", "advanced_extension"> {
let summary = "Represents the `AdvancedExtenssion` message of Substrait";
Expand Down
23 changes: 23 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,29 @@ def Substrait_EmitOp : Substrait_RelOp<"emit", [
}];
}

def Substrait_ExtensionTableOp : Substrait_RelOp<"extension_table"> {
let summary = "Extension table operation (i.e., a `ReadRel` case)";
let description = [{
Represents a `ExtensionTable` message together with the `ReadRel` and `Rel`
messages that contain it.

Example:

```mlir
%0 = extension_table
"some detail" : !substrait.any<"some url">
as ["a"] : tuple<si32>
```
}];
let arguments = (ins
StringArrayAttr:$field_names,
Substrait_AnyAttr:$detail
);
let results = (outs Substrait_Relation:$result);
let assemblyFormat = "$detail `as` $field_names attr-dict `:` type($result)";
let hasVerifier = true;
}

def Substrait_FetchOp : Substrait_RelOp<"fetch", [
SameOperandsAndResultType
]> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def Substrait_AtomicTypes {
}

def Substrait_AnyType : Substrait_Type<"Any", "any"> {
let summary = "Represents the `type_url` of a `google.protobuf.Any` message";
let summary = "type of a 'google.protobuf.Any' protobuf message";
let description = [{
This type represents the `type_url` fields of a `google.protobuf.Any`
message. These messages consist of an opaque byte array and a string holding
Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/Substrait/IR/Substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,12 @@ EmitOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
return success();
}

LogicalResult ExtensionTableOp::verify() {
llvm::ArrayRef<Attribute> fieldNames = getFieldNames().getValue();
auto tupleType = llvm::cast<TupleType>(getResult().getType());
return verifyNamedStruct(getOperation(), fieldNames, tupleType);
}

LogicalResult FieldReferenceOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
Expand Down
90 changes: 72 additions & 18 deletions lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class SubstraitExporter {
DECLARE_EXPORT_FUNC(CrossOp, Rel)
DECLARE_EXPORT_FUNC(EmitOp, Rel)
DECLARE_EXPORT_FUNC(ExpressionOpInterface, Expression)
DECLARE_EXPORT_FUNC(ExtensionTableOp, Rel)
DECLARE_EXPORT_FUNC(FieldReferenceOp, Expression)
DECLARE_EXPORT_FUNC(FetchOp, Rel)
DECLARE_EXPORT_FUNC(FilterOp, Rel)
Expand Down Expand Up @@ -81,6 +82,8 @@ class SubstraitExporter {
FailureOr<std::unique_ptr<Expression>> exportCallOpWindow(CallOp op);

std::unique_ptr<pb::Any> exportAny(StringAttr attr);
FailureOr<std::unique_ptr<NamedStruct>>
exportNamedStruct(Location loc, ArrayAttr fieldNames, TupleType relationType);
FailureOr<std::unique_ptr<pb::Message>> exportOperation(Operation *op);
FailureOr<std::unique_ptr<proto::Type>> exportType(Location loc,
mlir::Type mlirType);
Expand Down Expand Up @@ -631,6 +634,42 @@ SubstraitExporter::exportOperation(ExpressionOpInterface op) {
[](auto op) { return op->emitOpError("not supported for export"); });
}

FailureOr<std::unique_ptr<Rel>>
SubstraitExporter::exportOperation(ExtensionTableOp op) {
Location loc = op.getLoc();

// Build `RelCommon` message.
auto relCommon = std::make_unique<RelCommon>();
auto direct = std::make_unique<RelCommon::Direct>();
relCommon->set_allocated_direct(direct.release());

// Build `ExtensionTable` message.
StringAttr detailAttr = op.getDetailAttr();
std::unique_ptr<pb::Any> detail = exportAny(detailAttr);
auto extensionTable = std::make_unique<ReadRel::ExtensionTable>();
extensionTable->set_allocated_detail(detail.release());

// TODO(ingomueller): factor out commong logic of `ReadRel`.
// Export field names and result type into `base_schema`.
auto tupleType = llvm::cast<TupleType>(op.getResult().getType());
FailureOr<std::unique_ptr<NamedStruct>> baseSchema =
exportNamedStruct(loc, op.getFieldNames(), tupleType);
if (failed(baseSchema))
return failure();

// Build `ReadRel` message.
auto readRel = std::make_unique<ReadRel>();
readRel->set_allocated_common(relCommon.release());
readRel->set_allocated_extension_table(extensionTable.release());
readRel->set_allocated_base_schema(baseSchema->release());

// Build `Rel` message.
auto rel = std::make_unique<Rel>();
rel->set_allocated_read(readRel.release());

return rel;
}

FailureOr<std::unique_ptr<Expression>>
SubstraitExporter::exportOperation(FieldReferenceOp op) {
using FieldReference = Expression::FieldReference;
Expand Down Expand Up @@ -864,6 +903,31 @@ SubstraitExporter::exportOperation(ModuleOp op) {
return failure();
}

FailureOr<std::unique_ptr<NamedStruct>>
SubstraitExporter::exportNamedStruct(Location loc, ArrayAttr fieldNames,
TupleType relationType) {

// Build `Struct` message.
auto struct_ = std::make_unique<proto::Type::Struct>();
struct_->set_nullability(
Type_Nullability::Type_Nullability_NULLABILITY_REQUIRED);
for (mlir::Type fieldType : relationType.getTypes()) {
FailureOr<std::unique_ptr<proto::Type>> type = exportType(loc, fieldType);
if (failed(type))
return (failure());
*struct_->add_types() = *std::move(type.value());
}

// Build `NamedStruct` message.
auto namedStruct = std::make_unique<NamedStruct>();
namedStruct->set_allocated_struct_(struct_.release());
for (Attribute attr : fieldNames) {
namedStruct->add_names(mlir::cast<StringAttr>(attr).getValue().str());
}

return namedStruct;
}

FailureOr<std::unique_ptr<Rel>>
SubstraitExporter::exportOperation(NamedTableOp op) {
Location loc = op.getLoc();
Expand All @@ -880,29 +944,18 @@ SubstraitExporter::exportOperation(NamedTableOp op) {
auto direct = std::make_unique<RelCommon::Direct>();
relCommon->set_allocated_direct(direct.release());

// Build `Struct` message.
auto struct_ = std::make_unique<proto::Type::Struct>();
struct_->set_nullability(
Type_Nullability::Type_Nullability_NULLABILITY_REQUIRED);
// TODO(ingomueller): factor out commong logic of `ReadRel`.
// Export field names and result type into `base_schema`.
auto tupleType = llvm::cast<TupleType>(op.getResult().getType());
for (mlir::Type fieldType : tupleType.getTypes()) {
FailureOr<std::unique_ptr<proto::Type>> type = exportType(loc, fieldType);
if (failed(type))
return (failure());
*struct_->add_types() = *std::move(type.value());
}

// Build `NamedStruct` message.
auto namedStruct = std::make_unique<NamedStruct>();
namedStruct->set_allocated_struct_(struct_.release());
for (Attribute attr : op.getFieldNames()) {
namedStruct->add_names(mlir::cast<StringAttr>(attr).getValue().str());
}
FailureOr<std::unique_ptr<NamedStruct>> baseSchema =
exportNamedStruct(loc, op.getFieldNames(), tupleType);
if (failed(baseSchema))
return failure();

// Build `ReadRel` message.
auto readRel = std::make_unique<ReadRel>();
readRel->set_allocated_common(relCommon.release());
readRel->set_allocated_base_schema(namedStruct.release());
readRel->set_allocated_base_schema(baseSchema->release());
readRel->set_allocated_named_table(namedTable.release());

// Build `Rel` message.
Expand Down Expand Up @@ -1313,6 +1366,7 @@ SubstraitExporter::exportOperation(RelOpInterface op) {
AggregateOp,
CrossOp,
EmitOp,
ExtensionTableOp,
FetchOp,
FieldReferenceOp,
FilterOp,
Expand Down
87 changes: 67 additions & 20 deletions lib/Target/SubstraitPB/Import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ namespace pb = google::protobuf;

namespace {

using ImportedNamedStruct = std::tuple<ArrayAttr, TupleType>;

// Copied from
// https://github.com/llvm/llvm-project/blob/dea33c/mlir/lib/Transforms/CSE.cpp.
struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
Expand Down Expand Up @@ -80,10 +82,12 @@ DECLARE_IMPORT_FUNC(FetchRel, Rel, FetchOp)
DECLARE_IMPORT_FUNC(FilterRel, Rel, FilterOp)
DECLARE_IMPORT_FUNC(SetRel, Rel, SetOp)
DECLARE_IMPORT_FUNC(Expression, Expression, ExpressionOpInterface)
DECLARE_IMPORT_FUNC(ExtensionTable, Rel, ExtensionTableOp)
DECLARE_IMPORT_FUNC(FieldReference, Expression::FieldReference,
FieldReferenceOp)
DECLARE_IMPORT_FUNC(JoinRel, Rel, JoinOp)
DECLARE_IMPORT_FUNC(Literal, Expression::Literal, LiteralOp)
DECLARE_IMPORT_FUNC(NamedStruct, NamedStruct, ImportedNamedStruct)
DECLARE_IMPORT_FUNC(NamedTable, Rel, NamedTableOp)
DECLARE_IMPORT_FUNC(PlanRel, PlanRel, PlanRelOp)
DECLARE_IMPORT_FUNC(ProjectRel, Rel, ProjectOp)
Expand Down Expand Up @@ -442,6 +446,31 @@ importExpression(ImplicitLocOpBuilder builder, const Expression &message) {
}
}

static mlir::FailureOr<ExtensionTableOp>
importExtensionTable(ImplicitLocOpBuilder builder, const Rel &message) {
const ReadRel &readRel = message.read();
const ReadRel::ExtensionTable &extensionTable = readRel.extension_table();

// TODO(ingomueller): factor out common logic of `ReadRel`.
// Import base schema and extract result names and types.
const NamedStruct &baseSchema = readRel.base_schema();
FailureOr<ImportedNamedStruct> importedNamedStruct =
importNamedStruct(builder, baseSchema);
if (failed(importedNamedStruct))
return failure();
auto [fieldNamesAttr, resultType] = importedNamedStruct.value();

// Import `detail` attribute.
const pb::Any &detail = extensionTable.detail();
auto detailAttr = importAny(builder, detail).value();

// Assemble final op.
auto extensionTableOp =
builder.create<ExtensionTableOp>(resultType, fieldNamesAttr, detailAttr);

return extensionTableOp;
}

static mlir::FailureOr<FieldReferenceOp>
importFieldReference(ImplicitLocOpBuilder builder,
const Expression::FieldReference &message) {
Expand Down Expand Up @@ -648,6 +677,34 @@ static mlir::FailureOr<FilterOp> importFilterRel(ImplicitLocOpBuilder builder,
return filterOp;
}

static mlir::FailureOr<ImportedNamedStruct>
importNamedStruct(ImplicitLocOpBuilder builder, const NamedStruct &message) {
MLIRContext *context = builder.getContext();

// Assemble field names from schema.
llvm::SmallVector<Attribute> fieldNames;
fieldNames.reserve(message.names_size());
for (const std::string &name : message.names()) {
auto attr = StringAttr::get(context, name);
fieldNames.push_back(attr);
}
auto fieldNamesAttr = ArrayAttr::get(context, fieldNames);

// Assemble field types from schema.
const proto::Type::Struct &struct_ = message.struct_();
llvm::SmallVector<mlir::Type> resultTypes;
resultTypes.reserve(struct_.types_size());
for (const proto::Type &type : struct_.types()) {
FailureOr<mlir::Type> mlirType = importType(context, type);
if (failed(mlirType))
return failure();
resultTypes.push_back(mlirType.value());
}
auto resultType = TupleType::get(context, resultTypes);

return ImportedNamedStruct{fieldNamesAttr, resultType};
}

static mlir::FailureOr<NamedTableOp>
importNamedTable(ImplicitLocOpBuilder builder, const Rel &message) {
const ReadRel &readRel = message.read();
Expand All @@ -667,27 +724,14 @@ importNamedTable(ImplicitLocOpBuilder builder, const Rel &message) {
auto tableName =
SymbolRefAttr::get(context, tableNameRootRef, tableNameNestedRefs);

// Assemble field names from schema.
// TODO(ingomueller): factor out common logic of `ReadRel`.
// Import base schema and extract result names and types.
const NamedStruct &baseSchema = readRel.base_schema();
llvm::SmallVector<Attribute> fieldNames;
fieldNames.reserve(baseSchema.names_size());
for (const std::string &name : baseSchema.names()) {
auto attr = StringAttr::get(context, name);
fieldNames.push_back(attr);
}
auto fieldNamesAttr = ArrayAttr::get(context, fieldNames);

// Assemble field names from schema.
const proto::Type::Struct &struct_ = baseSchema.struct_();
llvm::SmallVector<mlir::Type> resultTypes;
resultTypes.reserve(struct_.types_size());
for (const proto::Type &type : struct_.types()) {
FailureOr<mlir::Type> mlirType = importType(context, type);
if (failed(mlirType))
return failure();
resultTypes.push_back(mlirType.value());
}
auto resultType = TupleType::get(context, resultTypes);
FailureOr<ImportedNamedStruct> importedNamedStruct =
importNamedStruct(builder, baseSchema);
if (failed(importedNamedStruct))
return failure();
auto [fieldNamesAttr, resultType] = importedNamedStruct.value();

// Assemble final op.
auto namedTableOp =
Expand Down Expand Up @@ -910,6 +954,9 @@ importReadRel(ImplicitLocOpBuilder builder, const Rel &message) {
const ReadRel &readRel = message.read();
ReadRel::ReadTypeCase readType = readRel.read_type_case();
switch (readType) {
case ReadRel::ReadTypeCase::kExtensionTable: {
return importExtensionTable(builder, message);
}
case ReadRel::ReadTypeCase::kNamedTable: {
return importNamedTable(builder, message);
}
Expand Down
18 changes: 18 additions & 0 deletions test/Dialect/Substrait/extension-table.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: substrait-opt -split-input-file %s \
// RUN: | FileCheck %s

// CHECK-LABEL: substrait.plan
// CHECK: relation
// CHECK: %[[V0:.*]] = extension_table
// CHECK-SAME: "some detail" : !substrait.any<"some url">
// CHECK-SAME: as ["a"] : tuple<si32>
// CHECK-NEXT: yield %[[V0]] : tuple<si32>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = extension_table
"some detail" : !substrait.any<"some url">
as ["a"] : tuple<si32>
yield %0 : tuple<si32>
}
}
Loading

0 comments on commit 8ba4a2d

Please sign in to comment.