Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TF QAT docs. Deprecate TF create_compressed_model method #3217

Merged
merged 19 commits into from
Feb 5, 2025
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def transform_fn(data_item):
calibration_dataset = nncf.Dataset(val_dataset, transform_fn)
# Step 3: Run the quantization pipeline
quantized_model = nncf.quantize(model, calibration_dataset)
# Step 4: Remove auxiliary layers and operations added during the quantization process,
# resulting in a clean, fully quantized model ready for deployment.
stripped_model = nncf.strip(quantized_model)
```

</details>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Use NNCF for Quantization Aware Training in PyTorch
# Use NNCF for Quantization Aware Training

This is a step-by-step tutorial on how to integrate the NNCF package into the existing PyTorch project (please see the [TensorFlow quantization documentation](../other_algorithms/LegacyQuantization.md) for integration tutorial for the existing TensorFlow project).
The use case implies that the user already has a training pipeline that reproduces training of the model in the floating point precision and pretrained model.
This is a step-by-step tutorial on how to integrate the NNCF package into the existing PyTorch or TensorFlow projects.
The use case implies that the user already has a training pipeline that reproduces training of the model in the floating point precision and pretrained model.
The task is to prepare this model for accelerated inference by simulating the compression at train time.
Please refer to this [document](/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md) for details of the implementation.

Expand All @@ -11,11 +11,24 @@ Please refer to this [document](/docs/usage/training_time_compression/other_algo

Quantize the model using the [Post Training Quantization](../../post_training_compression/post_training_quantization/Usage.md) method.

<details open><summary><b>PyTorch</b></summary>

```python
model = TorchModel() # instance of torch.nn.Module
quantized_model = nncf.quantize(model, ...)
```

</details>

<details><summary><b>TensorFlow</b></summary>

```python
model = TensorFlowModel() # instance of tf.keras.Model
quantized_model = nncf.quantize(model, ...)
```

</details>

### Step 2: Run the training pipeline

At this point, the NNCF is fully integrated into your training pipeline.
Expand All @@ -27,27 +40,46 @@ Important points you should consider when training your networks with compressio

### Step 3: Export the compressed model

After the compressed model has been fine-tuned to acceptable accuracy and compression stages, you can export it. There are two ways to export a model:
After the compressed model has been fine-tuned to acceptable accuracy and compression stages, you can export it.

<details open><summary><b>PyTorch</b></summary>

Trace the model via inference in framework operations.

1. Trace the model via inference in framework operations.
```python
# To OpenVINO format
import openvino as ov
ov_quantized_model = ov.convert_model(quantized_model.cpu(), example_input=dummy_input)
```

</details>

<details><summary><b>TensorFlow</b></summary>

```python
# To OpenVINO format
import openvino as ov

# Removes auxiliary layers and operations added during the quantization process,
# resulting in a clean, fully quantized model ready for deployment.
stripped_model = nncf.strip(quantized_model)

ov_quantized_model = ov.convert_model(stripped_model, share_weights=False)
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
```

```python
# To OpenVINO format
import openvino as ov
ov_quantized_model = ov.convert_model(quantized_model.cpu(), example_input=dummy_input)
```
</details>

## Saving and loading compressed models
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved

<details open><summary><b>PyTorch</b></summary>

The complete information about compression is defined by a compressed model and a NNCF config.
The model characterizes the weights and topology of the network. The NNCF config - how to restore additional modules intoduced by NNCF.
The NNCF config can be obtained by `quantized_model.nncf.get_config()` on saving and passed to the
`nncf.torch.load_from_config` helper function to load additional modules from the given NNCF config.
The quantized model saving allows to load quantized modules to the target model in a new python process and
requires only example input for the target module, corresponding NNCF config and the quantized model state dict.

### Saving and loading compressed models in PyTorch

```python
# save part
quantized_model = nncf.quantize(model, calibration_dataset)
Expand All @@ -70,10 +102,52 @@ quantized_model.load_state_dict(state_dict)

You can save the `compressed_model` object `torch.save` as usual: via `state_dict` and `load_state_dict` methods.

</details>

<details><summary><b>TensorFlow</b></summary>

To save a model checkpoint, use the following API:

```python
from nncf.tensorflow import get_config
from nncf.tensorflow.callbacks.checkpoint_callback import CheckpointManagerCallback

config = get_config(quantized_model)
checkpoint = tf.train.Checkpoint(model=quantized_model,
config=config,
... # the rest of the user-defined objects to save
)
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
callbacks = []
callbacks.append(CheckpointManagerCallback(checkpoint, path_to_checkpoint))
...
quantized_model.fit(..., callbacks=callbacks)
```

To restore the model from checkpoint, use the following API:

```python
from nncf.tensorflow import ModelConfig
from nncf.tensorflow import load_from_config

checkpoint = tf.train.Checkpoint(config=ModelConfig())
checkpoint.restore(path_to_checkpoint)

quantized_model = load_from_config(model, checkpoint.config)
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved

checkpoint = tf.train.Checkpoint(model=quantized_model
... # the rest of the user-defined objects to load
)
checkpoint.restore(path_to_checkpoint)
```

</details>

## Advanced usage

### Compression of custom modules

<details open><summary><b>PyTorch</b></summary>

With no target model code modifications, NNCF only supports native PyTorch modules with respect to trainable parameter (weight) compressed, such as `torch.nn.Conv2d`.
If your model contains a custom, non-PyTorch standard module with trainable weights that should be compressed, you can register it using the `@nncf.register_module` decorator:

