Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: factor out advanced_extension logic and add it to other ops #65

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions include/substrait-mlir/Dialect/Substrait/IR/Substrait.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,38 @@
#include "mlir/IR/SymbolTable.h" // IWYU: keep
#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU: keep

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.h.inc" // IWYU: export
//===----------------------------------------------------------------------===//
// Substrait dialect
//===----------------------------------------------------------------------===//

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsDialect.h.inc" // IWYU: export

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.h.inc" // IWYU: export
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.h.inc" // IWYU: export
//===----------------------------------------------------------------------===//
// Substrait enums
//===----------------------------------------------------------------------===//

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitEnums.h.inc" // IWYU: export

//===----------------------------------------------------------------------===//
// Substrait types
//===----------------------------------------------------------------------===//

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.h.inc" // IWYU: export
#define GET_TYPEDEF_CLASSES
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsTypes.h.inc" // IWYU: export

//===----------------------------------------------------------------------===//
// Substrait attributes
//===----------------------------------------------------------------------===//

#define GET_ATTRDEF_CLASSES
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsAttrs.h.inc" // IWYU: export

//===----------------------------------------------------------------------===//
// Substrait ops
//===----------------------------------------------------------------------===//

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.h.inc" // IWYU: export
#define GET_OP_CLASSES
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOps.h.inc" // IWYU: export

Expand Down
20 changes: 20 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,26 @@ def Substrait_ExpressionOpInterface : OpInterface<"ExpressionOpInterface"> {
let cppNamespace = "::mlir::substrait";
}

def Substrait_ExtensibleOpInterface : OpInterface<"ExtensibleOpInterface"> {
let description = [{
Interface for ops with the `advanced_extension` attribute. Several relations
and other message types of the Substrait specification have a field with the
same name (or the variant `advanced_extensions`, which has the same meaning)
and the interface enables handling all of them transparently.
}];
let cppNamespace = "::mlir::substrait";
let methods = [
InterfaceMethod<
"Get the `advanced_extension` attribute",
"std::optional<::mlir::substrait::AdvancedExtensionAttr>",
"getAdvancedExtension">,
InterfaceMethod<
"Get the `advanced_extension` attribute",
"void", "setAdvancedExtensionAttr",
(ins "::mlir::substrait::AdvancedExtensionAttr":$attr)>,
];
}

def Substrait_RelOpInterface : OpInterface<"RelOpInterface"> {
let description = [{
Interface for any relational operation in a Substrait plan. This corresponds
Expand Down
18 changes: 14 additions & 4 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def PlanBodyOp : AnyOf<[

def Substrait_PlanOp : Substrait_Op<"plan", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>,
DeclareOpInterfaceMethods<Substrait_ExtensibleOpInterface>,
NoTerminator, NoRegionArguments, SingleBlock, SymbolTable
]> {
let summary = "Represents a Substrait plan";
Expand Down Expand Up @@ -178,9 +179,13 @@ def Substrait_PlanOp : Substrait_Op<"plan", [
let builders = [
OpBuilder<(ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch), [{
build($_builder, $_state, major, minor, patch,
/*git_hash=*/StringAttr(), /*producer*/StringAttr(),
/*git_hash=*/"", /*producer*/"");
}]>,
OpBuilder<(ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch,
"std::string":$git_hash, "std::string":$producer), [{
build($_builder, $_state, major, minor, patch, git_hash, producer,
/*advanced_extension=*/AdvancedExtensionAttr());
}]>
}]>,
];
let extraClassDefinition = [{
/// Implement OpAsmOpInterface.
Expand Down Expand Up @@ -527,6 +532,7 @@ def Substrait_NamedTableOp : Substrait_RelOp<"named_table", [

def Substrait_ProjectOp : Substrait_RelOp<"project", [
SingleBlockImplicitTerminator<"::mlir::substrait::YieldOp">,
DeclareOpInterfaceMethods<Substrait_ExtensibleOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>
]> {
let summary = "Project operation";
Expand All @@ -551,14 +557,18 @@ def Substrait_ProjectOp : Substrait_RelOp<"project", [
}
```
}];
let arguments = (ins Substrait_Relation:$input);
let arguments = (ins
Substrait_Relation:$input,
OptionalAttr<Substrait_AdvancedExtensionAttr>:$advanced_extension
);
let regions = (region AnyRegion:$expressions);
let results = (outs Substrait_Relation:$result);
// TODO(ingomueller): We could elide/shorten the block argument from the
// assembly by writing custom printers/parsers similar to
// `scf.for` etc.
let assemblyFormat = [{
$input attr-dict `:` type($input) `->` type($result) $expressions
$input (`advanced_extension` `` $advanced_extension^)?
attr-dict `:` type($input) `->` type($result) $expressions
}];
let hasRegionVerifier = 1;
let hasFolder = 1;
Expand Down
58 changes: 38 additions & 20 deletions lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

using namespace mlir;
using namespace mlir::substrait;
using namespace mlir::substrait::protobuf_utils;
using namespace ::substrait;
using namespace ::substrait::proto;

Expand Down Expand Up @@ -60,6 +61,8 @@ class SubstraitExporter {
DECLARE_EXPORT_FUNC(RelOpInterface, Rel)
DECLARE_EXPORT_FUNC(SetOp, Rel)

template <typename MessageType>
void exportAdvancedExtension(ExtensibleOpInterface op, MessageType &message);
std::unique_ptr<pb::Any> exportAny(StringAttr attr);
FailureOr<std::unique_ptr<pb::Message>> exportOperation(Operation *op);
FailureOr<std::unique_ptr<proto::Type>> exportType(Location loc,
Expand Down Expand Up @@ -91,6 +94,36 @@ class SubstraitExporter {
std::unique_ptr<SymbolTable> symbolTable; // Symbol table cache.
};

template <typename MessageType>
void SubstraitExporter::exportAdvancedExtension(ExtensibleOpInterface op,
MessageType &message) {
if (!op.getAdvancedExtension())
return;

// Build the base `AdvancedExtension` message.
AdvancedExtensionAttr extensionAttr = op.getAdvancedExtension().value();
auto extension = std::make_unique<extensions::AdvancedExtension>();

StringAttr optimizationAttr = extensionAttr.getOptimization();
StringAttr enhancementAttr = extensionAttr.getEnhancement();

// Set `optimization` field if present.
if (optimizationAttr) {
std::unique_ptr<pb::Any> optimization = exportAny(optimizationAttr);
extension->set_allocated_optimization(optimization.release());
}

// Set `enhancement` field if present.
if (enhancementAttr) {
std::unique_ptr<pb::Any> enhancement = exportAny(enhancementAttr);
extension->set_allocated_enhancement(enhancement.release());
}

// Set the `advanced_extension` field in the provided message.
using Trait = advanced_extension_trait<MessageType>;
Trait::set_allocated_advanced_extension(message, extension.release());
}

std::unique_ptr<pb::Any> SubstraitExporter::exportAny(StringAttr attr) {
auto any = std::make_unique<pb::Any>();
auto anyType = mlir::cast<AnyType>(attr.getType());
Expand Down Expand Up @@ -874,26 +907,8 @@ FailureOr<std::unique_ptr<Plan>> SubstraitExporter::exportOperation(PlanOp op) {
version->set_git_hash(op.getGitHash().str());
plan->set_allocated_version(version.release());

// Build `AdvancedExtension` message.
if (op.getAdvancedExtension()) {
AdvancedExtensionAttr extensionAttr = op.getAdvancedExtension().value();
auto extension = std::make_unique<extensions::AdvancedExtension>();

StringAttr optimizationAttr = extensionAttr.getOptimization();
StringAttr enhancementAttr = extensionAttr.getEnhancement();

if (optimizationAttr) {
std::unique_ptr<pb::Any> optimization = exportAny(optimizationAttr);
extension->set_allocated_optimization(optimization.release());
}

if (enhancementAttr) {
std::unique_ptr<pb::Any> enhancement = exportAny(enhancementAttr);
extension->set_allocated_enhancement(enhancement.release());
}

plan->set_allocated_advanced_extensions(extension.release());
}
// Attach the `AdvancedExtension` message if the attribute exists.
exportAdvancedExtension(op, *plan);

// Add `extension_uris` to plan.
{
Expand Down Expand Up @@ -1024,6 +1039,9 @@ SubstraitExporter::exportOperation(ProjectOp op) {
*projectRel->add_expressions() = *expression.value();
}

// Attach the `AdvancedExtension` message if the attribute exists.
exportAdvancedExtension(op, *projectRel);

// Build `Rel` message.
auto rel = std::make_unique<Rel>();
rel->set_allocated_project(projectRel.release());
Expand Down
71 changes: 48 additions & 23 deletions lib/Target/SubstraitPB/Import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

using namespace mlir;
using namespace mlir::substrait;
using namespace mlir::substrait::protobuf_utils;
using namespace ::substrait;
using namespace ::substrait::proto;

Expand Down Expand Up @@ -64,6 +65,46 @@ DECLARE_IMPORT_FUNC(ReadRel, Rel, RelOpInterface)
DECLARE_IMPORT_FUNC(Rel, Rel, RelOpInterface)
DECLARE_IMPORT_FUNC(ScalarFunction, Expression::ScalarFunction, CallOp)

/// If present, imports the `advanced_extension` or `advanced_extensions` field
/// from the given message and sets the obtained attribute on the given op.
template <typename MessageType>
void importAdvancedExtension(ImplicitLocOpBuilder builder,
ExtensibleOpInterface op,
const MessageType &message);

template <typename MessageType>
void importAdvancedExtension(ImplicitLocOpBuilder builder,
ExtensibleOpInterface op,
const MessageType &message) {
using Trait = advanced_extension_trait<MessageType>;
if (!Trait::has_advanced_extension(message))
return;

// Get the `advanced_extension(s)` field.
const extensions::AdvancedExtension &advancedExtension =
Trait::advanced_extension(message);

// Import `optimization` field if present.
StringAttr optimizationAttr;
if (advancedExtension.has_optimization()) {
const pb::Any &optimization = advancedExtension.optimization();
optimizationAttr = importAny(builder, optimization).value();
}

// Import `enhancement` field if present.
StringAttr enhancementAttr;
if (advancedExtension.has_enhancement()) {
const pb::Any &enhancement = advancedExtension.enhancement();
enhancementAttr = importAny(builder, enhancement).value();
}

// Build attribute and set it on the op.
MLIRContext *context = builder.getContext();
auto advancedExtensionAttr =
AdvancedExtensionAttr::get(context, optimizationAttr, enhancementAttr);
op.setAdvancedExtensionAttr(advancedExtensionAttr);
}

FailureOr<StringAttr> importAny(ImplicitLocOpBuilder builder,
const pb::Any &message) {
MLIRContext *context = builder.getContext();
Expand Down Expand Up @@ -492,34 +533,15 @@ static FailureOr<PlanOp> importPlan(ImplicitLocOpBuilder builder,
// Import version.
const Version &version = message.version();

// Import advanced extension.
AdvancedExtensionAttr advancedExtensionAttr;
if (message.has_advanced_extensions()) {
const extensions::AdvancedExtension &advancedExtension =
message.advanced_extensions();

StringAttr optimizationAttr;
if (advancedExtension.has_optimization()) {
const pb::Any &optimization = advancedExtension.optimization();
optimizationAttr = importAny(builder, optimization).value();
}

StringAttr enhancementAttr;
if (advancedExtension.has_enhancement()) {
const pb::Any &enhancement = advancedExtension.enhancement();
enhancementAttr = importAny(builder, enhancement).value();
}

advancedExtensionAttr =
AdvancedExtensionAttr::get(context, optimizationAttr, enhancementAttr);
}

// Build `PlanOp`.
auto planOp = builder.create<PlanOp>(
version.major_number(), version.minor_number(), version.patch_number(),
version.git_hash(), version.producer(), advancedExtensionAttr);
version.git_hash(), version.producer());
planOp.getBody().push_back(new Block());

// Import advanced extension if it is present.
importAdvancedExtension(builder, planOp, message);

OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToEnd(&planOp.getBody().front());

Expand Down Expand Up @@ -680,6 +702,9 @@ static mlir::FailureOr<ProjectOp> importProjectRel(ImplicitLocOpBuilder builder,
builder.create<ProjectOp>(resultType, inputOp.value()->getResult(0));
projectOp.getExpressions().push_back(conditionBlock.release());

// Import advanced extension if it is present.
importAdvancedExtension(builder, projectOp, projectRel);

return projectOp;
}

Expand Down
55 changes: 55 additions & 0 deletions lib/Target/SubstraitPB/ProtobufUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#ifndef LIB_TARGET_SUBSTRAITPB_PROTOBUFUTILS_H
#define LIB_TARGET_SUBSTRAITPB_PROTOBUFUTILS_H

#include <type_traits>

#include "mlir/IR/Location.h"

namespace substrait::proto {
Expand All @@ -28,6 +30,59 @@ getCommon(const ::substrait::proto::Rel &rel, Location loc);
FailureOr<::substrait::proto::RelCommon *>
getMutableCommon(::substrait::proto::Rel *rel, Location loc);

/// SFINAE-based template that checks if the given (message) type has an field
/// called `advanced_extension`: the `value` member is `true` iff it has. This
/// is useful to deal with the two different names, `advanced_extension` and
/// `advanced_extensions`, that are used for the same thing across different
/// message types in the Substrait spec.
template <typename T>
class has_advanced_extensions {
template <typename C>
static std::true_type test(decltype(&C::advanced_extensions));
template <typename C>
static std::false_type test(...);

public:
static constexpr bool value = decltype(test<T>(0))::value;
};

/// Trait class for accessing the `advanced_extension` field. The default
/// instances is automatically used for message types that call this field
/// `advanced_extension`; the specialization below is automatically used for
/// message types that call it `advanced_extensions`.
template <typename T, typename = void>
struct advanced_extension_trait {
static auto has_advanced_extension(const T &message) {
return message.has_advanced_extension();
}
static auto advanced_extension(const T &message) {
return message.advanced_extension();
}
template <typename S>
static auto set_allocated_advanced_extension(T &message,
S &&advanced_extensions) {
message.set_allocated_advanced_extension(
std::forward<S>(advanced_extensions));
}
};

template <typename T>
struct advanced_extension_trait<
T, std::enable_if_t<has_advanced_extensions<T>::value>> {
static auto has_advanced_extension(const T &message) {
return message.has_advanced_extensions();
}
static auto advanced_extension(const T &message) {
return message.advanced_extensions();
}
template <typename S>
static auto set_allocated_advanced_extension(T &message,
S &&advanced_extensions) {
message.set_allocated_advanced_extensions(
std::forward<S>(advanced_extensions));
}
};

} // namespace mlir::substrait::protobuf_utils

#endif // LIB_TARGET_SUBSTRAITPB_PROTOBUFUTILS_H
Loading