Skip to content

Commit

Permalink
Remove flax dependency from TVM (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
jserbedzijaTT authored Aug 19, 2024
2 parents 3d4d1b0 + 62c94e9 commit 1c1ee57
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 76 deletions.
34 changes: 21 additions & 13 deletions python/tvm/contrib/pybuda_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from jax.tools.jax_to_ir import tf_wrap_with_input_names
import collections
from transformers.utils.generic import ModelOutput
from transformers.modeling_flax_utils import FlaxPreTrainedModel
from tvm.contrib.pybuda_utils import (
extract_framework_model_outputs,
extract_flatten_inputs,
Expand Down Expand Up @@ -752,12 +751,16 @@ def flatten_params(params, parent_key="", sep="."):

return dict(items)

if isinstance(jaxmodel, FlaxPreTrainedModel):
model_params = jaxmodel.params
else:
model_params = {}
if hasattr(jaxmodel, 'params'):
model_params = jaxmodel.variables['params']._dict
# if isinstance(jaxmodel, FlaxPreTrainedModel):
# model_params = jaxmodel.params
# else:
# model_params = {}
# if hasattr(jaxmodel, 'params'):
# model_params = jaxmodel.variables['params']._dict

model_params = {}
if hasattr(jaxmodel, 'params'):
model_params = jaxmodel.variables['params']._dict

weight_names = list(flatten_params(model_params).keys())
json_graphs = extract_graphs(partitioned_mod, buda_params, flattened_input_names,weight_names, param_name_lookup, graph_hash=m.hexdigest())
Expand Down Expand Up @@ -1006,12 +1009,17 @@ def flatten_params(params, parent_key="", sep="."):

return dict(items)

if isinstance(module, FlaxPreTrainedModel):
module_params = module.params
else:
module_params = {}
if hasattr(module, 'params'):
module_params = module.variables['params']._dict
# if isinstance(module, FlaxPreTrainedModel):
# module_params = module.params
# else:
# module_params = {}
# if hasattr(module, 'params'):
# module_params = module.variables['params']._dict

module_params = {}
if hasattr(module, 'params'):
module_params = module.variables['params']._dict

module_params = flatten_params(module_params)

weights = {}
Expand Down
124 changes: 61 additions & 63 deletions python/tvm/contrib/pybuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from collections import OrderedDict
from collections.abc import MutableMapping

import flax
import torch
import numpy as np
import tensorflow as tf
import onnxruntime as ort
from transformers.utils.generic import ModelOutput as HFModelOutput
from transformers.modeling_flax_utils import FlaxPreTrainedModel
from transformers.modeling_outputs import ModelOutput

import tvm
Expand Down Expand Up @@ -256,67 +254,67 @@ def construct_tvm_ir(framework: str, model, tvm_mod, params, compiler_cfg: Compi
tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], propped_params)
)

elif framework == "jax":

def flatten(d, parent_key="", sep="."):
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)

# TODO: Destupidify this! (Maybe we can sort by a substring of the weight names to make this more efficient)
found_weights = []
param_name_lookup = {}
non_weight_params = {} # Some parameters (like causal mask) are not weights

if isinstance(model, FlaxPreTrainedModel):
model_params = model.params
elif isinstance(model, flax.linen.Module):
model_params = {}
if hasattr(model, 'params'):
model_params = model.variables['params']._dict
else:
raise RuntimeError("Unknown Jax module instance.")

model_params = flatten(model_params)
for (bad_name, value) in params.items():
weight_found = False
for name, jax_value in model_params.items():
if name not in found_weights and np.array_equal(jax_value.to_py(), value.numpy()):
param_name_lookup[bad_name] = name
weight_found = True
found_weights.append(name)
break
if not weight_found:
param_name_lookup[bad_name] = bad_name
non_weight_params[bad_name] = value

if not compiler_cfg.enable_tvm_constant_prop:
tvm_mod = tvm.IRModule.from_expr(
tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], non_weight_params)
)
else:
if len(compiler_cfg.tvm_constnat_prop_mask):
propped_params = {
k: v
for k, v, in params.items()
if any(
[
mask in param_name_lookup[k]
for mask in compiler_cfg.tvm_constnat_prop_mask
]
)
}
propped_params.update(non_weight_params)
else:
propped_params = params
tvm_mod = tvm.IRModule.from_expr(
tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], propped_params)
)
# elif framework == "jax":

# def flatten(d, parent_key="", sep="."):
# items = []
# for k, v in d.items():
# new_key = parent_key + sep + k if parent_key else k
# if isinstance(v, MutableMapping):
# items.extend(flatten(v, new_key, sep=sep).items())
# else:
# items.append((new_key, v))
# return dict(items)

# # TODO: Destupidify this! (Maybe we can sort by a substring of the weight names to make this more efficient)
# found_weights = []
# param_name_lookup = {}
# non_weight_params = {} # Some parameters (like causal mask) are not weights

# if isinstance(model, FlaxPreTrainedModel):
# model_params = model.params
# elif isinstance(model, flax.linen.Module):
# model_params = {}
# if hasattr(model, 'params'):
# model_params = model.variables['params']._dict
# else:
# raise RuntimeError("Unknown Jax module instance.")

# model_params = flatten(model_params)
# for (bad_name, value) in params.items():
# weight_found = False
# for name, jax_value in model_params.items():
# if name not in found_weights and np.array_equal(jax_value.to_py(), value.numpy()):
# param_name_lookup[bad_name] = name
# weight_found = True
# found_weights.append(name)
# break
# if not weight_found:
# param_name_lookup[bad_name] = bad_name
# non_weight_params[bad_name] = value

# if not compiler_cfg.enable_tvm_constant_prop:
# tvm_mod = tvm.IRModule.from_expr(
# tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], non_weight_params)
# )
# else:
# if len(compiler_cfg.tvm_constnat_prop_mask):
# propped_params = {
# k: v
# for k, v, in params.items()
# if any(
# [
# mask in param_name_lookup[k]
# for mask in compiler_cfg.tvm_constnat_prop_mask
# ]
# )
# }
# propped_params.update(non_weight_params)
# else:
# propped_params = params
# tvm_mod = tvm.IRModule.from_expr(
# tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], propped_params)
# )
else:
raise RuntimeError("Unsupported framework type: {}".format(framework))

Expand Down

0 comments on commit 1c1ee57

Please sign in to comment.