Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement support for the PlanVersion message type #69

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class Substrait_StaticallyTypedAttr<string name, string attrMnemonic,
}];
}

/// `StringAttr` parameter that is the empty string by default.
def Substrait_DefaultEmptyStringParameter
: DefaultValuedParameter<"StringAttr", [{mlir::StringAttr::get($_ctxt, "")}]>;

//===----------------------------------------------------------------------===//
// Substrait attributes
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -94,6 +98,40 @@ def Substrait_TimestampTzAttr
let assemblyFormat = [{ `<` $value `` `us` `>` }];
}

def Substrait_VersionAttr : Substrait_Attr<"Version", "version"> {
let summary = "Substrait version";
let description = [{
Represents the `Version` message type.
}];
let parameters = (ins
"uint32_t":$major_number,
"uint32_t":$minor_number,
"uint32_t":$patch_number,
Substrait_DefaultEmptyStringParameter:$git_hash,
Substrait_DefaultEmptyStringParameter:$producer

);
// TODO(ingomueller): make this even nicer with custom printer/parser
let assemblyFormat = [{
`` $major_number `` `:` `` $minor_number `` `:` `` $patch_number
(`git_hash` $git_hash^)? (`producer` $producer^)?
}];
let builders = [
AttrBuilder<(ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch), [{
return $_get($_ctxt, major, minor, patch,
/*git_hash=*/StringAttr(),
/*producer=*/StringAttr());
}]>,
AttrBuilder<(ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch,
"::llvm::StringRef":$git_hash,
"::llvm::StringRef":$producer), [{
auto gitHashAttr = ::mlir::StringAttr::get($_ctxt, git_hash);
auto producerAttr = ::mlir::StringAttr::get($_ctxt, producer);
return $_get($_ctxt, major, minor, patch, gitHashAttr, producerAttr);
}]>,
];
}

//===----------------------------------------------------------------------===//
// Helpers and constraints
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 15 additions & 2 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ def Substrait_ExtensionTypeVariationOp :
}

//===----------------------------------------------------------------------===//
// Plan
// Top-level ops
//===----------------------------------------------------------------------===//
// The definitions in this section are related to the top-level `Plan` message.
// The definitions in this section are related to the top-level messages.
// See https://substrait.io/serialization/binary_serialization/ and
// https://github.com/substrait-io/substrait/blob/main/proto/substrait/plan.proto.
//===----------------------------------------------------------------------===//
Expand All @@ -149,6 +149,19 @@ def PlanBodyOp : AnyOf<[
IsOp<"::mlir::substrait::ExtensionTypeVariationOp">,
]>;

def Substrait_PlanVersionOp : Substrait_Op<"plan_version"> {
let summary = "Represents a stand-alone plan version";
let description = [{
This op represents the `PlanVersion` message type of Substrait. It carries
the version information as an attribute, so it also subsumes the `Version`
message type.
}];
let arguments = (ins
Substrait_VersionAttr:$version
);
let assemblyFormat = "$version attr-dict";
}

