Skip to content

Commit

Permalink
feat: elide attribute type with new TypeInferableAttrInterface
Browse files Browse the repository at this point in the history
This allows to omit the type of an attribute from the assembly if it can
be inferred from the attribute value. For example, the type is redundant
in `#substrait.timestamp<100us> : !substrait.timestamp`. The built-in
`TypedAttrInterface`, however, forces the appearance of the type in the
assembly. The new interface is almost identical but does not enforce it.
The PR also makes the two timestamp, the date, and the time attributes
implement that interface.

Signed-off-by: Ingo Müller <[email protected]>
  • Loading branch information
ingomueller-net committed Jan 29, 2025
1 parent 7c55b1b commit 60dcdaa
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 47 deletions.
2 changes: 2 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ add_dependencies(MLIRSubstraitDialect MLIRSubstraitAttrsIncGen)

# Add interfaces.
set(LLVM_TARGET_DEFINITIONS SubstraitInterfaces.td)
mlir_tablegen(SubstraitAttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(SubstraitAttrInterfaces.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(SubstraitOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(SubstraitOpInterfaces.cpp.inc -gen-op-interface-defs)
mlir_tablegen(SubstraitTypeInterfaces.h.inc -gen-type-interface-decls)
Expand Down
10 changes: 10 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/Substrait.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsDialect.h.inc" // IWYU: export

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitAttrInterfaces.h.inc" // IWYU: export
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.h.inc" // IWYU: export
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.h.inc" // IWYU: export

Expand All @@ -31,4 +32,13 @@
#define GET_OP_CLASSES
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOps.h.inc" // IWYU: export

namespace mlir::substrait {

/// Returns the `Type` of the attribute through the `TypedAttrInterface` or the
/// `TypeInferableAttrInterface`. Returns an empty `Type` if the given attribute
/// does not implement one of the two interfaces.
Type getAttrType(Attribute attr);

} // namespace mlir::substrait

#endif // SUBSTRAIT_MLIR_DIALECT_SUBSTRAIT_IR_SUBSTRAIT_H
55 changes: 25 additions & 30 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,24 @@ include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"

// Base class for Substrait dialect attribute types.
class Substrait_Attr<string name, string typeMnemonic, list<Trait> traits = []>
class Substrait_Attr<string name, string attrMnemonic, list<Trait> traits = []>
: AttrDef<Substrait_Dialect, name, traits> {
let mnemonic = typeMnemonic;
let mnemonic = attrMnemonic;
}

// Base class for Substrait dialect attribute types that have a statically known
// value type.
class Substrait_StaticallyTypedAttr<
string name, string attrMnemonic, string typeName, list<Trait> traits = []>
: Substrait_Attr<name, attrMnemonic, traits # [
DeclareAttrInterfaceMethods<TypeInferableAttrInterface>
]> {
let extraClassDeclaration = [{
/// Implement TypeInferableAttrInterface.
::mlir::Type getType() {
return ::mlir::substrait::}] # typeName # [{::get(getContext());
}
}];
}

def Substrait_AdvancedExtensionAttr
Expand All @@ -34,64 +49,44 @@ def Substrait_AdvancedExtensionAttr
let genVerifyDecl = 1;
}

