From 8cebf995ff216fcfd049e01e2800afaf42897c35 Mon Sep 17 00:00:00 2001 From: anzr299 Date: Wed, 29 Jan 2025 17:53:00 +0400 Subject: [PATCH 1/6] init --- .../torch/fx/nncf_graph_builder.py | 3 ++ .../synthetic_transformer.dot | 49 +++++++++++++++++ .../quantized/synthetic_transformer.dot | 53 +++++++++++++++++++ tests/torch/fx/helpers.py | 4 +- tests/torch/fx/test_models.py | 27 +++++++--- 5 files changed, 127 insertions(+), 9 deletions(-) create mode 100644 tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/synthetic_transformer.dot create mode 100644 tests/torch/data/reference_graphs/fx/dynamic_shapes/quantized/synthetic_transformer.dot diff --git a/nncf/experimental/torch/fx/nncf_graph_builder.py b/nncf/experimental/torch/fx/nncf_graph_builder.py index 36dc3b46e6e..3d162f27400 100644 --- a/nncf/experimental/torch/fx/nncf_graph_builder.py +++ b/nncf/experimental/torch/fx/nncf_graph_builder.py @@ -197,6 +197,9 @@ def get_edge_params( tensor = source_node.meta["val"] if isinstance(tensor, torch.Tensor): tensor_shape = tuple(tensor.shape) + tensor_shape = tuple(str(i) if isinstance(i, torch.SymInt) else i for i in tensor_shape) + if isinstance(tensor, torch.SymInt): + tensor_shape = (str(tensor),) if tensor_shape is None: # TODO(dlyakhov): Refactor algorithms to always have knowns edges shapes. diff --git a/tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/synthetic_transformer.dot b/tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/synthetic_transformer.dot new file mode 100644 index 00000000000..f9d6b83d2ed --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/synthetic_transformer.dot @@ -0,0 +1,49 @@ +strict digraph { +"0 wte_weight" [id=0, type=get_attr]; +"1 linear_bias" [id=1, type=get_attr]; +"2 lm_head_bias" [id=2, type=get_attr]; +"3 input_ids" [id=3, type=input]; +"4 embedding" [id=4, type=embedding]; +"5 embedding_0_0_nncf_smooth_quant_0" [id=5, type=call_module]; +"6 quantize_per_tensor_default" [id=6, type=quantize_per_tensor]; +"7 dequantize_per_tensor_default" [id=7, type=dequantize_per_tensor]; +"8 scale_updated_constant0" [id=8, type=get_attr]; +"9 compressed_weight_updated_constant0" [id=9, type=get_attr]; +"10 mul_tensor" [id=10, type=mul]; +"11 zero_point_updated_constant0" [id=11, type=get_attr]; +"12 sub_tensor" [id=12, type=sub]; +"13 linear" [id=13, type=linear]; +"14 linear_0_0_nncf_smooth_quant_0" [id=14, type=call_module]; +"15 quantize_per_tensor_default_1" [id=15, type=quantize_per_tensor]; +"16 dequantize_per_tensor_default_1" [id=16, type=dequantize_per_tensor]; +"17 scale_updated_constant1" [id=17, type=get_attr]; +"18 compressed_weight_updated_constant1" [id=18, type=get_attr]; +"19 mul_tensor_1" [id=19, type=mul]; +"20 zero_point_updated_constant1" [id=20, type=get_attr]; +"21 sub_tensor_1" [id=21, type=sub]; +"22 linear_1" [id=22, type=linear]; +"23 output" [id=23, type=output]; +"0 wte_weight" -> "4 embedding" [label="(10, 5)", style=solid]; +"1 linear_bias" -> "13 linear" [label="(5,)", style=solid]; +"2 lm_head_bias" -> "22 linear_1" [label="(10,)", style=solid]; +"3 input_ids" -> "4 embedding" [label="('s0',)", style=solid]; +"4 embedding" -> "5 embedding_0_0_nncf_smooth_quant_0" [label="('s0', 5)", style=solid]; +"5 embedding_0_0_nncf_smooth_quant_0" -> "6 quantize_per_tensor_default" [label="('s0', 5)", style=solid]; +"6 quantize_per_tensor_default" -> "7 dequantize_per_tensor_default" [label="('s0', 5)", style=solid]; +"7 dequantize_per_tensor_default" -> "13 linear" [label="('s0', 5)", style=solid]; +"8 scale_updated_constant0" -> "10 mul_tensor" [label="(5, 1)", style=solid]; +"9 compressed_weight_updated_constant0" -> "10 mul_tensor" [label="(5, 5)", style=solid]; +"10 mul_tensor" -> "12 sub_tensor" [label="(5, 5)", style=solid]; +"11 zero_point_updated_constant0" -> "12 sub_tensor" [label="(5, 1)", style=solid]; +"12 sub_tensor" -> "13 linear" [label="(5, 5)", style=solid]; +"13 linear" -> "14 linear_0_0_nncf_smooth_quant_0" [label="('s0', 5)", style=solid]; +"14 linear_0_0_nncf_smooth_quant_0" -> "15 quantize_per_tensor_default_1" [label="('s0', 5)", style=solid]; +"15 quantize_per_tensor_default_1" -> "16 dequantize_per_tensor_default_1" [label="('s0', 5)", style=solid]; +"16 dequantize_per_tensor_default_1" -> "22 linear_1" [label="('s0', 5)", style=solid]; +"17 scale_updated_constant1" -> "19 mul_tensor_1" [label="(10, 1)", style=solid]; +"18 compressed_weight_updated_constant1" -> "19 mul_tensor_1" [label="(10, 5)", style=solid]; +"19 mul_tensor_1" -> "21 sub_tensor_1" [label="(10, 5)", style=solid]; +"20 zero_point_updated_constant1" -> "21 sub_tensor_1" [label="(10, 1)", style=solid]; +"21 sub_tensor_1" -> "22 linear_1" [label="(10, 5)", style=solid]; +"22 linear_1" -> "23 output" [label="('s0', 10)", style=solid]; +} diff --git a/tests/torch/data/reference_graphs/fx/dynamic_shapes/quantized/synthetic_transformer.dot b/tests/torch/data/reference_graphs/fx/dynamic_shapes/quantized/synthetic_transformer.dot new file mode 100644 index 00000000000..0a0ad795c9a --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/dynamic_shapes/quantized/synthetic_transformer.dot @@ -0,0 +1,53 @@ +strict digraph { +"0 wte_weight" [id=0, type=get_attr]; +"1 linear_bias" [id=1, type=get_attr]; +"2 lm_head_bias" [id=2, type=get_attr]; +"3 input_ids" [id=3, type=input]; +"4 embedding" [id=4, type=embedding]; +"5 embedding_0_0_nncf_smooth_quant_0" [id=5, type=call_module]; +"6 quantize_per_tensor_default" [id=6, type=quantize_per_tensor]; +"7 dequantize_per_tensor_default" [id=7, type=dequantize_per_tensor]; +"8 linear_scale_0" [id=8, type=get_attr]; +"9 linear_zero_point_0" [id=9, type=get_attr]; +"10 compressed_weight_updated_constant0" [id=10, type=get_attr]; +"11 quantize_per_channel_default" [id=11, type=quantize_per_channel]; +"12 dequantize_per_channel_default" [id=12, type=dequantize_per_channel]; +"13 linear" [id=13, type=linear]; +"14 linear_0_0_nncf_smooth_quant_0" [id=14, type=call_module]; +"15 quantize_per_tensor_default_1" [id=15, type=quantize_per_tensor]; +"16 dequantize_per_tensor_default_1" [id=16, type=dequantize_per_tensor]; +"17 linear_1_scale_0" [id=17, type=get_attr]; +"18 linear_1_zero_point_0" [id=18, type=get_attr]; +"19 compressed_weight_updated_constant1" [id=19, type=get_attr]; +"20 quantize_per_channel_default_1" [id=20, type=quantize_per_channel]; +"21 dequantize_per_channel_default_1" [id=21, type=dequantize_per_channel]; +"22 linear_1" [id=22, type=linear]; +"23 output" [id=23, type=output]; +"0 wte_weight" -> "4 embedding" [label="(10, 5)", style=solid]; +"1 linear_bias" -> "13 linear" [label="(5,)", style=solid]; +"2 lm_head_bias" -> "22 linear_1" [label="(10,)", style=solid]; +"3 input_ids" -> "4 embedding" [label="('s0',)", style=solid]; +"4 embedding" -> "5 embedding_0_0_nncf_smooth_quant_0" [label="('s0', 5)", style=solid]; +"5 embedding_0_0_nncf_smooth_quant_0" -> "6 quantize_per_tensor_default" [label="('s0', 5)", style=solid]; +"6 quantize_per_tensor_default" -> "7 dequantize_per_tensor_default" [label="('s0', 5)", style=solid]; +"7 dequantize_per_tensor_default" -> "13 linear" [label="('s0', 5)", style=solid]; +"8 linear_scale_0" -> "11 quantize_per_channel_default" [label="(5,)", style=solid]; +"8 linear_scale_0" -> "12 dequantize_per_channel_default" [label="(5,)", style=solid]; +"9 linear_zero_point_0" -> "11 quantize_per_channel_default" [label="(5,)", style=solid]; +"9 linear_zero_point_0" -> "12 dequantize_per_channel_default" [label="(5,)", style=solid]; +"10 compressed_weight_updated_constant0" -> "11 quantize_per_channel_default" [label="(5, 5)", style=solid]; +"11 quantize_per_channel_default" -> "12 dequantize_per_channel_default" [label="(5, 5)", style=solid]; +"12 dequantize_per_channel_default" -> "13 linear" [label="(5, 5)", style=solid]; +"13 linear" -> "14 linear_0_0_nncf_smooth_quant_0" [label="('s0', 5)", style=solid]; +"14 linear_0_0_nncf_smooth_quant_0" -> "15 quantize_per_tensor_default_1" [label="('s0', 5)", style=solid]; +"15 quantize_per_tensor_default_1" -> "16 dequantize_per_tensor_default_1" [label="('s0', 5)", style=solid]; +"16 dequantize_per_tensor_default_1" -> "22 linear_1" [label="('s0', 5)", style=solid]; +"17 linear_1_scale_0" -> "20 quantize_per_channel_default_1" [label="(10,)", style=solid]; +"17 linear_1_scale_0" -> "21 dequantize_per_channel_default_1" [label="(10,)", style=solid]; +"18 linear_1_zero_point_0" -> "20 quantize_per_channel_default_1" [label="(10,)", style=solid]; +"18 linear_1_zero_point_0" -> "21 dequantize_per_channel_default_1" [label="(10,)", style=solid]; +"19 compressed_weight_updated_constant1" -> "20 quantize_per_channel_default_1" [label="(10, 5)", style=solid]; +"20 quantize_per_channel_default_1" -> "21 dequantize_per_channel_default_1" [label="(10, 5)", style=solid]; +"21 dequantize_per_channel_default_1" -> "22 linear_1" [label="(10, 5)", style=solid]; +"22 linear_1" -> "23 output" [label="('s0', 10)", style=solid]; +} diff --git a/tests/torch/fx/helpers.py b/tests/torch/fx/helpers.py index b51f6da16f4..9e599b38bdb 100644 --- a/tests/torch/fx/helpers.py +++ b/tests/torch/fx/helpers.py @@ -124,7 +124,7 @@ def visualize_fx_model(model: torch.fx.GraphModule, output_svg_path: str): def get_torch_fx_model( - model: torch.nn.Module, ex_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] + model: torch.nn.Module, ex_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]], dynamic_shapes=None ) -> torch.fx.GraphModule: """ Converts given module to GraphModule. @@ -151,7 +151,7 @@ def get_torch_fx_model( model.eval() with torch.no_grad(): with disable_patching(): - return torch.export.export_for_training(model, args=device_ex_input).module() + return torch.export.export_for_training(model, args=device_ex_input, dynamic_shapes=dynamic_shapes).module() def get_torch_fx_model_q_transformed(model: torch.nn.Module, ex_input: torch.Tensor) -> torch.fx.GraphModule: diff --git a/tests/torch/fx/test_models.py b/tests/torch/fx/test_models.py index 39538f8b7d9..17da34e7d52 100644 --- a/tests/torch/fx/test_models.py +++ b/tests/torch/fx/test_models.py @@ -44,10 +44,15 @@ from tests.torch.test_models.synthetic import MultiBranchesConnectedModel from tests.torch.test_models.synthetic import ShortTransformer from tests.torch.test_models.synthetic import YOLO11N_SDPABlock +from torch.export.dynamic_shapes import Dim FX_DIR_NAME = Path("fx") -FX_QUANTIZED_DIR_NAME = Path("fx") / "quantized" -FX_QUANTIZED_COMPRESSED_DIR_NAME = Path("fx") / "post_quantization_compressed" +FX_QUANTIZED_DIR_NAME = FX_DIR_NAME / "quantized" +FX_QUANTIZED_COMPRESSED_DIR_NAME = FX_DIR_NAME / "post_quantization_compressed" + +FX_DYNAMIC_DIR = FX_DIR_NAME / "dynamic_shapes" +FX_DYNAMIC_QUANTIZED_DIR_NAME = FX_DYNAMIC_DIR / "quantized" +FX_DYNAMIC_QUANTIZED_COMPRESSED_DIR_NAME = FX_DYNAMIC_DIR / "post_quantization_compressed" @dataclass @@ -170,19 +175,23 @@ def test_model(test_case: ModelCase): ), ) - +@pytest.mark.parametrize("enable_dynamic_shapes", [True, False]) @pytest.mark.parametrize("compress_weights", [True, False]) @pytest.mark.parametrize( ("model_case", "quantization_parameters", "compress_n_qdq"), TEST_MODELS_QUANIZED, ids=[m[0].model_id for m in TEST_MODELS_QUANIZED], ) -def test_quantized_model(model_case: ModelCase, quantization_parameters, compress_weights: bool, compress_n_qdq: int): +def test_quantized_model(model_case: ModelCase, quantization_parameters, compress_weights: bool, compress_n_qdq: int, enable_dynamic_shapes: bool): model = model_case.model_builder() dtype = torch.int32 if model_case.model_id == "synthetic_transformer" else torch.float32 example_input = torch.ones(model_case.input_shape, dtype=dtype) - - fx_model = get_torch_fx_model(model, example_input) + dynamic_shapes = None + enable_dynamic_shapes = model_case.model_id == "synthetic_transformer" and enable_dynamic_shapes + if(enable_dynamic_shapes): + dynamic_shapes = [(Dim.AUTO)] + + fx_model = get_torch_fx_model(model, example_input, dynamic_shapes=dynamic_shapes) def transform_fn(data_item): return data_item.to("cpu") @@ -198,7 +207,11 @@ def transform_fn(data_item): # Uncomment to visualize torch fx graph # from tests.torch.fx.helpers import visualize_fx_model # visualize_fx_model(quantized_model, f"{model_case.model_id}_int8.svg") - save_dir = FX_QUANTIZED_COMPRESSED_DIR_NAME if compress_weights else FX_QUANTIZED_DIR_NAME + if dynamic_shapes: + save_dir = FX_DYNAMIC_QUANTIZED_COMPRESSED_DIR_NAME if compress_weights else FX_DYNAMIC_QUANTIZED_DIR_NAME + else: + save_dir = FX_QUANTIZED_COMPRESSED_DIR_NAME if compress_weights else FX_QUANTIZED_DIR_NAME + nncf_graph = GraphConverter.create_nncf_graph(quantized_model) check_graph(nncf_graph, get_dot_filename(model_case.model_id), save_dir, extended=True) q_nodes, dq_nodes = count_q_dq(quantized_model) From fbd5779a41e9a28474918a654bd64ab6b47f3984 Mon Sep 17 00:00:00 2001 From: anzr299 Date: Wed, 29 Jan 2025 17:53:11 +0400 Subject: [PATCH 2/6] fix --- tests/torch/fx/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/torch/fx/test_models.py b/tests/torch/fx/test_models.py index 17da34e7d52..6e8eb8ae457 100644 --- a/tests/torch/fx/test_models.py +++ b/tests/torch/fx/test_models.py @@ -189,7 +189,7 @@ def test_quantized_model(model_case: ModelCase, quantization_parameters, compres dynamic_shapes = None enable_dynamic_shapes = model_case.model_id == "synthetic_transformer" and enable_dynamic_shapes if(enable_dynamic_shapes): - dynamic_shapes = [(Dim.AUTO)] + dynamic_shapes = [(Dim.AUTO,)] fx_model = get_torch_fx_model(model, example_input, dynamic_shapes=dynamic_shapes) From eda05491b3e2d4569c1709a53dd3cb8e9fa0492e Mon Sep 17 00:00:00 2001 From: anzr299 Date: Wed, 29 Jan 2025 18:05:13 +0400 Subject: [PATCH 3/6] pre-commit fix --- tests/torch/fx/test_models.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/torch/fx/test_models.py b/tests/torch/fx/test_models.py index 6e8eb8ae457..7ca4bc6e822 100644 --- a/tests/torch/fx/test_models.py +++ b/tests/torch/fx/test_models.py @@ -25,6 +25,7 @@ import torch.utils.data import torch.utils.data.distributed import torchvision.models as models +from torch.export.dynamic_shapes import Dim import nncf from nncf.common.graph.graph import NNCFNodeName @@ -44,7 +45,6 @@ from tests.torch.test_models.synthetic import MultiBranchesConnectedModel from tests.torch.test_models.synthetic import ShortTransformer from tests.torch.test_models.synthetic import YOLO11N_SDPABlock -from torch.export.dynamic_shapes import Dim FX_DIR_NAME = Path("fx") FX_QUANTIZED_DIR_NAME = FX_DIR_NAME / "quantized" @@ -175,6 +175,7 @@ def test_model(test_case: ModelCase): ), ) + @pytest.mark.parametrize("enable_dynamic_shapes", [True, False]) @pytest.mark.parametrize("compress_weights", [True, False]) @pytest.mark.parametrize( @@ -182,15 +183,21 @@ def test_model(test_case: ModelCase): TEST_MODELS_QUANIZED, ids=[m[0].model_id for m in TEST_MODELS_QUANIZED], ) -def test_quantized_model(model_case: ModelCase, quantization_parameters, compress_weights: bool, compress_n_qdq: int, enable_dynamic_shapes: bool): +def test_quantized_model( + model_case: ModelCase, + quantization_parameters, + compress_weights: bool, + compress_n_qdq: int, + enable_dynamic_shapes: bool, +): model = model_case.model_builder() dtype = torch.int32 if model_case.model_id == "synthetic_transformer" else torch.float32 example_input = torch.ones(model_case.input_shape, dtype=dtype) dynamic_shapes = None enable_dynamic_shapes = model_case.model_id == "synthetic_transformer" and enable_dynamic_shapes - if(enable_dynamic_shapes): + if enable_dynamic_shapes: dynamic_shapes = [(Dim.AUTO,)] - + fx_model = get_torch_fx_model(model, example_input, dynamic_shapes=dynamic_shapes) def transform_fn(data_item): From 36aca2fe35e212f7395aeb05f803f00b01fdc66d Mon Sep 17 00:00:00 2001 From: anzr299 Date: Wed, 29 Jan 2025 19:01:05 +0400 Subject: [PATCH 4/6] Add test case to check data type on edge tensor shape --- tests/torch/fx/test_models.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/torch/fx/test_models.py b/tests/torch/fx/test_models.py index 7ca4bc6e822..f46bbfca734 100644 --- a/tests/torch/fx/test_models.py +++ b/tests/torch/fx/test_models.py @@ -227,6 +227,21 @@ def transform_fn(data_item): check_fq_values(quantized_model) check_compressed_post_quantized(quantized_model) +def test_dynamic_edge(): + model = MultiBranchesConnectedModel() + ex_inputs = torch.ones((1, 3, 3, 3)) + dynamic_shapes = [(Dim.AUTO,Dim.AUTO,Dim.AUTO,Dim.AUTO,)] + fx_model = get_torch_fx_model(model, ex_inputs, dynamic_shapes=dynamic_shapes) + nncf_graph = GraphConverter.create_nncf_graph(fx_model) + + for edge in nncf_graph.get_all_edges(): + edge_shape = edge.tensor_shape + assert isinstance(edge_shape, tuple) + for dim in edge_shape: + assert isinstance(dim, (int, str)) + assert not isinstance(dim, torch.SymInt) + + print("All edges have valid shape types (int or str).") def check_fq_values(quantized_model): for node in quantized_model.graph.nodes: From 3057cf20541b6e17b421713c78b4b2c3c93c3ac9 Mon Sep 17 00:00:00 2001 From: anzr299 Date: Wed, 29 Jan 2025 19:01:41 +0400 Subject: [PATCH 5/6] clean --- tests/torch/fx/test_models.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/torch/fx/test_models.py b/tests/torch/fx/test_models.py index f46bbfca734..bdec90477e3 100644 --- a/tests/torch/fx/test_models.py +++ b/tests/torch/fx/test_models.py @@ -241,8 +241,6 @@ def test_dynamic_edge(): assert isinstance(dim, (int, str)) assert not isinstance(dim, torch.SymInt) - print("All edges have valid shape types (int or str).") - def check_fq_values(quantized_model): for node in quantized_model.graph.nodes: if node.target not in DEQUANTIZE_NODE_TARGETS: From e25e5945c483110c22afc5cb5ff21942185ec9ad Mon Sep 17 00:00:00 2001 From: anzr299 Date: Wed, 29 Jan 2025 19:27:00 +0400 Subject: [PATCH 6/6] pre commit fix --- tests/torch/fx/test_models.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/torch/fx/test_models.py b/tests/torch/fx/test_models.py index bdec90477e3..4b967e2a436 100644 --- a/tests/torch/fx/test_models.py +++ b/tests/torch/fx/test_models.py @@ -227,10 +227,18 @@ def transform_fn(data_item): check_fq_values(quantized_model) check_compressed_post_quantized(quantized_model) + def test_dynamic_edge(): model = MultiBranchesConnectedModel() ex_inputs = torch.ones((1, 3, 3, 3)) - dynamic_shapes = [(Dim.AUTO,Dim.AUTO,Dim.AUTO,Dim.AUTO,)] + dynamic_shapes = [ + ( + Dim.AUTO, + Dim.AUTO, + Dim.AUTO, + Dim.AUTO, + ) + ] fx_model = get_torch_fx_model(model, ex_inputs, dynamic_shapes=dynamic_shapes) nncf_graph = GraphConverter.create_nncf_graph(fx_model) @@ -241,6 +249,7 @@ def test_dynamic_edge(): assert isinstance(dim, (int, str)) assert not isinstance(dim, torch.SymInt) + def check_fq_values(quantized_model): for node in quantized_model.graph.nodes: if node.target not in DEQUANTIZE_NODE_TARGETS: