diff --git a/docs/src/test.md b/docs/src/test.md index 52c9a63ae..22beae0df 100644 --- a/docs/src/test.md +++ b/docs/src/test.md @@ -92,6 +92,11 @@ Full list of supported query parameters | TEST_ID | Id of a test containing test parameters | test_single | | ID_FILE | Path to a file containing test ids | test_ids | +Test configuration parameters + +| Parameter | Description | Supported by commands | +| ------------------------- | --------------------------------------------------------------------------------------------- | ------------------------------------- | +| SKIP_FORGE_VERIFICATION | Skip Forge model verification including model compiling and inference | all | To check supported values and options for each query parameter please run command `print_query_docs`. diff --git a/forge/test/operators/pytorch/test_all.py b/forge/test/operators/pytorch/test_all.py index 93c881186..ec114b7eb 100644 --- a/forge/test/operators/pytorch/test_all.py +++ b/forge/test/operators/pytorch/test_all.py @@ -418,6 +418,10 @@ def print_query_params(cls, max_width=80): cls.print_query_values(max_width) print("Query examples:") cls.print_query_examples(max_width) + print("Configuration parameters:") + cls.print_configuration_params(max_width) + print("Configuration examples:") + cls.print_configuration_examples(max_width) @classmethod def print_query_values(cls, max_width=80): @@ -500,6 +504,28 @@ def print_query_examples(cls, max_width=80): cls.print_formatted_parameters(parameters, max_width, headers=["Parameter", "Examples"]) + @classmethod + def print_configuration_params(cls, max_width=80): + + parameters = [ + { + "name": "SKIP_FORGE_VERIFICATION", + "description": f"Skip Forge model verification including model compiling and inference", + "default": "false", + }, + ] + + cls.print_formatted_parameters(parameters, max_width, headers=["Parameter", "Description", "Default"]) + + @classmethod + def print_configuration_examples(cls, max_width=80): + + parameters = [ + {"name": "SKIP_FORGE_VERIFICATION", "description": "export SKIP_FORGE_VERIFICATION=true"}, + ] + + cls.print_formatted_parameters(parameters, max_width, headers=["Parameter", "Examples"]) + @classmethod def print_formatted_parameters(cls, parameters, max_width=80, headers=["Parameter", "Description"]): for param in parameters: diff --git a/forge/test/operators/utils/__init__.py b/forge/test/operators/utils/__init__.py index 79780d22c..7f5d22f4b 100644 --- a/forge/test/operators/utils/__init__.py +++ b/forge/test/operators/utils/__init__.py @@ -13,6 +13,7 @@ from .utils import LoggerUtils from .utils import RateLimiter from .utils import FrameworkModelType +from .features import TestFeaturesConfiguration from .plan import InputSource from .plan import TestVector from .plan import TestCollection @@ -41,6 +42,7 @@ "VerifyUtils", "LoggerUtils", "RateLimiter", + "TestFeaturesConfiguration", "FrameworkModelType", "InputSource", "TestVector", diff --git a/forge/test/operators/utils/compat.py b/forge/test/operators/utils/compat.py index 4668ac12e..b3be654ca 100644 --- a/forge/test/operators/utils/compat.py +++ b/forge/test/operators/utils/compat.py @@ -11,6 +11,7 @@ from typing import Optional, List, Union from forge import ForgeModule, Module, DepricatedVerifyConfig +from forge.tensor import to_pt_tensors from forge.op_repo import TensorShape from forge.verify.compare import compare_with_golden from forge.verify.verify import verify @@ -326,3 +327,52 @@ def verify_module_for_inputs( forge_inputs = [forge.Tensor.create_from_torch(input, dev_data_format=dev_data_format) for input in inputs] compiled_model = forge.compile(model, sample_inputs=forge_inputs) verify(inputs, model, compiled_model, verify_config) + + +def verify_module_for_inputs_torch( + model: Module, + inputs: List[torch.Tensor], + verify_config: Optional[VerifyConfig] = VerifyConfig(), +): + + verify_torch(inputs, model, verify_config) + + +def verify_torch( + inputs: List[torch.Tensor], + framework_model: torch.nn.Module, + verify_cfg: VerifyConfig = VerifyConfig(), +): + """ + Verify the pytorch model with the given inputs + """ + if not verify_cfg.enabled: + logger.warning("Verification is disabled") + return + + # 0th step: input checks + + # Check if inputs are of the correct type + if not inputs: + raise ValueError("Input tensors must be provided") + for input_tensor in inputs: + if not isinstance(input_tensor, verify_cfg.supported_tensor_types): + raise TypeError( + f"Input tensor must be of type {verify_cfg.supported_tensor_types}, but got {type(input_tensor)}" + ) + + if not isinstance(framework_model, verify_cfg.framework_model_types): + raise TypeError( + f"Framework model must be of type {verify_cfg.framework_model_types}, but got {type(framework_model)}" + ) + + # 1st step: run forward pass for the networks + fw_out = framework_model(*inputs) + + # 2nd step: apply preprocessing (push tensors to cpu, perform any reshape if necessary, + # cast from tensorflow tensors to pytorch tensors if needed) + if not isinstance(fw_out, torch.Tensor): + fw_out = to_pt_tensors(fw_out) + + fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out + return fw_out diff --git a/forge/test/operators/utils/features.py b/forge/test/operators/utils/features.py new file mode 100644 index 000000000..ec5cd6136 --- /dev/null +++ b/forge/test/operators/utils/features.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import os + + +class TestFeaturesConfiguration: + """Store test features configuration""" + + __test__ = False # Disable pytest collection + + @staticmethod + def get_env_property(env_var: str, default_value: str): + return os.getenv(env_var, default_value) + + SKIP_FORGE_VERIFICATION = get_env_property("SKIP_FORGE_VERIFICATION", "false").lower() == "true" diff --git a/forge/test/operators/utils/utils.py b/forge/test/operators/utils/utils.py index 91674879a..ddacff052 100644 --- a/forge/test/operators/utils/utils.py +++ b/forge/test/operators/utils/utils.py @@ -26,8 +26,14 @@ from forge.verify.config import VerifyConfig from .compat import TestDevice -from .compat import create_torch_inputs, verify_module_for_inputs, verify_module_for_inputs_deprecated +from .compat import ( + create_torch_inputs, + verify_module_for_inputs, + verify_module_for_inputs_deprecated, + verify_module_for_inputs_torch, +) from .datatypes import ValueRanges +from .features import TestFeaturesConfiguration # All supported framework model types @@ -130,6 +136,7 @@ def verify( warm_reset: bool = False, deprecated_verification: bool = True, verify_config: Optional[VerifyConfig] = VerifyConfig(), + skip_forge_verification: bool = TestFeaturesConfiguration.SKIP_FORGE_VERIFICATION, ): """Perform Forge verification on the model @@ -146,6 +153,8 @@ def verify( random_seed: Random seed warm_reset: Warm reset the device before verification deprecated_verification: Use deprecated verification method + verify_config: Verification configuration + skip_forge_verification: Skip verification with Forge module """ cls.setup( @@ -168,6 +177,12 @@ def verify( pcc=pcc, dev_data_format=dev_data_format, ) + elif skip_forge_verification: + verify_module_for_inputs_torch( + model=model, + inputs=inputs, + verify_config=verify_config, + ) else: cls.verify_module_for_inputs( model=model,