def Substrait_DateAttr : Substrait_Attr<"Date", "date",
[TypedAttrInterface]> {
def Substrait_DateAttr
: Substrait_StaticallyTypedAttr<"Date", "date", "DateType"> {
let summary = "Substrait date type";
let description = [{
This type represents a substrait date attribute type.
}];
let parameters = (ins "int32_t":$value);
let assemblyFormat = [{ `<` $value `>` }];
let extraClassDeclaration = [{
::mlir::Type getType() const {
return DateType::get(getContext());
}
}];
}

def Substrait_TimeAttr : Substrait_Attr<"Time", "time",
[TypedAttrInterface]> {
def Substrait_TimeAttr :
Substrait_StaticallyTypedAttr<"Time", "time", "TimeType"> {
let summary = "Substrait time type";
let description = [{
This type represents a substrait time attribute type.
}];
let parameters = (ins "int64_t":$value);
let assemblyFormat = [{ `<` $value `` `us` `>` }];
let extraClassDeclaration = [{
::mlir::Type getType() const {
return TimeType::get(getContext());
}
}];
}

def Substrait_TimestampAttr : Substrait_Attr<"Timestamp", "timestamp",
[TypedAttrInterface]> {
def Substrait_TimestampAttr
: Substrait_StaticallyTypedAttr<"Timestamp", "timestamp", "TimestampType"> {
let summary = "Substrait timezone-unaware timestamp type";
let description = [{
This type represents a substrait timezone-unaware timestamp attribute type.
}];
let parameters = (ins "int64_t":$value);
let assemblyFormat = [{ `<` $value `` `us` `>` }];
let extraClassDeclaration = [{
::mlir::Type getType() const {
return TimestampType::get(getContext());
}
}];
}

def Substrait_TimestampTzAttr : Substrait_Attr<"TimestampTz", "timestamp_tz",
[TypedAttrInterface]> {
def Substrait_TimestampTzAttr : Substrait_StaticallyTypedAttr<
"TimestampTz", "timestamp_tz", "TimestampTzType"> {
let summary = "Substrait timezone-aware timestamp type";
let description = [{
This type represents a substrait timezone-aware timestamp attribute type.
}];
let parameters = (ins "int64_t":$value);
let assemblyFormat = [{ `<` $value `` `us` `>` }];
let extraClassDeclaration = [{
::mlir::Type getType() const {
return TimestampTzType::get(getContext());
}
}];
}

/// Attributes of currently supported atomic types, listed in order of substrait
Expand Down
20 changes: 20 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@

include "mlir/IR/OpBase.td"

def TypeInferableAttrInterface : AttrInterface<"TypeInferableAttrInterface"> {
let cppNamespace = "::mlir::substrait";
let description = [{
This interface is used for attributes that have a type that can be inferred
from the instance of the attribute. It is similar to the built-in
`TypedAttrInterface` in that that type is understood to represent the type
of the data contained in the attribute. However, it is different in that
`TypedAttrInterface` is typically used for cases where the type is a
parameter of the attribute such that there can be attribute instances with
the same value but different types. With this interface, the type must be
inferable from the value such that two instances with the same value always
have the same type. Crucially, this allows to elide the type in the assembly
format of the attribute.
}];
let methods = [InterfaceMethod<
"Get the attribute's type",
"::mlir::Type", "getType"
>];
}

def Substrait_ExpressionOpInterface : OpInterface<"ExpressionOpInterface"> {
let description = [{
Interface for any expression in a Substrait plan. This corresponds to an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define SUBSTRAIT_DIALECT_SUBSTRAIT_IR_SUBSTRAITTYPES

include "substrait-mlir/Dialect/Substrait/IR/SubstraitDialect.td"
include "substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td"
include "mlir/IR/CommonTypeConstraints.td"

// Base class for Substrait dialect types.
Expand Down
29 changes: 22 additions & 7 deletions lib/Dialect/Substrait/IR/Substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

#include "substrait-mlir/Dialect/Substrait/IR/Substrait.h"

#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep

using namespace mlir;
using namespace mlir::substrait;
Expand All @@ -38,6 +37,22 @@ void SubstraitDialect::initialize() {
>();
}

//===----------------------------------------------------------------------===//
// Free functions
//===----------------------------------------------------------------------===//

namespace mlir::substrait {

Type getAttrType(Attribute attr) {
if (auto typedAttr = mlir::dyn_cast<TypedAttr>(attr))
return typedAttr.getType();
if (auto typedAttr = mlir::dyn_cast<TypeInferableAttrInterface>(attr))
return typedAttr.getType();
return Type();
}

} // namespace mlir::substrait

//===----------------------------------------------------------------------===//
// Substrait attributes
//===----------------------------------------------------------------------===//
Expand All @@ -62,6 +77,7 @@ LogicalResult AdvancedExtensionAttr::verify(
// Substrait interfaces
//===----------------------------------------------------------------------===//

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitAttrInterfaces.cpp.inc"
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.cpp.inc"
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.cpp.inc"

Expand Down Expand Up @@ -297,15 +313,14 @@ LiteralOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
OpaqueProperties properties, RegionRange regions,
llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
auto *typedProperties = properties.as<Properties *>();
Attribute valueAttr = typedProperties->getValue();

auto attr = llvm::dyn_cast<TypedAttr>(typedProperties->getValue());
if (!attr)
Type resultType = getAttrType(valueAttr);
if (!resultType)
return emitOptionalError(loc, "unsuited attribute for literal value: ",
typedProperties->getValue());

Type resultType = attr.getType();
inferredReturnTypes.emplace_back(resultType);

return success();
}

Expand Down
4 changes: 2 additions & 2 deletions lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,8 +626,8 @@ SubstraitExporter::exportOperation(FilterOp op) {
FailureOr<std::unique_ptr<Expression>>
SubstraitExporter::exportOperation(LiteralOp op) {
// Build `Literal` message depending on type.
auto value = llvm::cast<TypedAttr>(op.getValue());
mlir::Type literalType = value.getType();
Attribute value = op.getValue();
mlir::Type literalType = getAttrType(value);
auto literal = std::make_unique<Expression::Literal>();

// `IntegerType`s.
Expand Down
16 changes: 8 additions & 8 deletions test/Dialect/Substrait/literal.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
// CHECK: %[[V0:.*]] = named_table
// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.time> {
// CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.time<200000000us> : !substrait.time
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.time<200000000us>{{$}}
// CHECK-NEXT: yield %[[V2]] : !substrait.time
// CHECK-NEXT: }
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.time>
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.time>

substrait.plan version 0 : 42 : 1 {
relation {
Expand All @@ -19,7 +19,7 @@ substrait.plan version 0 : 42 : 1 {
%time = literal #substrait.time<200000000us> : !substrait.time
yield %time : !substrait.time
}
yield %1 : tuple<si1, !substrait.time>
yield %1 : tuple<si1, !substrait.time>
}
}

Expand All @@ -30,7 +30,7 @@ substrait.plan version 0 : 42 : 1 {
// CHECK: %[[V0:.*]] = named_table
// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.date> {
// CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.date<200000000> : !substrait.date
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.date<200000000>{{$}}
// CHECK-NEXT: yield %[[V2]] : !substrait.date
// CHECK-NEXT: }
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.date>
Expand All @@ -43,7 +43,7 @@ substrait.plan version 0 : 42 : 1 {
%date = literal #substrait.date<200000000> : !substrait.date
yield %date : !substrait.date
}
yield %1 : tuple<si1, !substrait.date>
yield %1 : tuple<si1, !substrait.date>
}
}

Expand All @@ -54,8 +54,8 @@ substrait.plan version 0 : 42 : 1 {
// CHECK: %[[V0:.*]] = named_table
// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.timestamp, !substrait.timestamp_tz> {
// CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.timestamp<10000000000us>
// CHECK-NEXT: %[[V3:.*]] = literal #substrait.timestamp_tz<10000000000us>
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.timestamp<10000000000us>{{$}}
// CHECK-NEXT: %[[V3:.*]] = literal #substrait.timestamp_tz<10000000000us>{{$}}
// CHECK-NEXT: yield %[[V2]], %[[V3]] : !substrait.timestamp, !substrait.timestamp_tz
// CHECK-NEXT: }
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.timestamp, !substrait.timestamp_tz>
Expand All @@ -65,7 +65,7 @@ substrait.plan version 0 : 42 : 1 {
%0 = named_table @t1 as ["a"] : tuple<si1>
%1 = project %0 : tuple<si1> -> tuple<si1, !substrait.timestamp, !substrait.timestamp_tz> {
^bb0(%arg : tuple<si1>):
%timestamp = literal #substrait.timestamp<10000000000us>
%timestamp = literal #substrait.timestamp<10000000000us>
%timestamp_tz = literal #substrait.timestamp_tz<10000000000us>
yield %timestamp, %timestamp_tz : !substrait.timestamp, !substrait.timestamp_tz
}
Expand Down

0 comments on commit 60dcdaa

Please sign in to comment.