Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

op-by-op SHLO compiler #203

Open
AleksKnezevic opened this issue Jan 15, 2025 · 3 comments
Open

op-by-op SHLO compiler #203

AleksKnezevic opened this issue Jan 15, 2025 · 3 comments
Assignees

Comments

@AleksKnezevic
Copy link
Contributor

Please crate some infrastructure to compile a stableHLO graph op-by-op through tt-mlir.

@ddilbazTT ddilbazTT marked this as a duplicate of #196 Jan 16, 2025
@ddilbazTT
Copy link
Contributor

The way to install stablehlo is:

pip install --pre torch-mlir torchvision
pip install stablehlo -f https://github.com/openxla/stablehlo/releases/expanded_assets/dev-wheels

The way to use stablehlo is:

import tt_mlir
import mlir.ir as ir
import mlir.dialects.stablehlo as stablehlo
from mlir.ir import Context, Location

with ir.Context() as ctx:
    stablehlo.register_dialect(ctx)
    module = ir.Module.parse(MODULE_STRING)

@ddilbazTT
Copy link
Contributor

ddilbazTT commented Jan 16, 2025

Hi @AleksKnezevic This is the basic implementation so far:

# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

### PRE-REQUISITES
#   pip install --pre torch-mlir torchvision
#   pip install stablehlo -f https://github.com/openxla/stablehlo/releases/expanded_assets/dev-wheels

import mlir
import tt_mlir
import mlir.ir as ir
from mlir.ir import Context, Location, Module
import mlir.dialects.stablehlo as stablehlo
from typing import List, Dict, Any

modules = {}


def get_module_from_str(module_str: str):
    module = None
    with ir.Context() as ctx:
        stablehlo.register_dialect(ctx)
        module = Module.parse(module_str)
    return module


def lower_stablehlo_to_ttnn(stablehlo_ir: str):
    module = get_module_from_str(stablehlo_ir)
    try:
        ttir = tt_mlir.compile_stable_hlo_to_ttir(module.operation.get_asm())
        print("ttir done")
        try:
            binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir)
            print("ttnn done")
            return ttnn
        except Exception as e:
            print("Error: ", e)
            return None
    except Exception as e:
        print("Error: ", e)
        return None


def get_ops_in_module(module: mlir.ir.Module):

    # Iterate through all functions in the module
    for func_op in module.body.operations:
        for block in func_op.regions[0].blocks:
            for op in block.operations:
                inputs = {}
                result_type = None
                if not op.name.startswith(("func.", "return")):
                    if (
                        op.name == "stablehlo.pad"
                        or op.name == "stablehlo.reduce_window"
                    ):
                        continue
                    for operand in op.operands:
                        inputs[operand.get_name()] = str(operand.type)
                    args_str = ", ".join(f"{key}: {typ}" for key, typ in inputs.items())
                    result_type = str(
                        op.result.type
                    )  # assuming there is only one return value
                    result_name = str(op.result.get_name())
                    new_module_str = f"""module {{ \n\tfunc.func @main({args_str}) -> {result_type} {{ \n\t\t{str(op)} \n\t\treturn {result_name} : {result_type} \n\t}} \n}}"""
                    modules[result_name] = new_module_str

I had a few questions:

  1. Is the shlo op-by-op compiler going to be standalone (similar to onnx_compile) or integrated into tt_torch/dynamo/backend.py / tt_torch/tools/utils.py ?
  2. Based on the previous question - how should we organize the code? Should we use existing classes (Op/ CompilerConfig/ etc.) and how should we use them? These might need modifications to be used by shlo compiler.
  3. Do we want to compile shlo ops using different backend if they don't compile using ttir/ttnn? Or will we use this compiler only for debugging shlo graphs? It would be great to know what output we want to see.

@AleksKnezevic
Copy link
Contributor Author

This looks great so far @ddilbazTT!

Is the shlo op-by-op compiler going to be standalone (similar to onnx_compile) or integrated into tt_torch/dynamo/backend.py / tt_torch/tools/utils.py ?

We should refactor backend.py into multiple files. There should be a core part which just does the straight compilation of the full module. This can remain in backend.py. We can pull out any common code across all the platforms into a utils, and there should be a separate op-by-op flow for torch-fx and shlo. Can you give the refactoring a go, and I can review? Since this is a rather large change, it might be good to send a diff once you have an initial though on how it could look before getting everything working.

Based on the previous question - how should we organize the code? Should we use existing classes (Op/ CompilerConfig/ etc.) and how should we use them? These might need modifications to be used by shlo compiler.

I think we can use the existing classes and just have a separate compile depth for shlo_op_by_op.

Do we want to compile shlo ops using different backend if they don't compile using ttir/ttnn? Or will we use this compiler only for debugging shlo graphs? It would be great to know what output we want to see.

The goal here is to bring up a shlo graph in ttnn, so if an op doesn't compile, we mark it as such and go onto the next op.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants