diff --git a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td index 0a5651d8..573524f7 100644 --- a/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td +++ b/include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td @@ -166,21 +166,31 @@ def Substrait_PlanOp : Substrait_Op<"plan", [ UI32Attr:$patch_number, DefaultValuedAttr:$git_hash, DefaultValuedAttr:$producer, - OptionalAttr:$advanced_extension + OptionalAttr:$advanced_extension, + OptionalAttr:$expected_type_urls ); let regions = (region RegionOf:$body); let assemblyFormat = [{ `version` $major_number `:` $minor_number `:` $patch_number (`git_hash` $git_hash^)? (`producer` $producer^)? (`advanced_extension` `` $advanced_extension^)? + (`expected_type_urls` `` $expected_type_urls^)? attr-dict-with-keyword $body }]; let builders = [ OpBuilder<(ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch), [{ build($_builder, $_state, major, minor, patch, /*git_hash=*/StringAttr(), /*producer*/StringAttr(), - /*advanced_extension=*/AdvancedExtensionAttr()); - }]> + /*advanced_extension=*/AdvancedExtensionAttr(), + /*expected_type_urls=*/ArrayAttr()); + }]>, + OpBuilder< + (ins "uint32_t":$major, "uint32_t":$minor, "uint32_t":$patch, + "::llvm::StringRef":$git_hash, "::llvm::StringRef":$producer, + "::mlir::substrait::AdvancedExtensionAttr":$advanced_extension), [{ + build($_builder, $_state, major, minor, patch, git_hash, producer, + advanced_extension, /*expected_type_urls=*/ArrayAttr()); + }]>, ]; let extraClassDefinition = [{ /// Implement OpAsmOpInterface. diff --git a/lib/Target/SubstraitPB/Export.cpp b/lib/Target/SubstraitPB/Export.cpp index 93dab7ba..9f25a995 100644 --- a/lib/Target/SubstraitPB/Export.cpp +++ b/lib/Target/SubstraitPB/Export.cpp @@ -895,6 +895,13 @@ FailureOr> SubstraitExporter::exportOperation(PlanOp op) { plan->set_allocated_advanced_extensions(extension.release()); } + // Add `expected_type_urls` to plan if present. + if (op.getExpectedTypeUrls()) { + ArrayAttr expected_type_urls = op.getExpectedTypeUrls().value(); + for (auto expected_type_url : expected_type_urls.getAsRange()) + plan->add_expected_type_urls(expected_type_url.str()); + } + // Add `extension_uris` to plan. { AnchorUniquer anchorUniquer("extension_uri.", anchorsByOp); diff --git a/lib/Target/SubstraitPB/Import.cpp b/lib/Target/SubstraitPB/Import.cpp index 6e40950b..59d942c1 100644 --- a/lib/Target/SubstraitPB/Import.cpp +++ b/lib/Target/SubstraitPB/Import.cpp @@ -520,6 +520,15 @@ static FailureOr importPlan(ImplicitLocOpBuilder builder, version.git_hash(), version.producer(), advancedExtensionAttr); planOp.getBody().push_back(new Block()); + // Import `expected_type_urls` if present. + SmallVector expected_type_urls; + for (const std::string &expected_type_url : message.expected_type_urls()) { + expected_type_urls.push_back(StringAttr::get(context, expected_type_url)); + } + if (!expected_type_urls.empty()) { + planOp.setExpectedTypeUrlsAttr(ArrayAttr::get(context, expected_type_urls)); + } + OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToEnd(&planOp.getBody().front()); diff --git a/test/Dialect/Substrait/plan.mlir b/test/Dialect/Substrait/plan.mlir index ae043bad..990a40a8 100644 --- a/test/Dialect/Substrait/plan.mlir +++ b/test/Dialect/Substrait/plan.mlir @@ -100,3 +100,15 @@ substrait.plan version 0 : 42 : 1 optimization = "protobuf message" : !substrait.any<"http://some.url/with/type.proto"> enhancement = "other protobuf message" : !substrait.any<"http://other.url/with/type.proto"> {} + +// ----- + +// CHECK: substrait.plan +// CHECK-SAME: expected_type_urls +// CHECK-SAME: ["http://some.url/with/type.proto", "http://other.url/with/type.proto"] +// CHECK-NEXT: } + +substrait.plan version 0 : 42 : 1 + expected_type_urls + ["http://some.url/with/type.proto", "http://other.url/with/type.proto"] +{} diff --git a/test/Target/SubstraitPB/Export/plan.mlir b/test/Target/SubstraitPB/Export/plan.mlir index 9397b0f1..4a03f53b 100644 --- a/test/Target/SubstraitPB/Export/plan.mlir +++ b/test/Target/SubstraitPB/Export/plan.mlir @@ -182,3 +182,15 @@ substrait.plan version 0 : 42 : 1 advanced_extension enhancement = "other protobuf message" : !substrait.any<"http://other.url/with/type.proto"> {} + +// ----- + +// CHECK: expected_type_urls: "http://some.url/with/type.proto" +// CHECK-NEXT: expected_type_urls: "http://other.url/with/type.proto" +// CHECK-NEXT: version + + +substrait.plan version 0 : 42 : 1 + expected_type_urls + ["http://some.url/with/type.proto", "http://other.url/with/type.proto"] +{} diff --git a/test/Target/SubstraitPB/Import/plan.textpb b/test/Target/SubstraitPB/Import/plan.textpb index 6fe649f0..3d741dc5 100644 --- a/test/Target/SubstraitPB/Import/plan.textpb +++ b/test/Target/SubstraitPB/Import/plan.textpb @@ -246,3 +246,18 @@ version { minor_number: 42 patch_number: 1 } + +# ----- + +# CHECK-LABEL: substrait.plan +# CHECK-SAME: expected_type_urls +# CHECK-SAME: ["http://some.url/with/type.proto", +# CHECK-SAME: "http://other.url/with/type.proto"] +# CHECK-NEXT: } + +expected_type_urls: "http://some.url/with/type.proto" +expected_type_urls: "http://other.url/with/type.proto" +version { + minor_number: 42 + patch_number: 1 +}