Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove flax dependency from TVM #19

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions python/tvm/contrib/pybuda_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from jax.tools.jax_to_ir import tf_wrap_with_input_names
import collections
from transformers.utils.generic import ModelOutput
from transformers.modeling_flax_utils import FlaxPreTrainedModel
from tvm.contrib.pybuda_utils import (
extract_framework_model_outputs,
extract_flatten_inputs,
Expand Down Expand Up @@ -752,12 +751,16 @@ def flatten_params(params, parent_key="", sep="."):

return dict(items)

if isinstance(jaxmodel, FlaxPreTrainedModel):
model_params = jaxmodel.params
else:
model_params = {}
if hasattr(jaxmodel, 'params'):
model_params = jaxmodel.variables['params']._dict
# if isinstance(jaxmodel, FlaxPreTrainedModel):
# model_params = jaxmodel.params
# else:
# model_params = {}
# if hasattr(jaxmodel, 'params'):
# model_params = jaxmodel.variables['params']._dict

model_params = {}
if hasattr(jaxmodel, 'params'):
model_params = jaxmodel.variables['params']._dict

weight_names = list(flatten_params(model_params).keys())
json_graphs = extract_graphs(partitioned_mod, buda_params, flattened_input_names,weight_names, param_name_lookup, graph_hash=m.hexdigest())
Expand Down Expand Up @@ -1006,12 +1009,17 @@ def flatten_params(params, parent_key="", sep="."):

return dict(items)

if isinstance(module, FlaxPreTrainedModel):
module_params = module.params
else:
module_params = {}
if hasattr(module, 'params'):
module_params = module.variables['params']._dict
# if isinstance(module, FlaxPreTrainedModel):
# module_params = module.params
# else:
# module_params = {}
# if hasattr(module, 'params'):
# module_params = module.variables['params']._dict

module_params = {}
if hasattr(module, 'params'):
module_params = module.variables['params']._dict

module_params = flatten_params(module_params)

weights = {}
Expand Down
124 changes: 61 additions & 63 deletions python/tvm/contrib/pybuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from collections import OrderedDict
from collections.abc import MutableMapping

import flax
import torch
import numpy as np
import tensorflow as tf
import onnxruntime as ort
from transformers.utils.generic import ModelOutput as HFModelOutput
from transformers.modeling_flax_utils import FlaxPreTrainedModel
from transformers.modeling_outputs import ModelOutput

import tvm
Expand Down Expand Up @@ -256,67 +254,67 @@ def construct_tvm_ir(framework: str, model, tvm_mod, params, compiler_cfg: Compi
tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], propped_params)
)

elif framework == "jax":

def flatten(d, parent_key="", sep="."):
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)

# TODO: Destupidify this! (Maybe we can sort by a substring of the weight names to make this more efficient)
found_weights = []
param_name_lookup = {}
non_weight_params = {} # Some parameters (like causal mask) are not weights

if isinstance(model, FlaxPreTrainedModel):
model_params = model.params
elif isinstance(model, flax.linen.Module):
model_params = {}
if hasattr(model, 'params'):
model_params = model.variables['params']._dict
else:
raise RuntimeError("Unknown Jax module instance.")

model_params = flatten(model_params)
for (bad_name, value) in params.items():
weight_found = False
for name, jax_value in model_params.items():
if name not in found_weights and np.array_equal(jax_value.to_py(), value.numpy()):
param_name_lookup[bad_name] = name
weight_found = True
found_weights.append(name)
break
if not weight_found:
param_name_lookup[bad_name] = bad_name
non_weight_params[bad_name] = value

if not compiler_cfg.enable_tvm_constant_prop:
tvm_mod = tvm.IRModule.from_expr(
tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], non_weight_params)
)
else:
if len(compiler_cfg.tvm_constnat_prop_mask):
propped_params = {
k: v
for k, v, in params.items()
if any(
[
mask in param_name_lookup[k]
for mask in compiler_cfg.tvm_constnat_prop_mask
]
)
}
propped_params.update(non_weight_params)
else:
propped_params = params
tvm_mod = tvm.IRModule.from_expr(
tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], propped_params)
)
# elif framework == "jax":

# def flatten(d, parent_key="", sep="."):
# items = []
# for k, v in d.items():
# new_key = parent_key + sep + k if parent_key else k
# if isinstance(v, MutableMapping):
# items.extend(flatten(v, new_key, sep=sep).items())
# else:
# items.append((new_key, v))
# return dict(items)

# # TODO: Destupidify this! (Maybe we can sort by a substring of the weight names to make this more efficient)
# found_weights = []
# param_name_lookup = {}
# non_weight_params = {} # Some parameters (like causal mask) are not weights

# if isinstance(model, FlaxPreTrainedModel):
# model_params = model.params
# elif isinstance(model, flax.linen.Module):
# model_params = {}
# if hasattr(model, 'params'):
# model_params = model.variables['params']._dict
# else:
# raise RuntimeError("Unknown Jax module instance.")

# model_params = flatten(model_params)
# for (bad_name, value) in params.items():
# weight_found = False
# for name, jax_value in model_params.items():
# if name not in found_weights and np.array_equal(jax_value.to_py(), value.numpy()):
# param_name_lookup[bad_name] = name
# weight_found = True
# found_weights.append(name)
# break
# if not weight_found:
# param_name_lookup[bad_name] = bad_name
# non_weight_params[bad_name] = value

# if not compiler_cfg.enable_tvm_constant_prop:
# tvm_mod = tvm.IRModule.from_expr(
# tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], non_weight_params)
# )
# else:
# if len(compiler_cfg.tvm_constnat_prop_mask):
# propped_params = {
# k: v
# for k, v, in params.items()
# if any(
# [
# mask in param_name_lookup[k]
# for mask in compiler_cfg.tvm_constnat_prop_mask
# ]
# )
# }
# propped_params.update(non_weight_params)
# else:
# propped_params = params
# tvm_mod = tvm.IRModule.from_expr(
# tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], propped_params)
# )
else:
raise RuntimeError("Unsupported framework type: {}".format(framework))

Expand Down