Skip to content

Commit

Permalink
generate and pass binary json before passing binary to python
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT committed Jan 31, 2025
1 parent d27990d commit 751cca7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
14 changes: 8 additions & 6 deletions tt_torch/csrc/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,18 @@ std::string compile_stable_hlo_to_ttir(std::string_view code) {
return ret;
}

std::tuple<py::bytes, std::string>
std::tuple<py::bytes, std::string, std::string>
compile_ttir_to_bytestream(std::string_view code) {
auto [binary, ttnn] = tt::torch::compileTTIRToTTNN(code);
auto [binary_ptr, ttnn] = tt::torch::compileTTIRToTTNN(code);
auto size = ::flatbuffers::GetSizePrefixedBufferLength(
static_cast<const uint8_t *>(binary->get()));
static_cast<const uint8_t *>(binary_ptr->get()));
tt::runtime::Binary binary = tt::runtime::Binary(*binary_ptr);

std::string data_str(static_cast<const char *>(binary->get()), size);
delete binary;
std::string json = binary.asJson();
std::string data_str(static_cast<const char *>(binary_ptr->get()), size);
delete binary_ptr;

return std::make_tuple(py::bytes(data_str), ttnn);
return std::make_tuple(py::bytes(data_str), ttnn, json);
}

py::bytes compile_stablehlo_to_bytestream(std::string_view code) {
Expand Down
8 changes: 4 additions & 4 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def compile_process(receiver, sender, ttir_event, ttnn_event):
ttir = tt_mlir.compile_stable_hlo_to_ttir(asm)
sender.put({"ttir": ttir})
ttir_event.wait()
binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir)
sender.put({"binary": binary, "ttnn": ttnn})
binary, ttnn, json = tt_mlir.compile_ttir_to_bytestream(ttir)
sender.put({"binary": binary, "ttnn": ttnn, "binary_json": json})
ttnn_event.wait()
sys.exit(0)

Expand Down Expand Up @@ -275,7 +275,7 @@ def compile_op(self, node, *inputs, **kwargs):
if "binary" in result:
binary = result["binary"]
op.binary = binary
op.json = tt_mlir.bytestream_to_json(binary)
op.json = result["binary_json"]
op.add_ttnn_graph(result["ttnn"])
ttnn_event.set()
op.compilation_status = OpCompilationStatus.CONVERTED_TO_TTNN
Expand Down Expand Up @@ -549,7 +549,7 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config):
if compiler_config.enable_intermediate_verification:
executor.register_intermediate_callback(verify_golden_callback)

binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir)
binary, ttnn, _ = tt_mlir.compile_ttir_to_bytestream(ttir)

if dump_intermediates:
print("TTNN module", file=sys.stderr)
Expand Down

0 comments on commit 751cca7

Please sign in to comment.