Expand All @@ -91,4 +165,9 @@ If registered module should be ignored by specific algorithms use `ignored_algor

In the example above, the NNCF-compressed models that contain instances of `MyModule` will have the corresponding modules extended with functionality that will allow NNCF to quantize the `weight` parameter of `MyModule` before it takes part in `MyModule`'s `forward` calculation.

See a PyTorch [example](/examples/quantization_aware_training/torch/resnet18/README.md) for **Quantization** Compression scenario on Tiny ImageNet-200 dataset.
</details>

## Examples

- See a PyTorch [example](/examples/quantization_aware_training/torch/resnet18/README.md) for **Quantization** Compression scenario on Tiny ImageNet-200 dataset.
- See a TensorFlow [example](/examples/quantization_aware_training/tensorflow/mobilenet_v2/README.md) for **Quantization** Compression scenario on imagenette/320px-v2 dataset.
3 changes: 3 additions & 0 deletions nncf/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@
)
from nncf.tensorflow.helpers import create_compressed_model as create_compressed_model
from nncf.tensorflow.helpers.callback_creation import create_compression_callbacks as create_compression_callbacks
from nncf.tensorflow.helpers.model_creation import load_from_config
from nncf.tensorflow.initialization import register_default_init_args as register_default_init_args
from nncf.tensorflow.pruning.filter_pruning import algorithm as filter_pruning_algorithm

# Required for correct COMPRESSION_ALGORITHMS registry functioning
from nncf.tensorflow.quantization import algorithm as quantization_algorithm
from nncf.tensorflow.sparsity.magnitude import algorithm as magnitude_sparsity_algorithm
from nncf.tensorflow.sparsity.rb import algorithm as rb_sparsity_algorithm
from nncf.tensorflow.utils.state import ModelConfig
from nncf.tensorflow.utils.state import get_config
35 changes: 35 additions & 0 deletions nncf/tensorflow/helpers/model_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nncf import NNCFConfig
from nncf.api.compression import CompressionAlgorithmController
from nncf.common.compression import BaseCompressionAlgorithmController as BaseController
from nncf.common.deprecation import warning_deprecated
from nncf.common.utils.api_marker import api
from nncf.config.extractors import extract_algorithm_names
from nncf.config.telemetry_extractors import CompressionStartedFromConfig
Expand All @@ -29,8 +30,12 @@
from nncf.tensorflow.algorithm_selector import get_compression_algorithm_builder
from nncf.tensorflow.api.composite_compression import TFCompositeCompressionAlgorithmBuilder
from nncf.tensorflow.api.compression import TFCompressionAlgorithmBuilder
from nncf.tensorflow.graph.model_transformer import TFModelTransformer
from nncf.tensorflow.graph.transformations.layout import TFTransformationLayout
from nncf.tensorflow.graph.utils import is_keras_layer_model
from nncf.tensorflow.helpers.utils import get_built_model
from nncf.tensorflow.quantization.algorithm import QuantizationBuilder
from nncf.tensorflow.utils.state import ModelConfig


def create_compression_algorithm_builder(config: NNCFConfig, should_init: bool) -> TFCompressionAlgorithmBuilder:
Expand Down Expand Up @@ -80,6 +85,19 @@ def create_compressed_model(
:return: A tuple of the compression controller for the requested algorithm(s) and the model object with additional
modifications necessary to enable algorithm-specific compression during fine-tuning.
"""

warning_deprecated(
"The 'nncf.tensorflow.create_compressed_model' function is deprecated and will be removed in a "
"future release.\n"
"To perform post training quantization (PTQ) or quantization aware training (QAT),"
" use the nncf.quantize() API:\n"
" - https://github.com/openvinotoolkit/nncf?tab=readme-ov-file#post-training-quantization\n"
" - https://github.com/openvinotoolkit/nncf?tab=readme-ov-file#training-time-quantization\n"
"Examples:\n"
" - https://github.com/openvinotoolkit/nncf/tree/develop/examples/post_training_quantization/tensorflow\n"
" - https://github.com/openvinotoolkit/nncf/tree/develop/examples/quantization_aware_training/tensorflow"
)

if is_experimental_quantization(config):
if is_keras_layer_model(model):
raise ValueError(
Expand Down Expand Up @@ -126,3 +144,20 @@ def get_input_signature(config: NNCFConfig):
input_signature.append(tf.TensorSpec(shape=shape, dtype=tf.float32))

return input_signature if len(input_signature) > 1 else input_signature[0]


def load_from_config(model: tf.keras.Model, config: ModelConfig) -> tf.keras.Model:
"""
TODO(TF)

:param model:
:parem config:
:return:
"""
transformation_layout = TFTransformationLayout()
# pylint: disable=protected-access
insertion_commands, _ = QuantizationBuilder._build_insertion_commands_for_quantizer_setup(config.quantizer_setup)
for command in insertion_commands:
transformation_layout.register(command)
model_transformer = TFModelTransformer(model)
return model_transformer.transform(transformation_layout)
23 changes: 14 additions & 9 deletions nncf/tensorflow/quantization/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,17 @@ def _get_half_range(
return True
return False

def _create_quantizer(self, name: str, qspec: TFQuantizerSpec) -> Quantizer:
@staticmethod
def _create_quantizer(name: str, qspec: TFQuantizerSpec) -> Quantizer:
quantizer_cls = NNCF_QUANTIZATION_OPERATIONS.get(qspec.mode)
return quantizer_cls(name, qspec)

@staticmethod
def _build_insertion_commands_for_quantizer_setup(
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
self, quantizer_setup: TFQuantizationSetup
) -> List[TFInsertionCommand]:
quantizer_setup: TFQuantizationSetup,
) -> Tuple[List[TFInsertionCommand], List[str]]:
insertion_commands = []
op_names = []
quantization_points = quantizer_setup.get_quantization_points()
non_unified_scales_quantization_point_ids = set(range(len(quantization_points)))

Expand All @@ -365,7 +368,7 @@ def _build_insertion_commands_for_quantizer_setup(
quantizer_spec = qp.quantizer_spec
op_name = qp.op_name + "/unified_scale_group"
quantizer = FakeQuantize(quantizer_spec, name=op_name)
self._op_names.append(quantizer.op_name)
op_names.append(quantizer.op_name)
target_points = []
for us_qp_id in unified_scales_group:
non_unified_scales_quantization_point_ids.discard(us_qp_id)
Expand All @@ -387,24 +390,26 @@ def _build_insertion_commands_for_quantizer_setup(
quantizer_spec = quantization_point.quantizer_spec
target_point = quantization_point.target_point
if quantization_point.is_weight_quantization():
quantizer = self._create_quantizer(op_name, quantizer_spec)
self._op_names.append(op_name)
quantizer = QuantizationBuilder._create_quantizer(op_name, quantizer_spec)
op_names.append(op_name)
else:
quantizer = FakeQuantize(quantizer_spec, name=op_name)
self._op_names.append(quantizer.op_name)
op_names.append(quantizer.op_name)
command = TFInsertionCommand(
target_point=target_point,
callable_object=quantizer,
priority=TransformationPriority.QUANTIZATION_PRIORITY,
)
insertion_commands.append(command)
return insertion_commands
return insertion_commands, op_names

def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLayout:
transformations = TFTransformationLayout()
if self._quantizer_setup is None:
self._quantizer_setup = self._get_quantizer_setup(model)
insertion_commands = self._build_insertion_commands_for_quantizer_setup(self._quantizer_setup)
insertion_commands, self._op_names = QuantizationBuilder._build_insertion_commands_for_quantizer_setup(
self._quantizer_setup
)
for command in insertion_commands:
transformations.register(command)
return transformations
Expand Down
5 changes: 4 additions & 1 deletion nncf/tensorflow/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ def quantize_impl(
]
)

_, compressed_model = create_compressed_model(model=model, config=nncf_config)
compression_ctrl, compressed_model = create_compressed_model(model=model, config=nncf_config)
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved

nncf_model_config = compression_ctrl.get_compression_state()["builder_state"]["quantization"]
setattr(compressed_model, "_nncf_model_config", nncf_model_config)
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved

return compressed_model
44 changes: 43 additions & 1 deletion nncf/tensorflow/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
# limitations under the License.

import json
from typing import Any, Dict
from typing import Any, Dict, Optional

import tensorflow as tf

from nncf.common.compression import BaseCompressionAlgorithmController
from nncf.tensorflow.quantization.algorithm import TFQuantizationSetup

# TODO(achurkin): remove pylint ignore after 120296 ticked is fixed

Expand Down Expand Up @@ -86,3 +87,44 @@ def deserialize(self, string_value: str) -> None:
:param string_value: A serialized compression state.
"""
self._state = json.loads(string_value)


class ModelConfig(tf.train.experimental.PythonState):
"""
TODO(TF)
"""

def __init__(self, quantizer_setup: Optional[TFQuantizationSetup] = None):
""" """
self.quantizer_setup = quantizer_setup

def serialize(self) -> str:
"""
Callback to serialize the model config.

:return: A serialized model config.
"""
data = {"quantizer_setup": self.quantizer_setup.get_state()}
return json.dumps(data)

def deserialize(self, string_value: str) -> None:
"""
Callback to deserialize the model config.

:param string_value: A serialized model config.
"""
data = json.loads(string_value)
self.quantizer_setup = TFQuantizationSetup.from_state(data["quantizer_setup"])


def get_config(model: tf.keras.Model) -> ModelConfig:
"""
TODO(TF)

:param model:
:return:
"""
data = getattr(model, "_nncf_model_config")
delattr(model, "_nncf_model_config")
quantizer_setup = TFQuantizationSetup.from_state(data["quantizer_setup"])
return ModelConfig(quantizer_setup)
2 changes: 1 addition & 1 deletion nncf/torch/model_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def create_compressed_model(
warning_deprecated(
"The 'nncf.torch.create_compressed_model' function is deprecated and will be removed in a future release.\n"
"To perform post training quantization (PTQ) or quantization aware training (QAT),"
" use the new nncf.quantize() API:\n"
" use the nncf.quantize() API:\n"
" - https://github.com/openvinotoolkit/nncf?tab=readme-ov-file#post-training-quantization\n"
" - https://github.com/openvinotoolkit/nncf?tab=readme-ov-file#training-time-quantization\n"
"Examples:\n"
Expand Down