Skip to content

Commit

Permalink
Fix distillation to be compatible with neural-compressor v2.1 (#260)
Browse files Browse the repository at this point in the history
* Fix distillation for neural-compressor v2.1

* fix style

* update neural compressor min version

* fix style
  • Loading branch information
echarlaix authored Mar 29, 2023
1 parent 30b5f71 commit 31ed7f8
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 42 deletions.
4 changes: 2 additions & 2 deletions optimum/intel/neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@

logger = logging.getLogger(__name__)

NEURAL_COMPRESSOR_MINIMUM_VERSION = "2.0.0"
NEURAL_COMPRESSOR_MINIMUM_VERSION = "2.1.0"

if is_neural_compressor_version("<", NEURAL_COMPRESSOR_MINIMUM_VERSION):
raise ImportError(
Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(
"""
super().__init__()
self._original_model = model
self.eval_fn = eval_fn
self.eval_fn = eval_fn if eval_fn is not None else lambda model: 1
self.calibration_fn = calibration_fn
self.task = task
self.seed = seed
Expand Down
77 changes: 47 additions & 30 deletions optimum/intel/neural_compressor/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
import warnings
from collections.abc import Mapping
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import datasets
Expand All @@ -32,22 +33,37 @@
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
from transformers import Trainer
from transformers.data.data_collator import DataCollator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.deepspeed import deepspeed_init
from transformers.file_utils import WEIGHTS_NAME

# Integrations must be imported before ML frameworks:
from transformers.integrations import hp_params
from transformers.modeling_utils import get_parameter_dtype, unwrap_model
from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype, unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.pytorch_utils import is_torch_less_than_1_11
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import TRAINER_STATE_NAME
from transformers.trainer_callback import TrainerState
from transformers.trainer_callback import TrainerCallback, TrainerState
from transformers.trainer_pt_utils import IterableDatasetShard
from transformers.trainer_utils import HPSearchBackend, ShardedDDPOption, TrainOutput, has_length, speed_metrics
from transformers.trainer_utils import (
EvalPrediction,
HPSearchBackend,
ShardedDDPOption,
TrainOutput,
has_length,
speed_metrics,
)
from transformers.training_args import TrainingArguments
from transformers.utils import is_sagemaker_mp_enabled, logging

from neural_compressor import training
from neural_compressor.compression import DistillationCallbacks
from neural_compressor.conf.pythonic_config import _BaseQuantizationConfig
from neural_compressor.experimental.export import torch_to_fp32_onnx, torch_to_int8_onnx
from neural_compressor.model.torch_model import PyTorchModel
from optimum.exporters import TasksManager

from ..utils.import_utils import is_neural_compressor_version
from .utils import MIN_QDQ_ONNX_OPSET, ONNX_WEIGHTS_NAME, TRAINING_ARGS_NAME
Expand All @@ -63,22 +79,6 @@
logger = logging.get_logger(__name__)


from itertools import chain
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

from transformers.data.data_collator import DataCollator
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import TrainingArguments

from neural_compressor import training
from neural_compressor.conf.pythonic_config import _BaseQuantizationConfig
from neural_compressor.model.torch_model import PyTorchModel
from optimum.exporters import TasksManager


class INCTrainer(Trainer):
"""
INCTrainer enables Intel Neural Compression quantization aware training, pruning and distillation.
Expand Down Expand Up @@ -129,6 +129,7 @@ def __init__(
self.pruning_config = pruning_config
self.distillation_config = distillation_config
self._compression_manager = None
self.distillation_callback = None
self.save_onnx_model = save_onnx_model

# Attach dtype and architecture to the config
Expand All @@ -150,6 +151,11 @@ def __init__(
self.model = self._compression_manager.model.model
self.model_wrapped = self.model

for callback in self._compression_manager.callbacks.callbacks_list:
if isinstance(callback, DistillationCallbacks):
self.distillation_callback = callback
break

def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
):
Expand Down Expand Up @@ -518,7 +524,12 @@ def _inner_training_loop(

return TrainOutput(self.state.global_step, train_loss, metrics)

def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False, save_onnx_model=None):
def save_model(
self,
output_dir: Optional[str] = None,
_internal_call: bool = False,
save_onnx_model: Optional[bool] = None,
):
"""
Will save the model, so you can reload it using `from_pretrained()`.
Will only save from the main process.
Expand All @@ -529,9 +540,17 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
output_dir = self.args.output_dir

if self.args.should_save:
self._save(output_dir=output_dir, save_onnx_model=save_onnx_model)
self._save(
output_dir=output_dir,
save_onnx_model=save_onnx_model,
)

def _save(self, output_dir: Optional[str] = None, state_dict=None, save_onnx_model=False):
def _save(
self,
output_dir=None,
state_dict=None,
save_onnx_model=False,
):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir

Expand Down Expand Up @@ -689,11 +708,11 @@ def compute_loss(self, model, inputs, return_outputs=False):
teacher_outputs = self.distillation_config.teacher_model(**inputs)
teacher_outputs = self._get_logits(teacher_outputs)

if teacher_outputs is not None:
if teacher_outputs is not None and self.distillation_callback is not None:
distillation_loss = self.compute_distillation_loss(student_outputs, teacher_outputs)
loss *= self._compression_manager.callbacks.callbacks.criterion.loss_weights[0]
loss += distillation_loss * self._compression_manager.callbacks.callbacks.criterion.loss_weights[1]
loss /= sum(self._compression_manager.callbacks.callbacks.criterion.loss_weights)
loss *= self.distillation_callback.criterion.loss_weights[0]
loss += distillation_loss * self.distillation_callback.criterion.loss_weights[1]
loss /= sum(self.distillation_callback.criterion.loss_weights)

if isinstance(outputs, dict):
outputs["loss"] = loss
Expand Down Expand Up @@ -730,13 +749,11 @@ def compute_distillation_loss(self, student_outputs, teacher_outputs):
How the distillation loss is computed given the student and teacher outputs.
"""
distillation_loss = None
temperature = self._compression_manager.callbacks.callbacks.criterion.temperature
temperature = self.distillation_callback.criterion.temperature
for student_output, teacher_output in zip(student_outputs, teacher_outputs):
student_output = student_output / temperature
teacher_output = teacher_output / temperature
loss = self._compression_manager.callbacks.callbacks.criterion.teacher_student_loss_cal(
student_output, teacher_output
)
loss = self.distillation_callback.criterion.teacher_student_loss_cal(student_output, teacher_output)
distillation_loss = loss if distillation_loss is None else distillation_loss + loss
distillation_loss *= temperature**2
return distillation_loss
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

EXTRAS_REQUIRE = {
"neural-compressor": [
"neural-compressor>=2.0.0",
"neural-compressor>=2.1.0",
"onnx",
"onnxruntime",
"torch<2.0.0", # remove after neural-compressor next release
Expand Down
18 changes: 9 additions & 9 deletions tests/neural_compressor/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_dynamic_accuracy_strategy_quantization(self):
eval_dataset = load_dataset("squad", split="validation").select(range(64))
task_evaluator = evaluate.evaluator("question-answering")
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
tolerance_criterion = 0.05
tolerance_criterion = 0.1

def eval_fn(model):
qa_pipeline.model = model
Expand Down Expand Up @@ -221,8 +221,8 @@ def compute_metrics(p: EvalPrediction):
quantization_config=quantization_config,
task="sequence-classification",
args=TrainingArguments(tmp_dir, num_train_epochs=1.0, do_train=True, do_eval=False),
train_dataset=dataset["train"].select(range(64)),
eval_dataset=dataset["validation"].select(range(64)),
train_dataset=dataset["train"].select(range(8)),
eval_dataset=dataset["validation"].select(range(8)),
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=default_data_collator,
Expand Down Expand Up @@ -280,8 +280,8 @@ def compute_metrics(p: EvalPrediction):
pruning_config=pruning_config,
task="sequence-classification",
args=TrainingArguments(tmp_dir, num_train_epochs=1.0, do_train=True, do_eval=False),
train_dataset=dataset["train"].select(range(64)),
eval_dataset=dataset["validation"].select(range(64)),
train_dataset=dataset["train"].select(range(8)),
eval_dataset=dataset["validation"].select(range(8)),
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=default_data_collator,
Expand All @@ -307,7 +307,7 @@ def test_magnitude_pruning(self):
pruning_config = WeightPruningConfig(
pruning_type="magnitude",
start_step=0,
end_step=15,
end_step=1,
target_sparsity=target_sparsity,
pruning_scope="local",
)
Expand All @@ -330,7 +330,7 @@ def compute_metrics(p: EvalPrediction):
task="sequence-classification",
args=TrainingArguments(tmp_dir, num_train_epochs=2.0, do_train=True, do_eval=False),
train_dataset=dataset["train"].select(range(64)),
eval_dataset=dataset["validation"].select(range(64)),
eval_dataset=dataset["validation"].select(range(4)),
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=default_data_collator,
Expand Down Expand Up @@ -371,8 +371,8 @@ def compute_metrics(p: EvalPrediction):
distillation_config=distillation_config,
task="sequence-classification",
args=TrainingArguments(tmp_dir, num_train_epochs=2.0, do_train=True, do_eval=False),
train_dataset=dataset["train"].select(range(64)),
eval_dataset=dataset["validation"].select(range(64)),
train_dataset=dataset["train"].select(range(8)),
eval_dataset=dataset["validation"].select(range(8)),
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=default_data_collator,
Expand Down

0 comments on commit 31ed7f8

Please sign in to comment.