From 62c94e9f55a37d579a79d5bd11437577b7a8563f Mon Sep 17 00:00:00 2001 From: jserbedzija Date: Thu, 15 Aug 2024 10:05:54 +0000 Subject: [PATCH] Remove flax dependency from TVM --- python/tvm/contrib/pybuda_compile.py | 34 +++++--- python/tvm/contrib/pybuda_utils.py | 124 +++++++++++++-------------- 2 files changed, 82 insertions(+), 76 deletions(-) diff --git a/python/tvm/contrib/pybuda_compile.py b/python/tvm/contrib/pybuda_compile.py index 260867fa4d..961d99e9fc 100644 --- a/python/tvm/contrib/pybuda_compile.py +++ b/python/tvm/contrib/pybuda_compile.py @@ -40,7 +40,6 @@ from jax.tools.jax_to_ir import tf_wrap_with_input_names import collections from transformers.utils.generic import ModelOutput -from transformers.modeling_flax_utils import FlaxPreTrainedModel from tvm.contrib.pybuda_utils import ( extract_framework_model_outputs, extract_flatten_inputs, @@ -752,12 +751,16 @@ def flatten_params(params, parent_key="", sep="."): return dict(items) - if isinstance(jaxmodel, FlaxPreTrainedModel): - model_params = jaxmodel.params - else: - model_params = {} - if hasattr(jaxmodel, 'params'): - model_params = jaxmodel.variables['params']._dict + # if isinstance(jaxmodel, FlaxPreTrainedModel): + # model_params = jaxmodel.params + # else: + # model_params = {} + # if hasattr(jaxmodel, 'params'): + # model_params = jaxmodel.variables['params']._dict + + model_params = {} + if hasattr(jaxmodel, 'params'): + model_params = jaxmodel.variables['params']._dict weight_names = list(flatten_params(model_params).keys()) json_graphs = extract_graphs(partitioned_mod, buda_params, flattened_input_names,weight_names, param_name_lookup, graph_hash=m.hexdigest()) @@ -1006,12 +1009,17 @@ def flatten_params(params, parent_key="", sep="."): return dict(items) - if isinstance(module, FlaxPreTrainedModel): - module_params = module.params - else: - module_params = {} - if hasattr(module, 'params'): - module_params = module.variables['params']._dict + # if isinstance(module, FlaxPreTrainedModel): + # module_params = module.params + # else: + # module_params = {} + # if hasattr(module, 'params'): + # module_params = module.variables['params']._dict + + module_params = {} + if hasattr(module, 'params'): + module_params = module.variables['params']._dict + module_params = flatten_params(module_params) weights = {} diff --git a/python/tvm/contrib/pybuda_utils.py b/python/tvm/contrib/pybuda_utils.py index b5c961eafa..32587c0083 100644 --- a/python/tvm/contrib/pybuda_utils.py +++ b/python/tvm/contrib/pybuda_utils.py @@ -4,13 +4,11 @@ from collections import OrderedDict from collections.abc import MutableMapping -import flax import torch import numpy as np import tensorflow as tf import onnxruntime as ort from transformers.utils.generic import ModelOutput as HFModelOutput -from transformers.modeling_flax_utils import FlaxPreTrainedModel from transformers.modeling_outputs import ModelOutput import tvm @@ -256,67 +254,67 @@ def construct_tvm_ir(framework: str, model, tvm_mod, params, compiler_cfg: Compi tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], propped_params) ) - elif framework == "jax": - - def flatten(d, parent_key="", sep="."): - items = [] - for k, v in d.items(): - new_key = parent_key + sep + k if parent_key else k - if isinstance(v, MutableMapping): - items.extend(flatten(v, new_key, sep=sep).items()) - else: - items.append((new_key, v)) - return dict(items) - - # TODO: Destupidify this! (Maybe we can sort by a substring of the weight names to make this more efficient) - found_weights = [] - param_name_lookup = {} - non_weight_params = {} # Some parameters (like causal mask) are not weights - - if isinstance(model, FlaxPreTrainedModel): - model_params = model.params - elif isinstance(model, flax.linen.Module): - model_params = {} - if hasattr(model, 'params'): - model_params = model.variables['params']._dict - else: - raise RuntimeError("Unknown Jax module instance.") - - model_params = flatten(model_params) - for (bad_name, value) in params.items(): - weight_found = False - for name, jax_value in model_params.items(): - if name not in found_weights and np.array_equal(jax_value.to_py(), value.numpy()): - param_name_lookup[bad_name] = name - weight_found = True - found_weights.append(name) - break - if not weight_found: - param_name_lookup[bad_name] = bad_name - non_weight_params[bad_name] = value - - if not compiler_cfg.enable_tvm_constant_prop: - tvm_mod = tvm.IRModule.from_expr( - tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], non_weight_params) - ) - else: - if len(compiler_cfg.tvm_constnat_prop_mask): - propped_params = { - k: v - for k, v, in params.items() - if any( - [ - mask in param_name_lookup[k] - for mask in compiler_cfg.tvm_constnat_prop_mask - ] - ) - } - propped_params.update(non_weight_params) - else: - propped_params = params - tvm_mod = tvm.IRModule.from_expr( - tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], propped_params) - ) + # elif framework == "jax": + + # def flatten(d, parent_key="", sep="."): + # items = [] + # for k, v in d.items(): + # new_key = parent_key + sep + k if parent_key else k + # if isinstance(v, MutableMapping): + # items.extend(flatten(v, new_key, sep=sep).items()) + # else: + # items.append((new_key, v)) + # return dict(items) + + # # TODO: Destupidify this! (Maybe we can sort by a substring of the weight names to make this more efficient) + # found_weights = [] + # param_name_lookup = {} + # non_weight_params = {} # Some parameters (like causal mask) are not weights + + # if isinstance(model, FlaxPreTrainedModel): + # model_params = model.params + # elif isinstance(model, flax.linen.Module): + # model_params = {} + # if hasattr(model, 'params'): + # model_params = model.variables['params']._dict + # else: + # raise RuntimeError("Unknown Jax module instance.") + + # model_params = flatten(model_params) + # for (bad_name, value) in params.items(): + # weight_found = False + # for name, jax_value in model_params.items(): + # if name not in found_weights and np.array_equal(jax_value.to_py(), value.numpy()): + # param_name_lookup[bad_name] = name + # weight_found = True + # found_weights.append(name) + # break + # if not weight_found: + # param_name_lookup[bad_name] = bad_name + # non_weight_params[bad_name] = value + + # if not compiler_cfg.enable_tvm_constant_prop: + # tvm_mod = tvm.IRModule.from_expr( + # tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], non_weight_params) + # ) + # else: + # if len(compiler_cfg.tvm_constnat_prop_mask): + # propped_params = { + # k: v + # for k, v, in params.items() + # if any( + # [ + # mask in param_name_lookup[k] + # for mask in compiler_cfg.tvm_constnat_prop_mask + # ] + # ) + # } + # propped_params.update(non_weight_params) + # else: + # propped_params = params + # tvm_mod = tvm.IRModule.from_expr( + # tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], propped_params) + # ) else: raise RuntimeError("Unsupported framework type: {}".format(framework))