Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: elide attribute type with new StaticallyTypedAttrInterface #59

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ add_dependencies(MLIRSubstraitDialect MLIRSubstraitAttrsIncGen)

# Add interfaces.
set(LLVM_TARGET_DEFINITIONS SubstraitInterfaces.td)
mlir_tablegen(SubstraitAttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(SubstraitAttrInterfaces.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(SubstraitOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(SubstraitOpInterfaces.cpp.inc -gen-op-interface-defs)
mlir_tablegen(SubstraitTypeInterfaces.h.inc -gen-type-interface-decls)
Expand Down
10 changes: 10 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/Substrait.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsDialect.h.inc" // IWYU: export

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitAttrInterfaces.h.inc" // IWYU: export
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.h.inc" // IWYU: export
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.h.inc" // IWYU: export

Expand All @@ -31,4 +32,13 @@
#define GET_OP_CLASSES
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOps.h.inc" // IWYU: export

namespace mlir::substrait {

/// Returns the `Type` of the attribute through the `TypedAttrInterface` or the
/// `TypeInferableAttrInterface`. Returns an empty `Type` if the given attribute
/// does not implement one of the two interfaces.
Type getAttrType(Attribute attr);

} // namespace mlir::substrait

#endif // SUBSTRAIT_MLIR_DIALECT_SUBSTRAIT_IR_SUBSTRAIT_H
56 changes: 26 additions & 30 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,24 @@ include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"

// Base class for Substrait dialect attribute types.
class Substrait_Attr<string name, string typeMnemonic, list<Trait> traits = []>
class Substrait_Attr<string name, string attrMnemonic, list<Trait> traits = []>
: AttrDef<Substrait_Dialect, name, traits> {
let mnemonic = typeMnemonic;
let mnemonic = attrMnemonic;
}

// Base class for Substrait dialect attribute types that have a statically known
// value type.
class Substrait_StaticallyTypedAttr<string name, string attrMnemonic,
string typeName, list<Trait> traits = []>
: Substrait_Attr<
name, attrMnemonic,
traits#[DeclareAttrInterfaceMethods<TypeInferableAttrInterface>]> {
let extraClassDeclaration = [{
/// Implement TypeInferableAttrInterface.
::mlir::Type getType() {
return ::mlir::substrait::}]#typeName#[{::get(getContext());
}
}];
}

def Substrait_AdvancedExtensionAttr
Expand All @@ -34,64 +49,45 @@ def Substrait_AdvancedExtensionAttr
let genVerifyDecl = 1;
}

