diff --git a/nncf/experimental/torch/fx/nncf_graph_builder.py b/nncf/experimental/torch/fx/nncf_graph_builder.py index 36dc3b46e6e..3d162f27400 100644 --- a/nncf/experimental/torch/fx/nncf_graph_builder.py +++ b/nncf/experimental/torch/fx/nncf_graph_builder.py @@ -197,6 +197,9 @@ def get_edge_params( tensor = source_node.meta["val"] if isinstance(tensor, torch.Tensor): tensor_shape = tuple(tensor.shape) + tensor_shape = tuple(str(i) if isinstance(i, torch.SymInt) else i for i in tensor_shape) + if isinstance(tensor, torch.SymInt): + tensor_shape = (str(tensor),) if tensor_shape is None: # TODO(dlyakhov): Refactor algorithms to always have knowns edges shapes. diff --git a/tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/synthetic_transformer.dot b/tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/synthetic_transformer.dot new file mode 100644 index 00000000000..f9d6b83d2ed --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/dynamic_shapes/post_quantization_compressed/synthetic_transformer.dot @@ -0,0 +1,49 @@ +strict digraph { +"0 wte_weight" [id=0, type=get_attr]; +"1 linear_bias" [id=1, type=get_attr]; +"2 lm_head_bias" [id=2, type=get_attr]; +"3 input_ids" [id=3, type=input]; +"4 embedding" [id=4, type=embedding]; +"5 embedding_0_0_nncf_smooth_quant_0" [id=5, type=call_module]; +"6 quantize_per_tensor_default" [id=6, type=quantize_per_tensor]; +"7 dequantize_per_tensor_default" [id=7, type=dequantize_per_tensor]; +"8 scale_updated_constant0" [id=8, type=get_attr]; +"9 compressed_weight_updated_constant0" [id=9, type=get_attr]; +"10 mul_tensor" [id=10, type=mul]; +"11 zero_point_updated_constant0" [id=11, type=get_attr]; +"12 sub_tensor" [id=12, type=sub]; +"13 linear" [id=13, type=linear]; +"14 linear_0_0_nncf_smooth_quant_0" [id=14, type=call_module]; +"15 quantize_per_tensor_default_1" [id=15, type=quantize_per_tensor]; +"16 dequantize_per_tensor_default_1" [id=16, type=dequantize_per_tensor]; +"17 scale_updated_constant1" [id=17, type=get_attr]; +"18 compressed_weight_updated_constant1" [id=18, type=get_attr]; +"19 mul_tensor_1" [id=19, type=mul]; +"20 zero_point_updated_constant1" [id=20, type=get_attr]; +"21 sub_tensor_1" [id=21, type=sub]; +"22 linear_1" [id=22, type=linear]; +"23 output" [id=23, type=output]; +"0 wte_weight" -> "4 embedding" [label="(10, 5)", style=solid]; +"1 linear_bias" -> "13 linear" [label="(5,)", style=solid]; +"2 lm_head_bias" -> "22 linear_1" [label="(10,)", style=solid]; +"3 input_ids" -> "4 embedding" [label="('s0',)", style=solid]; +"4 embedding" -> "5 embedding_0_0_nncf_smooth_quant_0" [label="('s0', 5)", style=solid]; +"5 embedding_0_0_nncf_smooth_quant_0" -> "6 quantize_per_tensor_default" [label="('s0', 5)", style=solid]; +"6 quantize_per_tensor_default" -> "7 dequantize_per_tensor_default" [label="('s0', 5)", style=solid]; +"7 dequantize_per_tensor_default" -> "13 linear" [label="('s0', 5)", style=solid]; +"8 scale_updated_constant0" -> "10 mul_tensor" [label="(5, 1)", style=solid]; +"9 compressed_weight_updated_constant0" -> "10 mul_tensor" [label="(5, 5)", style=solid]; +"10 mul_tensor" -> "12 sub_tensor" [label="(5, 5)", style=solid]; +"11 zero_point_updated_constant0" -> "12 sub_tensor" [label="(5, 1)", style=solid]; +"12 sub_tensor" -> "13 linear" [label="(5, 5)", style=solid]; +"13 linear" -> "14 linear_0_0_nncf_smooth_quant_0" [label="('s0', 5)", style=solid]; +"14 linear_0_0_nncf_smooth_quant_0" -> "15 quantize_per_tensor_default_1" [label="('s0', 5)", style=solid]; +"15 quantize_per_tensor_default_1" -> "16 dequantize_per_tensor_default_1" [label="('s0', 5)", style=solid]; +"16 dequantize_per_tensor_default_1" -> "22 linear_1" [label="('s0', 5)", style=solid]; +"17 scale_updated_constant1" -> "19 mul_tensor_1" [label="(10, 1)", style=solid]; +"18 compressed_weight_updated_constant1" -> "19 mul_tensor_1" [label="(10, 5)", style=solid]; +"19 mul_tensor_1" -> "21 sub_tensor_1" [label="(10, 5)", style=solid]; +"20 zero_point_updated_constant1" -> "21 sub_tensor_1" [label="(10, 1)", style=solid]; +"21 sub_tensor_1" -> "22 linear_1" [label="(10, 5)", style=solid]; +"22 linear_1" -> "23 output" [label="('s0', 10)", style=solid]; +} diff --git a/tests/torch/data/reference_graphs/fx/dynamic_shapes/quantized/synthetic_transformer.dot b/tests/torch/data/reference_graphs/fx/dynamic_shapes/quantized/synthetic_transformer.dot new file mode 100644 index 00000000000..0a0ad795c9a --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/dynamic_shapes/quantized/synthetic_transformer.dot @@ -0,0 +1,53 @@ +strict digraph { +"0 wte_weight" [id=0, type=get_attr]; +"1 linear_bias" [id=1, type=get_attr]; +"2 lm_head_bias" [id=2, type=get_attr]; +"3 input_ids" [id=3, type=input]; +"4 embedding" [id=4, type=embedding]; +"5 embedding_0_0_nncf_smooth_quant_0" [id=5, type=call_module]; +"6 quantize_per_tensor_default" [id=6, type=quantize_per_tensor]; +"7 dequantize_per_tensor_default" [id=7, type=dequantize_per_tensor]; +"8 linear_scale_0" [id=8, type=get_attr]; +"9 linear_zero_point_0" [id=9, type=get_attr]; +"10 compressed_weight_updated_constant0" [id=10, type=get_attr]; +"11 quantize_per_channel_default" [id=11, type=quantize_per_channel]; +"12 dequantize_per_channel_default" [id=12, type=dequantize_per_channel]; +"13 linear" [id=13, type=linear]; +"14 linear_0_0_nncf_smooth_quant_0" [id=14, type=call_module]; +"15 quantize_per_tensor_default_1" [id=15, type=quantize_per_tensor]; +"16 dequantize_per_tensor_default_1" [id=16, type=dequantize_per_tensor]; +"17 linear_1_scale_0" [id=17, type=get_attr]; +"18 linear_1_zero_point_0" [id=18, type=get_attr]; +"19 compressed_weight_updated_constant1" [id=19, type=get_attr]; +"20 quantize_per_channel_default_1" [id=20, type=quantize_per_channel]; +"21 dequantize_per_channel_default_1" [id=21, type=dequantize_per_channel]; +"22 linear_1" [id=22, type=linear]; +"23 output" [id=23, type=output]; +"0 wte_weight" -> "4 embedding" [label="(10, 5)", style=solid]; +"1 linear_bias" -> "13 linear" [label="(5,)", style=solid]; +"2 lm_head_bias" -> "22 linear_1" [label="(10,)", style=solid]; +"3 input_ids" -> "4 embedding" [label="('s0',)", style=solid]; +"4 embedding" -> "5 embedding_0_0_nncf_smooth_quant_0" [label="('s0', 5)", style=solid]; +"5 embedding_0_0_nncf_smooth_quant_0" -> "6 quantize_per_tensor_default" [label="('s0', 5)", style=solid]; +"6 quantize_per_tensor_default" -> "7 dequantize_per_tensor_default" [label="('s0', 5)", style=solid]; +"7 dequantize_per_tensor_default" -> "13 linear" [label="('s0', 5)", style=solid]; +"8 linear_scale_0" -> "11 quantize_per_channel_default" [label="(5,)", style=solid]; +"8 linear_scale_0" -> "12 dequantize_per_channel_default" [label="(5,)", style=solid]; +"9 linear_zero_point_0" -> "11 quantize_per_channel_default" [label="(5,)", style=solid]; +"9 linear_zero_point_0" -> "12 dequantize_per_channel_default" [label="(5,)", style=solid]; +"10 compressed_weight_updated_constant0" -> "11 quantize_per_channel_default" [label="(5, 5)", style=solid]; +"11 quantize_per_channel_default" -> "12 dequantize_per_channel_default" [label="(5, 5)", style=solid]; +"12 dequantize_per_channel_default" -> "13 linear" [label="(5, 5)", style=solid]; +"13 linear" -> "14 linear_0_0_nncf_smooth_quant_0" [label="('s0', 5)", style=solid]; +"14 linear_0_0_nncf_smooth_quant_0" -> "15 quantize_per_tensor_default_1" [label="('s0', 5)", style=solid]; +"15 quantize_per_tensor_default_1" -> "16 dequantize_per_tensor_default_1" [label="('s0', 5)", style=solid]; +"16 dequantize_per_tensor_default_1" -> "22 linear_1" [label="('s0', 5)", style=solid]; +"17 linear_1_scale_0" -> "20 quantize_per_channel_default_1" [label="(10,)", style=solid]; +"17 linear_1_scale_0" -> "21 dequantize_per_channel_default_1" [label="(10,)", style=solid]; +"18 linear_1_zero_point_0" -> "20 quantize_per_channel_default_1" [label="(10,)", style=solid]; +"18 linear_1_zero_point_0" -> "21 dequantize_per_channel_default_1" [label="(10,)", style=solid]; +"19 compressed_weight_updated_constant1" -> "20 quantize_per_channel_default_1" [label="(10, 5)", style=solid]; +"20 quantize_per_channel_default_1" -> "21 dequantize_per_channel_default_1" [label="(10, 5)", style=solid]; +"21 dequantize_per_channel_default_1" -> "22 linear_1" [label="(10, 5)", style=solid]; +"22 linear_1" -> "23 output" [label="('s0', 10)", style=solid]; +} diff --git a/tests/torch/fx/helpers.py b/tests/torch/fx/helpers.py index b51f6da16f4..9e599b38bdb 100644 --- a/tests/torch/fx/helpers.py +++ b/tests/torch/fx/helpers.py @@ -124,7 +124,7 @@ def visualize_fx_model(model: torch.fx.GraphModule, output_svg_path: str): def get_torch_fx_model( - model: torch.nn.Module, ex_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] + model: torch.nn.Module, ex_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]], dynamic_shapes=None ) -> torch.fx.GraphModule: """ Converts given module to GraphModule. @@ -151,7 +151,7 @@ def get_torch_fx_model( model.eval() with torch.no_grad(): with disable_patching(): - return torch.export.export_for_training(model, args=device_ex_input).module() + return torch.export.export_for_training(model, args=device_ex_input, dynamic_shapes=dynamic_shapes).module() def get_torch_fx_model_q_transformed(model: torch.nn.Module, ex_input: torch.Tensor) -> torch.fx.GraphModule: diff --git a/tests/torch/fx/test_models.py b/tests/torch/fx/test_models.py index 39538f8b7d9..4b967e2a436 100644 --- a/tests/torch/fx/test_models.py +++ b/tests/torch/fx/test_models.py @@ -25,6 +25,7 @@ import torch.utils.data import torch.utils.data.distributed import torchvision.models as models +from torch.export.dynamic_shapes import Dim import nncf from nncf.common.graph.graph import NNCFNodeName @@ -46,8 +47,12 @@ from tests.torch.test_models.synthetic import YOLO11N_SDPABlock FX_DIR_NAME = Path("fx") -FX_QUANTIZED_DIR_NAME = Path("fx") / "quantized" -FX_QUANTIZED_COMPRESSED_DIR_NAME = Path("fx") / "post_quantization_compressed" +FX_QUANTIZED_DIR_NAME = FX_DIR_NAME / "quantized" +FX_QUANTIZED_COMPRESSED_DIR_NAME = FX_DIR_NAME / "post_quantization_compressed" + +FX_DYNAMIC_DIR = FX_DIR_NAME / "dynamic_shapes" +FX_DYNAMIC_QUANTIZED_DIR_NAME = FX_DYNAMIC_DIR / "quantized" +FX_DYNAMIC_QUANTIZED_COMPRESSED_DIR_NAME = FX_DYNAMIC_DIR / "post_quantization_compressed" @dataclass @@ -171,18 +176,29 @@ def test_model(test_case: ModelCase): ) +@pytest.mark.parametrize("enable_dynamic_shapes", [True, False]) @pytest.mark.parametrize("compress_weights", [True, False]) @pytest.mark.parametrize( ("model_case", "quantization_parameters", "compress_n_qdq"), TEST_MODELS_QUANIZED, ids=[m[0].model_id for m in TEST_MODELS_QUANIZED], ) -def test_quantized_model(model_case: ModelCase, quantization_parameters, compress_weights: bool, compress_n_qdq: int): +def test_quantized_model( + model_case: ModelCase, + quantization_parameters, + compress_weights: bool, + compress_n_qdq: int, + enable_dynamic_shapes: bool, +): model = model_case.model_builder() dtype = torch.int32 if model_case.model_id == "synthetic_transformer" else torch.float32 example_input = torch.ones(model_case.input_shape, dtype=dtype) + dynamic_shapes = None + enable_dynamic_shapes = model_case.model_id == "synthetic_transformer" and enable_dynamic_shapes + if enable_dynamic_shapes: + dynamic_shapes = [(Dim.AUTO,)] - fx_model = get_torch_fx_model(model, example_input) + fx_model = get_torch_fx_model(model, example_input, dynamic_shapes=dynamic_shapes) def transform_fn(data_item): return data_item.to("cpu") @@ -198,7 +214,11 @@ def transform_fn(data_item): # Uncomment to visualize torch fx graph # from tests.torch.fx.helpers import visualize_fx_model # visualize_fx_model(quantized_model, f"{model_case.model_id}_int8.svg") - save_dir = FX_QUANTIZED_COMPRESSED_DIR_NAME if compress_weights else FX_QUANTIZED_DIR_NAME + if dynamic_shapes: + save_dir = FX_DYNAMIC_QUANTIZED_COMPRESSED_DIR_NAME if compress_weights else FX_DYNAMIC_QUANTIZED_DIR_NAME + else: + save_dir = FX_QUANTIZED_COMPRESSED_DIR_NAME if compress_weights else FX_QUANTIZED_DIR_NAME + nncf_graph = GraphConverter.create_nncf_graph(quantized_model) check_graph(nncf_graph, get_dot_filename(model_case.model_id), save_dir, extended=True) q_nodes, dq_nodes = count_q_dq(quantized_model) @@ -208,6 +228,28 @@ def transform_fn(data_item): check_compressed_post_quantized(quantized_model) +def test_dynamic_edge(): + model = MultiBranchesConnectedModel() + ex_inputs = torch.ones((1, 3, 3, 3)) + dynamic_shapes = [ + ( + Dim.AUTO, + Dim.AUTO, + Dim.AUTO, + Dim.AUTO, + ) + ] + fx_model = get_torch_fx_model(model, ex_inputs, dynamic_shapes=dynamic_shapes) + nncf_graph = GraphConverter.create_nncf_graph(fx_model) + + for edge in nncf_graph.get_all_edges(): + edge_shape = edge.tensor_shape + assert isinstance(edge_shape, tuple) + for dim in edge_shape: + assert isinstance(dim, (int, str)) + assert not isinstance(dim, torch.SymInt) + + def check_fq_values(quantized_model): for node in quantized_model.graph.nodes: if node.target not in DEQUANTIZE_NODE_TARGETS: