From 694b9d03b4754d59b82b029ab20ef4dbd237d342 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Mon, 2 Dec 2024 11:19:50 +0000 Subject: [PATCH] Added support for scalars --- src/common/api_impl.cc | 8 ++++++-- src/common/api_impl.h | 5 +++++ src/common/module_builder.cc | 33 +++++++++++++++++++++++++++++++++ src/common/module_builder.h | 9 +++++++++ tests/infrastructure.py | 7 ------- 5 files changed, 53 insertions(+), 9 deletions(-) diff --git a/src/common/api_impl.cc b/src/common/api_impl.cc index f4b66b78..33db6cfc 100644 --- a/src/common/api_impl.cc +++ b/src/common/api_impl.cc @@ -1047,9 +1047,13 @@ LoadedExecutableInstance::Execute(PJRT_LoadedExecutable_Execute_Args *args) { (void)event; for (size_t i = 0; i < output_specs.size(); ++i) { + bool is_scalar = client_.get_module_builder()->is_output_scalar(i); + // PJRT expects an empty shape for scalars. + std::vector output_shape = + is_scalar ? std::vector() : output_specs[i].shape; auto result_buffer = std::make_unique( - *this->addressable_devices_[dev_index], rt_outputs[i], - output_specs[i].shape, output_specs[i].stride); + *this->addressable_devices_[dev_index], rt_outputs[i], output_shape, + output_specs[i].stride); result_buffer->setType( convertElementTypeToBufferType(output_specs[i].dataType)); DLOG_F(INFO, "Runtime output id: %d", result_buffer->unique_id()); diff --git a/src/common/api_impl.h b/src/common/api_impl.h index 988fdda9..8284c4b9 100644 --- a/src/common/api_impl.h +++ b/src/common/api_impl.h @@ -364,6 +364,11 @@ class ClientInstance { // Advances the timeline, returning (current, next) time point values. std::tuple AdvanceTimeline(); + // Returns the module builder used for this ClientInstance. + const ModuleBuilder *get_module_builder() const { + return module_builder_.get(); + } + protected: std::string cached_platform_name_; std::string cached_platform_version_; diff --git a/src/common/module_builder.cc b/src/common/module_builder.cc index d0999514..182b38b0 100644 --- a/src/common/module_builder.cc +++ b/src/common/module_builder.cc @@ -17,6 +17,7 @@ #include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Parser/Parser.h" @@ -65,6 +66,36 @@ ModuleBuilder::ModuleBuilder() m_context->appendDialectRegistry(registry); } +static bool isScalarType(mlir::Type type) { + if (mlir::isa(type) || mlir::isa(type)) { + return true; + } + if (auto tensorType = mlir::dyn_cast(type)) { + return tensorType.getRank() == 0; + } + return false; +} + +void ModuleBuilder::collectOutputTypes(mlir::ModuleOp &&module) { + m_is_output_scalar.clear(); + for (auto &op : module.getOps()) { + if (auto funcOp = mlir::cast(op)) { + // We care only for return ops of public functions, as that are the ones + // that will produce results in the flatbuffer. + if (funcOp.isPublic()) { + funcOp.walk([&](mlir::Operation *op) { + if (mlir::func::ReturnOp return_op = + mlir::dyn_cast(op)) { + for (auto operand : op->getOperands()) { + m_is_output_scalar.push_back(isScalarType(operand.getType())); + } + } + }); + } + } + } +} + tt_pjrt_status ModuleBuilder::buildModule(const std::string_view &code, const std::string_view &format) { DLOG_F(LOG_DEBUG, "ModuleBuilder::buildModule"); @@ -134,6 +165,8 @@ void ModuleBuilder::convertFromVHLOToSHLO( return; } + collectOutputTypes(mlir_module.get()); + DLOG_F(LOG_DEBUG, "SHLO Module:"); print_module(mlir_module); } diff --git a/src/common/module_builder.h b/src/common/module_builder.h index 70c753ab..18a020f1 100644 --- a/src/common/module_builder.h +++ b/src/common/module_builder.h @@ -33,6 +33,8 @@ class ModuleBuilder { size_t getNumOutputs() const { return m_num_outputs; }; + bool is_output_scalar(int index) const { return m_is_output_scalar[index]; } + private: // Creates VHLO module from the input program code. mlir::OwningOpRef @@ -51,6 +53,10 @@ class ModuleBuilder { void createFlatbufferBinary(const mlir::OwningOpRef &mlir_module); + // Fills up the m_is_output_scalar array with information is the output type + // scalar or not. + void collectOutputTypes(mlir::ModuleOp &&module); + // Prints module to console for debug purposes. static void print_module(mlir::OwningOpRef &mlir_module); @@ -68,6 +74,9 @@ class ModuleBuilder { // Holds status of the last builder action. tt_pjrt_status m_status; + + // For every output, holds if the type is a scalar or not. + std::vector m_is_output_scalar; }; } // namespace tt::pjrt diff --git a/tests/infrastructure.py b/tests/infrastructure.py index 42c16c97..23cf91b0 100644 --- a/tests/infrastructure.py +++ b/tests/infrastructure.py @@ -40,13 +40,6 @@ def compare_tensor_to_golden( ): ret = True - # TODO (issue #81): Remove these reshapes once the PJRT can handle scalars. - if tensor.ndim == 0: - tensor = tensor.reshape((1,)) - if golden.ndim == 0: - with run_on_cpu(): - golden = golden.reshape((1,)) - if tensor.device != golden.device: tensor = jax.device_put(tensor, golden.device)