diff --git a/tt_torch/csrc/bindings.cpp b/tt_torch/csrc/bindings.cpp index c97d889..b3d8164 100644 --- a/tt_torch/csrc/bindings.cpp +++ b/tt_torch/csrc/bindings.cpp @@ -168,16 +168,18 @@ std::string compile_stable_hlo_to_ttir(std::string_view code) { return ret; } -std::tuple +std::tuple compile_ttir_to_bytestream(std::string_view code) { - auto [binary, ttnn] = tt::torch::compileTTIRToTTNN(code); + auto [binary_ptr, ttnn] = tt::torch::compileTTIRToTTNN(code); auto size = ::flatbuffers::GetSizePrefixedBufferLength( - static_cast(binary->get())); + static_cast(binary_ptr->get())); + tt::runtime::Binary binary = tt::runtime::Binary(*binary_ptr); - std::string data_str(static_cast(binary->get()), size); - delete binary; + std::string json = binary.asJson(); + std::string data_str(static_cast(binary_ptr->get()), size); + delete binary_ptr; - return std::make_tuple(py::bytes(data_str), ttnn); + return std::make_tuple(py::bytes(data_str), ttnn, json); } py::bytes compile_stablehlo_to_bytestream(std::string_view code) { diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index ab90cae..7e14b98 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -79,8 +79,8 @@ def compile_process(receiver, sender, ttir_event, ttnn_event): ttir = tt_mlir.compile_stable_hlo_to_ttir(asm) sender.put({"ttir": ttir}) ttir_event.wait() - binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir) - sender.put({"binary": binary, "ttnn": ttnn}) + binary, ttnn, json = tt_mlir.compile_ttir_to_bytestream(ttir) + sender.put({"binary": binary, "ttnn": ttnn, "binary_json": json}) ttnn_event.wait() sys.exit(0) @@ -275,7 +275,7 @@ def compile_op(self, node, *inputs, **kwargs): if "binary" in result: binary = result["binary"] op.binary = binary - op.json = tt_mlir.bytestream_to_json(binary) + op.json = result["binary_json"] op.add_ttnn_graph(result["ttnn"]) ttnn_event.set() op.compilation_status = OpCompilationStatus.CONVERTED_TO_TTNN @@ -549,7 +549,7 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config): if compiler_config.enable_intermediate_verification: executor.register_intermediate_callback(verify_golden_callback) - binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir) + binary, ttnn, _ = tt_mlir.compile_ttir_to_bytestream(ttir) if dump_intermediates: print("TTNN module", file=sys.stderr)