Skip to content

Commit

Permalink
Merge branch 'sa/quant_mod_refactor' of github.com:neuralmagic/sparse…
Browse files Browse the repository at this point in the history
…ml into sa/quant_mod_refactor
  • Loading branch information
horheynm committed May 1, 2024
2 parents bf7d0f6 + 90795bd commit 579d201
Show file tree
Hide file tree
Showing 57 changed files with 1,875 additions and 970 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ jobs:
- name: "Clean sparsezoo directory"
run: rm -r sparsezoo/
- name: "⚙️ Install dependencies"
run: pip3 install .[dev,torchvision,onnxruntime]
run: pip3 install .[dev,torchvision,onnxruntime,transformers]
- name: "🔬 Running pytorch tests"
run: make test TARGETS=pytorch
compat-pytorch-1_9-pytorch-tests:
Expand Down Expand Up @@ -194,7 +194,7 @@ jobs:
- name: "Clean sparsezoo directory"
run: rm -r sparsezoo/
- name: "⚙️ Install dependencies"
run: pip3 install .[dev,torchvision,onnxruntime] torch==1.9.1
run: pip3 install .[dev,torchvision,onnxruntime,transformers]
- name: "🔬 Running pytorch tests"
run: make test TARGETS=pytorch
compat-pytorch-1_9-onnx-tests:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ class SFTTrainer(SessionManagerMixIn, TRLSFTTrainer):

The new `SFTTrainer` class can now apply SparseML recipes and modifiers during
supervised finetuning, will full support for all of the original TRL features. The full
class is defined in [sft_trainer.py](sft_trainer.py) and requires very minimal
class is defined in the script `sft_trainer.py` and requires very minimal
additional code: just a dataset load override to support passing in tokenized datasets
to the Trainer.

### Examples

[ex_trl_sft_data.py](ex_trl_sft_data.py): finetunes a 50% sparse Llama-7b model,
* Script `ex_trl_sft_data.py`: finetunes a 50% sparse Llama-7b model,
using TRL's dataset preprocessing. Sparsity is maintained throughout training by
applying a `ConstantPruningModifier` recipe to the `SFTTrainer`

[ex_trl_distillation.py](ex_trl_distillation.py): finetunes a 50% sparse Llama-7b
* Script `ex_trl_distillation.py`: finetunes a 50% sparse Llama-7b
model using knowledge distillation from a dense Llama-7b model. Sparsity is maintained
throughout training with a `ConstantPruningModifier` and layer-wise knowledge
distillation is handled by the `OutputDistillationModifier`
22 changes: 15 additions & 7 deletions src/sparseml/core/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,15 +580,23 @@ def _get_yaml_dict(self) -> Dict[str, Any]:
# populate stages
stages = original_recipe_dict["stages"]
for stage_name, stage_list in stages.items():
# stage is always a list of size 1
stage = stage_list[0]
stage_dict = get_yaml_serializable_stage_dict(modifiers=stage["modifiers"])
for idx, stage in enumerate(stage_list):
if len(stage_list) > 1:
# resolve name clashes caused by combining recipes with
# duplicate stage names
final_stage_name = f"{stage_name}_{idx}"
else:
final_stage_name = stage_name
stage_dict = get_yaml_serializable_stage_dict(
modifiers=stage["modifiers"]
)

# infer run_type from stage
if run_type := stage.get("run_type"):
stage_dict["run_type"] = run_type

# infer run_type from stage
if run_type := stage.get("run_type"):
stage_dict["run_type"] = run_type
yaml_recipe_dict[final_stage_name] = stage_dict

yaml_recipe_dict[stage_name] = stage_dict
return yaml_recipe_dict


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy

from torch import nn

from sparseml.transformers.sparsification.modification import modify_model
from sparseml.transformers.sparsification.modification.modification_objects import (
QATLinear,
)


def test_modifying_mobilebert(mobilebert_model):

mobilebert_ = deepcopy(mobilebert_model)
mobilebert = modify_model(mobilebert_model)

assert isinstance(mobilebert_.embeddings.embedding_transformation, nn.Linear)
assert isinstance(mobilebert.embeddings.embedding_transformation, QATLinear)
# flake8: noqa
from .modify_model import modify_model
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""
Set of helper objects that are used to modify
the HuggingFace transformer models
the quantized models
"""

import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,34 @@
import logging
import os

import torch

from sparseml.transformers.sparsification.modification.registry import (
ModificationRegistry,
)
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry


_LOGGER = logging.getLogger(__name__)


def modify_model(model: torch.nn.Module, disable: int = False) -> torch.nn.Module:
def modify_model(
model: "torch.nn.Module", disable: bool = False # noqa: F821
) -> "torch.nn.Module": # noqa: F821
"""
Modify the original transformers model so that it is
compatible with the SparseML library.
Modify the original model so that it is
compatible with the quantization format required by the
SparseML library.
The model will be modified, if there exist a modification
function for the model in the registry of modifications.
Otherwise, the original model will be returned.
:param model: The original HuggingFace transformers model
:return: The potentially modified model
:param model: The original model to be modified
:param disable: If True, the modification will be disabled
:return: The potentially modified model to support
SparseML quantization
"""
model_name = model.__class__.__name__
NM_DISABLE_TRANSFORMERS_MODIFICATION = os.environ.get(
"NM_DISABLE_TRANSFORMERS_MODIFICATION", "False"
NM_DISABLE_QUANTIZATION_MODIFICATION = os.environ.get(
"NM_DISABLE_QUANTIZATION_MODIFICATION", "False"
).lower() in ["true", "1"]

try:
modification_func = ModificationRegistry.get_value_from_registry(model_name)
except KeyError:
Expand All @@ -50,7 +53,7 @@ def modify_model(model: torch.nn.Module, disable: int = False) -> torch.nn.Modul
)
return model

if NM_DISABLE_TRANSFORMERS_MODIFICATION:
if NM_DISABLE_QUANTIZATION_MODIFICATION:
_LOGGER.debug(
"Application of the modification function to model "
"disabled through the environment variable."
Expand All @@ -65,6 +68,6 @@ def modify_model(model: torch.nn.Module, disable: int = False) -> torch.nn.Modul
return model

_LOGGER.info(
f"Modifying the model {model_name} to be compatible with SparseML library"
f"Modifying the model {model_name} to be compatible with SparseML quantization"
)
return modification_func(model)
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
)
from sparsezoo.utils.registry import RegistryMixin


class ModificationRegistry(RegistryMixin):
"""
A registry for modification functions that can be applied to models
so that they can be used in the context of sparseml.transformers
so that they can be compatible with the quantization format required by the
SparseML library.
"""

@classmethod
def get_value_from_registry(cls, name: str):
"""
Extends the base class method to check the transformers version after
successfully retrieving the value from the registry. The motivation is
to ensure that the transformers version falls within the supported range
before we proceed with model modification.
"""
retrieved_value = super().get_value_from_registry(name)
check_transformers_version()
return retrieved_value
6 changes: 6 additions & 0 deletions src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization.base import QuantizationModifier
from sparseml.modifiers.quantization.modification import modify_model
from sparseml.modifiers.quantization.utils.helpers import (
configure_module_bn_wrappers,
freeze_bn_stats,
Expand Down Expand Up @@ -73,11 +74,16 @@ def __init__(self, **kwargs):

def on_initialize_structure(self, state: State, **kwargs):
module = state.model.model
# before the structure is modified to support quantization,
# we need to potentially modify the model architecture
module = modify_model(module)
self._enable_module_qat(module)
state.model.model.apply(torch.quantization.disable_observer)

def on_initialize(self, state: State, **kwargs) -> bool:
raise_if_torch_quantization_not_available()
module = state.model.model
module = modify_model(module)
if self.end and self.end != -1:
raise ValueError(
"end_epoch is disabled for QuantizationModifier and can only be set to"
Expand Down
4 changes: 1 addition & 3 deletions src/sparseml/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from dataclasses import dataclass
from typing import Dict, Generic, List, Optional, Tuple, TypeVar

from pydantic import Field

from sparseml.core import Modifier
from sparseml.core.model import ModifiableModel
from sparseml.core.model.base import LT
Expand Down Expand Up @@ -98,7 +96,7 @@ class SmoothQuantModifier(Modifier):
use the whole dataset
"""

