Skip to content

Commit

Permalink
Add enable_tvm_constant_prop compiler flag
Browse files Browse the repository at this point in the history
  • Loading branch information
meenakshiramanathan1 committed Jan 9, 2025
1 parent e405246 commit 07386f3
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions python/tvm/contrib/forge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,23 +186,21 @@ def construct_tvm_ir(framework: str, model, tvm_mod, params, compiler_cfg: Compi
if framework == "pytorch":
param_name_lookup = {}

if not compiler_cfg.enable_tvm_constant_prop:
tvm_mod = tvm.IRModule.from_expr(
tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], {})
)
else:
if len(compiler_cfg.tvm_constnat_prop_mask):
propped_params = {
if params is not None:
propped_params = {
k: (v, True)
for k, v, in params.items()
if any([mask in k for mask in compiler_cfg.tvm_constnat_prop_mask])
}
else:
propped_params = {k: (v, True) for k, v, in params.items()}
tvm_mod = tvm.IRModule.from_expr(
tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], propped_params)
)

else:
tvm_mod = tvm.IRModule.from_expr(
tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], {})
)


elif framework == "tensorflow":
# TODO: Destupidify this! (Maybe we can sort by a substring of the weight names to make this more efficient)
found_weights = []
Expand Down

0 comments on commit 07386f3

Please sign in to comment.