def Substrait_PlanOp : Substrait_Op<"plan", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>,
DeclareOpInterfaceMethods<Substrait_ExtensibleOpInterface>,
Expand Down
8 changes: 6 additions & 2 deletions include/substrait-mlir/Target/SubstraitPB/Import.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@ class OwningOpRef;
namespace substrait {

OwningOpRef<ModuleOp>
translateProtobufToSubstrait(llvm::StringRef input, MLIRContext *context,
substrait::ImportExportOptions options = {});
translateProtobufToSubstraitPlan(llvm::StringRef input, MLIRContext *context,
substrait::ImportExportOptions options = {});

OwningOpRef<ModuleOp> translateProtobufToSubstraitPlanVersion(
llvm::StringRef input, MLIRContext *context,
substrait::ImportExportOptions options = {});

} // namespace substrait
} // namespace mlir
Expand Down
2 changes: 1 addition & 1 deletion lib/CAPI/Dialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ MlirModule mlirSubstraitImportPlan(MlirContext context, MlirStringRef input,
ImportExportOptions options;
options.serdeFormat = convertSerdeFormat(format);
OwningOpRef<ModuleOp> owning =
translateProtobufToSubstrait(unwrap(input), unwrap(context), options);
translateProtobufToSubstraitPlan(unwrap(input), unwrap(context), options);
if (!owning)
return MlirModule{nullptr};
return MlirModule{owning.release().getOperation()};
Expand Down
33 changes: 28 additions & 5 deletions lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ class SubstraitExporter {
DECLARE_EXPORT_FUNC(FilterOp, Rel)
DECLARE_EXPORT_FUNC(JoinOp, Rel)
DECLARE_EXPORT_FUNC(LiteralOp, Expression)
DECLARE_EXPORT_FUNC(ModuleOp, Plan)
DECLARE_EXPORT_FUNC(ModuleOp, pb::Message)
DECLARE_EXPORT_FUNC(NamedTableOp, Rel)
DECLARE_EXPORT_FUNC(PlanOp, Plan)
DECLARE_EXPORT_FUNC(PlanVersionOp, PlanVersion)
DECLARE_EXPORT_FUNC(ProjectOp, Rel)
DECLARE_EXPORT_FUNC(RelOpInterface, Rel)
DECLARE_EXPORT_FUNC(SetOp, Rel)
Expand Down Expand Up @@ -842,7 +843,7 @@ SubstraitExporter::exportOperation(LiteralOp op) {
return expression;
}

FailureOr<std::unique_ptr<Plan>>
FailureOr<std::unique_ptr<pb::Message>>
SubstraitExporter::exportOperation(ModuleOp op) {
if (!op->getAttrs().empty()) {
op->emitOpError("has attributes");
Expand All @@ -855,8 +856,9 @@ SubstraitExporter::exportOperation(ModuleOp op) {
return failure();
}

if (auto plan = llvm::dyn_cast<PlanOp>(&*body.op_begin()))
return exportOperation(plan);
Operation *innerOp = &*body.op_begin();
if (llvm::isa<PlanOp, PlanVersionOp>(innerOp))
return exportOperation(innerOp);

op->emitOpError("contains an op that is not a 'substrait.plan'");
return failure();
Expand Down Expand Up @@ -1111,6 +1113,27 @@ FailureOr<std::unique_ptr<Plan>> SubstraitExporter::exportOperation(PlanOp op) {
return std::move(plan);
}

FailureOr<std::unique_ptr<PlanVersion>>
SubstraitExporter::exportOperation(PlanVersionOp op) {
VersionAttr versionAttr = op.getVersion();

// Build `Version` message.
auto version = std::make_unique<Version>();
version->set_major_number(versionAttr.getMajorNumber());
version->set_minor_number(versionAttr.getMinorNumber());
version->set_patch_number(versionAttr.getPatchNumber());
if (versionAttr.getProducer())
version->set_producer(versionAttr.getProducer().str());
if (versionAttr.getGitHash())
version->set_git_hash(versionAttr.getGitHash().str());

// Build `PlanVersion` message.
auto planVersion = std::make_unique<PlanVersion>();
planVersion->set_allocated_version(version.release());

return planVersion;
}

FailureOr<std::unique_ptr<Rel>>
SubstraitExporter::exportOperation(ProjectOp op) {
// Build `RelCommon` message.
Expand Down Expand Up @@ -1309,7 +1332,7 @@ FailureOr<std::unique_ptr<pb::Message>>
SubstraitExporter::exportOperation(Operation *op) {
return llvm::TypeSwitch<Operation *, FailureOr<std::unique_ptr<pb::Message>>>(
op)
.Case<ModuleOp, PlanOp>(
.Case<ModuleOp, PlanOp, PlanVersionOp>(
[&](auto op) -> FailureOr<std::unique_ptr<pb::Message>> {
auto typedMessage = exportOperation(op);
if (failed(typedMessage))
Expand Down
78 changes: 55 additions & 23 deletions lib/Target/SubstraitPB/Import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,13 @@ DECLARE_IMPORT_FUNC(FieldReference, Expression::FieldReference,
DECLARE_IMPORT_FUNC(JoinRel, Rel, JoinOp)
DECLARE_IMPORT_FUNC(Literal, Expression::Literal, LiteralOp)
DECLARE_IMPORT_FUNC(NamedTable, Rel, NamedTableOp)
DECLARE_IMPORT_FUNC(Plan, Plan, PlanOp)
DECLARE_IMPORT_FUNC(PlanRel, PlanRel, PlanRelOp)
DECLARE_IMPORT_FUNC(ProjectRel, Rel, ProjectOp)
DECLARE_IMPORT_FUNC(ReadRel, Rel, RelOpInterface)
DECLARE_IMPORT_FUNC(Rel, Rel, RelOpInterface)
DECLARE_IMPORT_FUNC(ScalarFunction, Expression::ScalarFunction, CallOp)
DECLARE_IMPORT_FUNC(TopLevel, Plan, PlanOp)
DECLARE_IMPORT_FUNC(TopLevel, PlanVersion, PlanVersionOp)

/// If present, imports the `advanced_extension` or `advanced_extensions` field
/// from the given message and sets the obtained attribute on the given op.
Expand Down Expand Up @@ -695,8 +696,8 @@ importNamedTable(ImplicitLocOpBuilder builder, const Rel &message) {
return namedTableOp;
}

static FailureOr<PlanOp> importPlan(ImplicitLocOpBuilder builder,
const Plan &message) {
static FailureOr<PlanOp> importTopLevel(ImplicitLocOpBuilder builder,
const Plan &message) {
using extensions::SimpleExtensionDeclaration;
using extensions::SimpleExtensionURI;
using ExtensionFunction = SimpleExtensionDeclaration::ExtensionFunction;
Expand Down Expand Up @@ -838,6 +839,15 @@ static FailureOr<PlanRelOp> importPlanRel(ImplicitLocOpBuilder builder,
return planRelOp;
}

static FailureOr<PlanVersionOp> importTopLevel(ImplicitLocOpBuilder builder,
const PlanVersion &message) {
const Version &version = message.version();
auto versionAttr = VersionAttr::get(
builder.getContext(), version.major_number(), version.minor_number(),
version.patch_number(), version.git_hash(), version.producer());
return builder.create<PlanVersionOp>(versionAttr);
}

static mlir::FailureOr<ProjectOp> importProjectRel(ImplicitLocOpBuilder builder,
const Rel &message) {
const ProjectRel &projectRel = message.project();
Expand Down Expand Up @@ -1028,53 +1038,75 @@ FailureOr<CallOp> importFunctionCommon(ImplicitLocOpBuilder builder,
return {callOp};
}

} // namespace

namespace mlir {
namespace substrait {

OwningOpRef<ModuleOp>
translateProtobufToSubstrait(llvm::StringRef input, MLIRContext *context,
ImportExportOptions options) {
template <typename MessageType>
OwningOpRef<ModuleOp> translateProtobufToSubstraitTopLevel(
llvm::StringRef input, MLIRContext *context, ImportExportOptions options,
MessageType &message) {
Location loc = UnknownLoc::get(context);
auto plan = std::make_unique<Plan>();

// Parse from serialized form into desired protobuf `MessageType`.
switch (options.serdeFormat) {
case substrait::SerdeFormat::kText:
if (!pb::TextFormat::ParseFromString(input.str(), plan.get())) {
emitError(loc) << "could not parse string as 'Plan' message.";
case SerdeFormat::kText:
if (!pb::TextFormat::ParseFromString(input.str(), &message)) {
emitError(loc) << "could not parse string as '" << message.GetTypeName()
<< "' message.";
return {};
}
break;
case substrait::SerdeFormat::kBinary:
if (!plan->ParseFromString(input.str())) {
emitError(loc) << "could not deserialize input as 'Plan' message.";
case SerdeFormat::kBinary:
if (!message.ParseFromString(input.str())) {
emitError(loc) << "could not deserialize input as '"
<< message.GetTypeName() << "' message.";
return {};
}
break;
case substrait::SerdeFormat::kJson:
case substrait::SerdeFormat::kPrettyJson: {
case SerdeFormat::kJson:
case SerdeFormat::kPrettyJson: {
pb::util::Status status =
pb::util::JsonStringToMessage(input.str(), plan.get());
pb::util::JsonStringToMessage(input.str(), &message);
if (!status.ok()) {
emitError(loc) << "could not deserialize JSON as 'Plan' message:\n"
emitError(loc) << "could not deserialize JSON as '"
<< message.GetTypeName() << "' message:\n"
<< status.message().as_string();
return {};
}
}
}

// Set up infra for importing.
context->loadDialect<SubstraitDialect>();

ImplicitLocOpBuilder builder(loc, context);
auto module = builder.create<ModuleOp>(loc);
auto moduleRef = OwningOpRef<ModuleOp>(module);
builder.setInsertionPointToEnd(&module.getBodyRegion().back());

if (failed(importPlan(builder, *plan)))
// Import protobuf message into corresponding MLIR op.
if (failed(importTopLevel(builder, message)))
return {};

return moduleRef;
}

} // namespace

namespace mlir {
namespace substrait {

OwningOpRef<ModuleOp>
translateProtobufToSubstraitPlan(llvm::StringRef input, MLIRContext *context,
ImportExportOptions options) {

Plan plan;
return translateProtobufToSubstraitTopLevel(input, context, options, plan);
}

OwningOpRef<ModuleOp> translateProtobufToSubstraitPlanVersion(
llvm::StringRef input, MLIRContext *context, ImportExportOptions options) {
PlanVersion planVersion;
return translateProtobufToSubstraitTopLevel(input, context, options,
planVersion);
}

} // namespace substrait
} // namespace mlir
17 changes: 17 additions & 0 deletions test/Dialect/Substrait/plan-version.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: substrait-opt -split-input-file %s \
// RUN: | FileCheck %s

// CHECK: substrait.plan_version 0:42:1 git_hash "hash" producer "producer"
substrait.plan_version 0:42:1 git_hash "hash" producer "producer"

// CHECK-NEXT: substrait.plan_version 1:2:3 producer "other producer"{{$}}
substrait.plan_version 1:2:3 producer "other producer"

// CHECK-NEXT: substrait.plan_version 1:33:7 git_hash "other hash"{{$}}
substrait.plan_version 1:33:7 git_hash "other hash"

// CHECK-NEXT: substrait.plan_version 3:2:1{{$}}
substrait.plan_version 3:2:1

// CHECK-NEXT: substrait.plan_version 6:6:6{{$}}
substrait.plan_version 6:6:6 git_hash "" producer ""
60 changes: 60 additions & 0 deletions test/Target/SubstraitPB/Export/plan-version.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// RUN: substrait-translate -substrait-to-protobuf --split-input-file %s \
// RUN: | FileCheck %s

// RUN: substrait-translate -substrait-to-protobuf %s \
// RUN: --split-input-file --output-split-marker="# -----" \
// RUN: | substrait-translate -protobuf-to-substrait-plan-version \
// RUN: --split-input-file="# -----" --output-split-marker="// ""-----" \
// RUN: | substrait-translate -substrait-to-protobuf \
// RUN: --split-input-file --output-split-marker="# -----" \
// RUN: | FileCheck %s

// CHECK-LABEL: version {
// CHECK-DAG: minor_number: 42
// CHECK-DAG: patch_number: 1
// CHECK-DAG: git_hash: "hash"
// CHECK-DAG: producer: "producer"
// CHECK-NEXT: }

substrait.plan_version 0:42:1 git_hash "hash" producer "producer"

// -----

// CHECK-LABEL: version {
// CHECK-DAG: major_number: 1
// CHECK-DAG: minor_number: 2
// CHECK-DAG: patch_number: 3
// CHECK-DAG: producer: "other producer"
// CHECK-NEXT: }

substrait.plan_version 1:2:3 producer "other producer"

// -----

// CHECK-LABEL: version {
// CHECK-DAG: major_number: 1
// CHECK-DAG: minor_number: 33
// CHECK-DAG: patch_number: 7
// CHECK-DAG: git_hash: "other hash"
// CHECK-NEXT: }

substrait.plan_version 1:33:7 git_hash "other hash"

// -----

// CHECK-LABEL: version {
// CHECK-DAG: major_number: 3
// CHECK-DAG: minor_number: 2
// CHECK-DAG: patch_number: 1
// CHECK-NEXT: }

substrait.plan_version 3:2:1

// -----

// CHECK-LABEL: version {
// CHECK-NOT: git_hash
// CHECK-NOT: producer
// CHECK: }

substrait.plan_version 1:2:3 git_hash "" producer ""
Loading