smoothing_strength: float = Field(validation_alias="alpha", default=0.5)
smoothing_strength: float = 0.5
mappings: List[Tuple]
ignore: Optional[List[str]] = None
num_calibration_steps: Optional[int] = None
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/transformers/sparsification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# flake8: noqa

from .modification import *
from .question_answering import *
from .sparse_config import *
from .sparse_model import *
Expand Down
20 changes: 13 additions & 7 deletions src/sparseml/transformers/sparsification/modification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa
from .modify_model import modify_model
from .modifying_bert import *
from .modifying_distilbert import *
from .modifying_llama import *
from .modifying_mistral import *
from .modifying_mobilebert import *
from .modifying_opt import *
# isort:skip_file

# the modification module that adds modifications
# for transformers models to enable quantization

# import all the modification functions for the different models
from .modifying_bert import modify
from .modifying_llama import modify
from .modifying_mistral import modify
from .modifying_distilbert import modify
from .modifying_mobilebert import modify
from .modifying_opt import modify
Original file line number Diff line number Diff line change
Expand Up @@ -14,68 +14,59 @@

"""
Modification to the original Bert model required in the
context of SparseML
context of SparseML quantization
"""


import logging
import math
from typing import Optional, Tuple

import torch
from torch import nn
from transformers.models.bert.modeling_bert import BertAttention, BertSelfAttention
from transformers.models.bert.modeling_bert import BertSelfAttention

from sparseml.modifiers.quantization.modification.modification_objects import QATMatMul
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.modification_objects import (
QATMatMul,
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
)
from sparseml.transformers.sparsification.modification.registry import (
ModificationRegistry,
)


_LOGGER = logging.getLogger(__name__)


@ModificationRegistry.register(name="BertModel", alias=["BertForQuestionAnswering"])
def modify(model: nn.Module) -> nn.Module:
"""
Modify the Bert model to be compatible with SparseML
quantization
1. Replaces the MultiHeadSelfAttention modules with
MultiHeadSelfAttentionWithQuantizableMatmuls modules
Note: This function will not alter any of the alternatives
to the MultiHeadSelfAttention module such as BertAttention
Replaces the attention modules with
MultiHeadSelfAttentionWithQuantizableMatmuls modules
:param model: the original Bert model
:return: the modified Bert model
"""
check_transformers_version()
for name, submodule in model.named_modules():
if isinstance(submodule, BertSelfAttention):
if isinstance(submodule, BertSelfAttention) and not isinstance(
submodule, BertSelfAttentionWithQuantizableMatmuls
):
swap_modules(
model, name, BertSelfAttentionWithQuantizableMatmuls(submodule)
)
elif isinstance(submodule, BertAttention):
_LOGGER.debug(
f"The model contains {submodule.__class__.__name__} "
"module, which will not be modified"
)
return model


class BertSelfAttentionWithQuantizableMatmuls(BertSelfAttention):
"""
Wrapper around the original BertSelfAttention module to replace the
Wrapper around the original attention module to replace the
matmul operations with quantizable matmul operations
:param bert_self_attention: the original BertSelfAttention module
:param bert_self_attention: the original attention module to be
wrapped and modified
"""

def __init__(self, bert_self_attention: BertSelfAttention):
self.__class__ = type(
bert_self_attention.__class__.__name__,
self.__class__.__name__,
(self.__class__, bert_self_attention.__class__),
{},
)
Expand Down
Loading

0 comments on commit 579d201

Please sign in to comment.