def Substrait_DateAttr : Substrait_Attr<"Date", "date",
[TypedAttrInterface]> {
def Substrait_DateAttr
: Substrait_StaticallyTypedAttr<"Date", "date", "DateType"> {
let summary = "Substrait date type";
let description = [{
This type represents a substrait date attribute type.
}];
let parameters = (ins "int32_t":$value);
let assemblyFormat = [{ `<` $value `>` }];
let extraClassDeclaration = [{
::mlir::Type getType() const {
return DateType::get(getContext());
}
}];
}

def Substrait_TimeAttr : Substrait_Attr<"Time", "time",
[TypedAttrInterface]> {
def Substrait_TimeAttr
: Substrait_StaticallyTypedAttr<"Time", "time", "TimeType"> {
let summary = "Substrait time type";
let description = [{
This type represents a substrait time attribute type.
}];
let parameters = (ins "int64_t":$value);
let assemblyFormat = [{ `<` $value `` `us` `>` }];
let extraClassDeclaration = [{
::mlir::Type getType() const {
return TimeType::get(getContext());
}
}];
}

def Substrait_TimestampAttr : Substrait_Attr<"Timestamp", "timestamp",
[TypedAttrInterface]> {
def Substrait_TimestampAttr
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: colon on this line to match the others.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

: Substrait_StaticallyTypedAttr<"Timestamp", "timestamp", "TimestampType"> {
let summary = "Substrait timezone-unaware timestamp type";
let description = [{
This type represents a substrait timezone-unaware timestamp attribute type.
}];
let parameters = (ins "int64_t":$value);
let assemblyFormat = [{ `<` $value `` `us` `>` }];
let extraClassDeclaration = [{
::mlir::Type getType() const {
return TimestampType::get(getContext());
}
}];
}

def Substrait_TimestampTzAttr : Substrait_Attr<"TimestampTz", "timestamp_tz",
[TypedAttrInterface]> {
def Substrait_TimestampTzAttr
: Substrait_StaticallyTypedAttr<"TimestampTz", "timestamp_tz",
"TimestampTzType"> {
let summary = "Substrait timezone-aware timestamp type";
let description = [{
This type represents a substrait timezone-aware timestamp attribute type.
}];
let parameters = (ins "int64_t":$value);
let assemblyFormat = [{ `<` $value `` `us` `>` }];
let extraClassDeclaration = [{
::mlir::Type getType() const {
return TimestampTzType::get(getContext());
}
}];
}

/// Attributes of currently supported atomic types, listed in order of substrait
Expand Down
20 changes: 20 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@

include "mlir/IR/OpBase.td"

def TypeInferableAttrInterface : AttrInterface<"TypeInferableAttrInterface"> {
let cppNamespace = "::mlir::substrait";
let description = [{
This interface is used for attributes that have a type that can be inferred
from the instance of the attribute. It is similar to the built-in
`TypedAttrInterface` in that that type is understood to represent the type
of the data contained in the attribute. However, it is different in that
`TypedAttrInterface` is typically used for cases where the type is a
parameter of the attribute such that there can be attribute instances with
the same value but different types. With this interface, the type must be
inferable from the value such that two instances with the same value always
have the same type. Crucially, this allows to elide the type in the assembly
format of the attribute.
}];
let methods = [InterfaceMethod<
"Get the attribute's type",
"::mlir::Type", "getType"
>];
}

def Substrait_ExpressionOpInterface : OpInterface<"ExpressionOpInterface"> {
let description = [{
Interface for any expression in a Substrait plan. This corresponds to an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define SUBSTRAIT_DIALECT_SUBSTRAIT_IR_SUBSTRAITTYPES

include "substrait-mlir/Dialect/Substrait/IR/SubstraitDialect.td"
include "substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td"
include "mlir/IR/CommonTypeConstraints.td"

// Base class for Substrait dialect types.
Expand Down
29 changes: 22 additions & 7 deletions lib/Dialect/Substrait/IR/Substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

#include "substrait-mlir/Dialect/Substrait/IR/Substrait.h"

#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep

using namespace mlir;
using namespace mlir::substrait;
Expand All @@ -38,6 +37,22 @@ void SubstraitDialect::initialize() {
>();
}

//===----------------------------------------------------------------------===//
// Free functions
//===----------------------------------------------------------------------===//

namespace mlir::substrait {

Type getAttrType(Attribute attr) {
if (auto typedAttr = mlir::dyn_cast<TypedAttr>(attr))
return typedAttr.getType();
if (auto typedAttr = mlir::dyn_cast<TypeInferableAttrInterface>(attr))
return typedAttr.getType();
return Type();
}

} // namespace mlir::substrait

//===----------------------------------------------------------------------===//
// Substrait attributes
//===----------------------------------------------------------------------===//
Expand All @@ -62,6 +77,7 @@ LogicalResult AdvancedExtensionAttr::verify(
// Substrait interfaces
//===----------------------------------------------------------------------===//

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitAttrInterfaces.cpp.inc"
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.cpp.inc"
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.cpp.inc"

Expand Down Expand Up @@ -297,15 +313,14 @@ LiteralOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
OpaqueProperties properties, RegionRange regions,
llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
auto *typedProperties = properties.as<Properties *>();
Attribute valueAttr = typedProperties->getValue();

auto attr = llvm::dyn_cast<TypedAttr>(typedProperties->getValue());
if (!attr)
Type resultType = getAttrType(valueAttr);
if (!resultType)
return emitOptionalError(loc, "unsuited attribute for literal value: ",
typedProperties->getValue());

Type resultType = attr.getType();
inferredReturnTypes.emplace_back(resultType);

return success();
}

Expand Down
4 changes: 2 additions & 2 deletions lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,8 +626,8 @@ SubstraitExporter::exportOperation(FilterOp op) {
FailureOr<std::unique_ptr<Expression>>
SubstraitExporter::exportOperation(LiteralOp op) {
// Build `Literal` message depending on type.
auto value = llvm::cast<TypedAttr>(op.getValue());
mlir::Type literalType = value.getType();
Attribute value = op.getValue();
mlir::Type literalType = getAttrType(value);
auto literal = std::make_unique<Expression::Literal>();

// `IntegerType`s.
Expand Down
16 changes: 8 additions & 8 deletions test/Dialect/Substrait/literal.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
// CHECK: %[[V0:.*]] = named_table
// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.time> {
// CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.time<200000000us> : !substrait.time
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.time<200000000us>{{$}}
// CHECK-NEXT: yield %[[V2]] : !substrait.time
// CHECK-NEXT: }
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.time>
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.time>

substrait.plan version 0 : 42 : 1 {
relation {
Expand All @@ -19,7 +19,7 @@ substrait.plan version 0 : 42 : 1 {
%time = literal #substrait.time<200000000us> : !substrait.time
yield %time : !substrait.time
}
yield %1 : tuple<si1, !substrait.time>
yield %1 : tuple<si1, !substrait.time>
}
}

Expand All @@ -30,7 +30,7 @@ substrait.plan version 0 : 42 : 1 {
// CHECK: %[[V0:.*]] = named_table
// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.date> {
// CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.date<200000000> : !substrait.date
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.date<200000000>{{$}}
// CHECK-NEXT: yield %[[V2]] : !substrait.date
// CHECK-NEXT: }
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.date>
Expand All @@ -43,7 +43,7 @@ substrait.plan version 0 : 42 : 1 {
%date = literal #substrait.date<200000000> : !substrait.date
yield %date : !substrait.date
}
yield %1 : tuple<si1, !substrait.date>
yield %1 : tuple<si1, !substrait.date>
}
}

Expand All @@ -54,8 +54,8 @@ substrait.plan version 0 : 42 : 1 {
// CHECK: %[[V0:.*]] = named_table
// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.timestamp, !substrait.timestamp_tz> {
// CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.timestamp<10000000000us>
// CHECK-NEXT: %[[V3:.*]] = literal #substrait.timestamp_tz<10000000000us>
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.timestamp<10000000000us>{{$}}
// CHECK-NEXT: %[[V3:.*]] = literal #substrait.timestamp_tz<10000000000us>{{$}}
// CHECK-NEXT: yield %[[V2]], %[[V3]] : !substrait.timestamp, !substrait.timestamp_tz
// CHECK-NEXT: }
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.timestamp, !substrait.timestamp_tz>
Expand All @@ -65,7 +65,7 @@ substrait.plan version 0 : 42 : 1 {
%0 = named_table @t1 as ["a"] : tuple<si1>
%1 = project %0 : tuple<si1> -> tuple<si1, !substrait.timestamp, !substrait.timestamp_tz> {
^bb0(%arg : tuple<si1>):
%timestamp = literal #substrait.timestamp<10000000000us>
%timestamp = literal #substrait.timestamp<10000000000us>
%timestamp_tz = literal #substrait.timestamp_tz<10000000000us>
yield %timestamp, %timestamp_tz : !substrait.timestamp, !substrait.timestamp_tz
}
Expand Down
8 changes: 4 additions & 4 deletions test/Target/SubstraitPB/Import/literal.textpb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# CHECK: %[[V0:.*]] = named_table
# CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.time> {
# CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.time<200000000us> : !substrait.time
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.time<200000000us>
# CHECK-NEXT: yield %[[V2]] : !substrait.time
# CHECK-NEXT: }
# CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.time>
Expand Down Expand Up @@ -69,7 +69,7 @@ version {
# CHECK: %[[V0:.*]] = named_table
# CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.date> {
# CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.date<200000000> : !substrait.date
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.date<200000000>
# CHECK-NEXT: yield %[[V2]] : !substrait.date
# CHECK-NEXT: }
# CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.date>
Expand Down Expand Up @@ -123,8 +123,8 @@ version {
# CHECK: %[[V0:.*]] = named_table
# CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.timestamp, !substrait.timestamp_tz> {
# CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.timestamp<10000000000us>
# CHECK-NEXT: %[[V3:.*]] = literal #substrait.timestamp_tz<10000000000us>
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.timestamp<10000000000us>
# CHECK-NEXT: %[[V3:.*]] = literal #substrait.timestamp_tz<10000000000us>
# CHECK-NEXT: yield %[[V2]], %[[V3]] : !substrait.timestamp, !substrait.timestamp_tz
# CHECK-NEXT: }
# CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.timestamp, !substrait.timestamp_tz>
Expand Down