Skip to content

Commit

Permalink
WIP: Elide attribute type with new StaticallyTypedAttrInterface
Browse files Browse the repository at this point in the history
Signed-off-by: Ingo Müller <[email protected]>
  • Loading branch information
ingomueller-net committed Jan 17, 2025
1 parent d3b5324 commit d3c98d6
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 29 deletions.
2 changes: 2 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ add_public_tablegen_target(MLIRSubstraitAttrsIncGen)
add_dependencies(MLIRSubstraitDialect MLIRSubstraitAttrsIncGen)

set(LLVM_TARGET_DEFINITIONS SubstraitInterfaces.td)
mlir_tablegen(SubstraitAttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(SubstraitAttrInterfaces.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(SubstraitOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(SubstraitOpInterfaces.cpp.inc -gen-op-interface-defs)
mlir_tablegen(SubstraitTypeInterfaces.h.inc -gen-type-interface-decls)
Expand Down
1 change: 1 addition & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/Substrait.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpsDialect.h.inc" // IWYU: export

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitAttrInterfaces.h.inc" // IWYU: export
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.h.inc" // IWYU: export
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.h.inc" // IWYU: export

Expand Down
22 changes: 22 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,28 @@

include "mlir/IR/OpBase.td"

def TypeInferableAttrInterface : AttrInterface<"TypeInferableAttrInterface"> {
let cppNamespace = "::mlir::substrait";

let description = [{
This interface is used for attributes that have a type that can be inferred
from the instance of the attribute. It is similar to the built-in
`TypedAttrInterface` in that that type is understood to represent the type
of the data contained in the attribute. However, it is different in that
`TypedAttrInterface` is typically used for cases where the type is a
parameter of the attribute such that there can be attribute instances with
the same value but different types. With this interface, the type must be
inferable from the value such that two instances with the same value always
have the same type. Crucially, this allows to elide the type in the assembly
format of the attribute.
}];

let methods = [InterfaceMethod<
"Get the attribute's type",
"::mlir::Type", "getType"
>];
}

def Substrait_ExpressionOpInterface : OpInterface<"ExpressionOpInterface"> {
let description = [{
Interface for any expression in a Substrait plan. This corresponds to an
Expand Down
28 changes: 22 additions & 6 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define SUBSTRAIT_DIALECT_SUBSTRAIT_IR_SUBSTRAITTYPES

include "substrait-mlir/Dialect/Substrait/IR/SubstraitDialect.td"
include "substrait-mlir/Dialect/Substrait/IR/SubstraitInterfaces.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
Expand All @@ -21,9 +22,24 @@ class Substrait_Type<string name, string typeMnemonic, list<Trait> traits = []>
}

// Base class for Substrait dialect attribute types.
class Substrait_Attr<string name, string typeMnemonic, list<Trait> traits = []>
class Substrait_Attr<string name, string attrMnemonic, list<Trait> traits = []>
: AttrDef<Substrait_Dialect, name, traits> {
let mnemonic = typeMnemonic;
let mnemonic = attrMnemonic;
}

// Base class for Substrait dialect attribute types that have a statically known
// value type.
class Substrait_StaticallyTypedAttr<
string name, string attrMnemonic, string typeName, list<Trait> traits = []>
: Substrait_Attr<name, attrMnemonic, traits # [
DeclareAttrInterfaceMethods<TypeInferableAttrInterface>
]> {
let extraClassDeclaration = [{
/// Implement TypeInferableAttrInterface.
::mlir::Type getType() {
return ::mlir::substrait::}] # typeName # [{::get(getContext());
}
}];
}

def Substrait_StringType : Substrait_Type<"String", "string"> {
Expand All @@ -47,8 +63,8 @@ def Substrait_TimestampType : Substrait_Type<"Timestamp", "timestamp"> {
}];
}

def Substrait_TimestampAttr : Substrait_Attr<"Timestamp", "timestamp",
[DeclareAttrInterfaceMethods<TypedAttrInterface>]> {
def Substrait_TimestampAttr
: Substrait_StaticallyTypedAttr<"Timestamp", "timestamp", "TimestampType"> {
let summary = "Substrait timestamp (excluding timezone) type";
let description = [{
This type represents a substrait timestamp (excluding timezone) attribute type.
Expand All @@ -64,8 +80,8 @@ def Substrait_TimestampTzType : Substrait_Type<"TimestampTz", "timestamp_tz"> {
}];
}

def Substrait_TimestampTzAttr : Substrait_Attr<"TimestampTz", "timestamp_tz",
[DeclareAttrInterfaceMethods<TypedAttrInterface>]> {
def Substrait_TimestampTzAttr : Substrait_StaticallyTypedAttr<
"TimestampTz", "timestamp_tz", "TimestampTzType"> {
let summary = "Substrait timestamp (including timezone) type";
let description = [{
This type represents a substrait timestamp (including timezone) attribute type.
Expand Down
31 changes: 11 additions & 20 deletions lib/Dialect/Substrait/IR/Substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ void SubstraitDialect::initialize() {
// Substrait interfaces
//===----------------------------------------------------------------------===//

#include "substrait-mlir/Dialect/Substrait/IR/SubstraitAttrInterfaces.cpp.inc"
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitOpInterfaces.cpp.inc"
#include "substrait-mlir/Dialect/Substrait/IR/SubstraitTypeInterfaces.cpp.inc"

Expand Down Expand Up @@ -82,20 +83,6 @@ void printCountAsAll(OpAsmPrinter &printer, Operation *op, IntegerAttr count) {
printer << count.getValue();
}

//===----------------------------------------------------------------------===//
// Substrait attributes
//===----------------------------------------------------------------------===//

/// Implement the getType method for custom type `TimestampAttr`.
::mlir::Type TimestampAttr::getType() const {
return TimestampType::get(getContext());
}

/// Implement the getType method for custom type `TimestampTzAttr`.
::mlir::Type TimestampTzAttr::getType() const {
return TimestampTzType::get(getContext());
}

//===----------------------------------------------------------------------===//
// Substrait operations
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -297,15 +284,19 @@ LiteralOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
OpaqueProperties properties, RegionRange regions,
llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
auto *typedProperties = properties.as<Properties *>();
Attribute valueAttr = typedProperties->getValue();

auto attr = llvm::dyn_cast<TypedAttr>(typedProperties->getValue());
if (!attr)
return emitOptionalError(loc, "unsuited attribute for literal value: ",
typedProperties->getValue());
Type resultType;
if (auto attr = llvm::dyn_cast<TypedAttr>(valueAttr))
resultType = attr.getType();
if (auto attr = llvm::dyn_cast<TypeInferableAttrInterface>(valueAttr))
resultType = attr.getType();

Type resultType = attr.getType();
inferredReturnTypes.emplace_back(resultType);
if (!resultType)
emitOptionalError(loc, "unsuited attribute for literal value: ",
typedProperties->getValue());

inferredReturnTypes.emplace_back(resultType);
return success();
}

Expand Down
6 changes: 3 additions & 3 deletions test/Dialect/Substrait/literal.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
// CHECK: %[[V0:.*]] = named_table
// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.timestamp, !substrait.timestamp_tz> {
// CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.timestamp<10000000000us> : !substrait.timestamp
// CHECK-NEXT: %[[V3:.*]] = literal #substrait.timestamp_tz<10000000000us> : !substrait.timestamp_tz
// CHECK-NEXT: %[[V2:.*]] = literal #substrait.timestamp<10000000000us>{{$}}
// CHECK-NEXT: %[[V3:.*]] = literal #substrait.timestamp_tz<10000000000us>{{$}}
// CHECK-NEXT: yield %[[V2]], %[[V3]] : !substrait.timestamp, !substrait.timestamp_tz
// CHECK-NEXT: }
// CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.timestamp, !substrait.timestamp_tz>
Expand All @@ -17,7 +17,7 @@ substrait.plan version 0 : 42 : 1 {
%0 = named_table @t1 as ["a"] : tuple<si1>
%1 = project %0 : tuple<si1> -> tuple<si1, !substrait.timestamp, !substrait.timestamp_tz> {
^bb0(%arg : tuple<si1>):
%timestamp = literal #substrait.timestamp<10000000000us>
%timestamp = literal #substrait.timestamp<10000000000us>
%timestamp_tz = literal #substrait.timestamp_tz<10000000000us>
yield %timestamp, %timestamp_tz : !substrait.timestamp, !substrait.timestamp_tz
}
Expand Down

0 comments on commit d3c98d6

Please sign in to comment.