From 7504f94ea1429309f84408086e9668bf13c761e4 Mon Sep 17 00:00:00 2001 From: mramanathan Date: Thu, 9 Jan 2025 04:43:05 +0000 Subject: [PATCH] Add enable_tvm_constant_prop compiler flag --- python/tvm/contrib/forge_utils.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/python/tvm/contrib/forge_utils.py b/python/tvm/contrib/forge_utils.py index 01ae5f53f..677dbfd7f 100644 --- a/python/tvm/contrib/forge_utils.py +++ b/python/tvm/contrib/forge_utils.py @@ -186,23 +186,21 @@ def construct_tvm_ir(framework: str, model, tvm_mod, params, compiler_cfg: Compi if framework == "pytorch": param_name_lookup = {} - if not compiler_cfg.enable_tvm_constant_prop: - tvm_mod = tvm.IRModule.from_expr( - tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], {}) - ) - else: - if len(compiler_cfg.tvm_constnat_prop_mask): - propped_params = { + if params is not None: + propped_params = { k: (v, True) for k, v, in params.items() if any([mask in k for mask in compiler_cfg.tvm_constnat_prop_mask]) } - else: - propped_params = {k: (v, True) for k, v, in params.items()} tvm_mod = tvm.IRModule.from_expr( tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], propped_params) ) - + else: + tvm_mod = tvm.IRModule.from_expr( + tvm.relay.build_module.bind_params_by_name(tvm_mod["main"], {}) + ) + + elif framework == "tensorflow": # TODO: Destupidify this! (Maybe we can sort by a substring of the weight names to make this more efficient) found_weights = []