Skip to content

Commit

Permalink
Add NNCF pruning and distillation in documentation (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix authored Mar 2, 2023
1 parent 21c24c5 commit 5da170f
Showing 1 changed file with 79 additions and 2 deletions.
81 changes: 79 additions & 2 deletions docs/source/optimization_ov.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License.

# Optimization

🤗 Optimum Intel provides an `optimum.openvino` package that enables you to apply a variety of model compression methods such as quantization on many models hosted on the 🤗 hub using the [NNCF](https://docs.openvino.ai/2022.1/docs_nncf_introduction.html) framework.
🤗 Optimum Intel provides an `optimum.openvino` package that enables you to apply a variety of model compression methods such as quantization, pruning, on many models hosted on the 🤗 hub using the [NNCF](https://docs.openvino.ai/2022.1/docs_nncf_introduction.html) framework.


## Post-training optimization
Expand Down Expand Up @@ -65,7 +65,7 @@ The `quantize()` method applies post-training static quantization and export the

## Training-time optimization

Apart from optimizing a model after training like post-training quantization above, `optimum.openvino` also provides optimization methods during training, namely Quantization-Aware Training (QAT).
Apart from optimizing a model after training like post-training quantization above, `optimum.openvino` also provides optimization methods during training, namely Quantization-Aware Training (QAT) and Joint Pruning, Quantization and Distillation (JPQD).


### Quantization-Aware Training (QAT)
Expand Down Expand Up @@ -118,6 +118,83 @@ metrics = trainer.evaluate()
trainer.save_model()
```


### Joint Pruning, Quantization and Distillation (JPQD)

Other than quantization, compression methods like pruning and distillation are common in further improving the task performance and efficiency. Structured pruning slims a model for lower computational demands while distillation leverages knowledge of a teacher, usually, larger model to improve model prediction. Combining these methods with quantization can result in optimized model with significant efficiency improvement while enjoying good task accuracy retention. In `optimum.openvino`, `OVTrainer` provides the capability to jointly prune, quantize and distill a model during training. Following is an example on how to perform the optimization on BERT-base for the sst-2 task.

First, we create a config dictionary to specify the target algorithms. As `optimum.openvino` relies on NNCF as backend, the config format follows NNCF specifications (see [here](https://github.com/openvinotoolkit/nncf/tree/develop/docs/compression_algorithms)). In the example config below, we specify pruning and quantization in a list of compression with thier hyperparameters. The pruning method closely resembles the work of [Lagunas et al., 2021, Block Pruning For Faster Transformers](https://arxiv.org/pdf/2109.04838.pdf) whereas the quantization refers to QAT. With this configuration, the model under optimization will be initialized with pruning and quantization operators at the beginning of the training.

```python
compression_config = [
{
"compression":
{
"algorithm": "movement_sparsity",
"params": {
"warmup_start_epoch": 1,
"warmup_end_epoch": 4,
"importance_regularization_factor": 0.01,
"enable_structured_masking": True
},
"sparse_structure_by_scopes": [
{"mode": "block", "sparse_factors": [32, 32], "target_scopes": "{re}.*BertAttention.*"},
{"mode": "per_dim", "axis": 0, "target_scopes": "{re}.*BertIntermediate.*"},
{"mode": "per_dim", "axis": 1, "target_scopes": "{re}.*BertOutput.*"},
],
"ignored_scopes": ["{re}.*NNCFEmbedding", "{re}.*pooler.*", "{re}.*LayerNorm.*"]
}
},
{
"algorithm": "quantization",
"weights": {"mode": "symmetric"}
"activations": { "mode": "symmetric"},
}
]
```

> Known limitation: Current structured pruning with movement sparsity only supports *BERT, Wav2vec2 and Swin* family of models. See [here](https://github.com/openvinotoolkit/nncf/blob/develop/nncf/experimental/torch/sparsity/movement/MovementSparsity.md) for more information.

Once we have the config ready, we can start develop the training pipeline like the snippet below. Since we are customizing joint compression with config above, notice that `OVConfig` is initialized with config dictionary (JSON parsing to python dictionary is skipped for brevity). As for distillation, users are required to load the teacher model, it is just like a normal model loading with transformers API. `OVTrainingArguments` extends transformers' `TrainingArguments` with distillation hyperparameters, i.e. distillation weightage and temperature for ease of use. The snippet below shows how we load a teacher model and create training arguments with `OVTrainingArguments`. Subsequently, the teacher model, with the instantiated `OVConfig` and `OVTrainingArguments` are fed to `OVTrainer`. Voila! that is all we need, the rest of the pipeline is identical to native transformers training.

```diff
-from optimum.intel.openvino.trainer import OVConfig, OVTrainer
+from optimum.intel.openvino import OVConfig, OVTrainer, OVTrainingArguments

# Load teacher model
+teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_model_or_path)

-ov_config = OVConfig()
+ov_config = OVConfig(compression=compression_config)

trainer = OVTrainer(
model=model,
+ teacher_model=teacher_model,
- args=TrainingArguments(save_dir, num_train_epochs=1.0, do_train=True, do_eval=True),
+ args=OVTrainingArguments(save_dir, num_train_epochs=1.0, do_train=True, do_eval=True, distillation_temperature=3, distillation_weight=0.9),
train_dataset=dataset["train"].select(range(300)),
eval_dataset=dataset["validation"],
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=default_data_collator,
+ ov_config=ov_config,
task="sequence-classification",
)

# Train the model like usual, internally the training is applied with pruning, quantization and distillation
train_result = trainer.train()
metrics = trainer.evaluate()
# Export the quantized model to OpenVINO IR format and save it
trainer.save_model()
```

More on the description and how to configure movement sparsity, see NNCF documentation [here](https://github.com/openvinotoolkit/nncf/blob/develop/nncf/experimental/torch/sparsity/movement/MovementSparsity.md).

More on available algorithms in NNCF, see documentation [here](https://github.com/openvinotoolkit/nncf/tree/develop/docs/compression_algorithms).

For complete JPQD scripts, please refer to examples provided [here](https://github.com/huggingface/optimum-intel/tree/main/examples/openvino).


## Inference with Transformers pipeline

After applying quantization on our model, we can then easily load it with our `OVModelFor<Task>` classes and perform inference with OpenVINO Runtime using the Transformers [pipelines](https://huggingface.co/docs/transformers/main/en/main_classes/pipelines).
Expand Down

0 comments on commit 5da170f

Please sign in to comment.