From 0f4c1074c3b9ba5f187a44026675d50ad952c107 Mon Sep 17 00:00:00 2001 From: vkovinicTT Date: Mon, 18 Nov 2024 09:39:47 +0000 Subject: [PATCH 1/2] remove verifyConfig --- python/tvm/contrib/forge_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/forge_utils.py b/python/tvm/contrib/forge_utils.py index 56f8daea51..01ae5f53ff 100644 --- a/python/tvm/contrib/forge_utils.py +++ b/python/tvm/contrib/forge_utils.py @@ -13,7 +13,6 @@ import tvm from tvm.relay import ExprVisitor -from forge.verify.config import VerifyConfig from forge.config import CompilerConfig from forge.tvm_utils import flatten_inputs, flatten_structured_output from forge.tensor import to_pt_tensors @@ -24,13 +23,13 @@ def extract_framework_model_outputs( framework: str, model, inputs, - verify_cfg: VerifyConfig, + verify_tvm_compile: bool = False, path=None, input_dict={}, ): framework_outputs = [] - if verify_cfg is None or not verify_cfg.verify_tvm_compile: + if verify_tvm_compile: return framework_outputs if framework == "pytorch": From 8cd46b30f8b6ef079a9fb457d1f6fc661b4f69c9 Mon Sep 17 00:00:00 2001 From: vkovinicTT Date: Mon, 18 Nov 2024 10:22:26 +0000 Subject: [PATCH 2/2] verify_cfg -> verify_tvm_compile --- python/tvm/contrib/forge_compile.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/contrib/forge_compile.py b/python/tvm/contrib/forge_compile.py index 1f905079fb..c367e58043 100644 --- a/python/tvm/contrib/forge_compile.py +++ b/python/tvm/contrib/forge_compile.py @@ -338,7 +338,7 @@ def compile_pytorch_for_forge(torchmod, *inputs, graph_name, compiler_cfg, verif framework="pytorch", model=torchmod, inputs=inputs, - verify_cfg=verify_cfg, + verify_tvm_compile=verify_cfg.verify_tvm_compile, ) # (Temporary): Remove when forge supports dropout @@ -548,7 +548,7 @@ def compile_onnx_for_forge(onnx_mod, path, *inputs, graph_name, compiler_cfg, ve framework="onnx", model=onnx_mod, inputs=inputs, - verify_cfg=verify_cfg, + verify_tvm_compile=verify_cfg.verify_tvm_compile, path=path, input_dict=input_dict, ) @@ -590,7 +590,7 @@ def compile_tflite_for_forge(module, path, *inputs, graph_name, compiler_cfg, ve framework="tflite", model=module, inputs=inputs, - verify_cfg=verify_cfg, + verify_tvm_compile=verify_cfg.verify_tvm_compile, path=path, ) @@ -689,7 +689,7 @@ def compile_jax_for_forge(jaxmodel, *inputs, graph_name, compiler_cfg, verify_cf framework="jax", model=jaxmodel, inputs=inputs, - verify_cfg=verify_cfg, + verify_tvm_compile=verify_cfg.verify_tvm_compile, ) if compiler_cfg.enable_tvm_jax_freeze_large_model: @@ -775,7 +775,7 @@ def compile_tf_for_forge(tfmod, *inputs, graph_name, compiler_cfg, verify_cfg=No framework="tensorflow", model=tfmod, inputs=inputs, - verify_cfg=verify_cfg, + verify_tvm_compile=verify_cfg.verify_tvm_compile, ) # Trace module & get graph definition