From aa7cedfd7ed6b4a61f7822bde8a029b7f58fa385 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 25 Apr 2024 16:50:05 -0400 Subject: [PATCH 1/8] Open Fix Consecutive Recipe Application Test (#2255) --- src/sparseml/core/recipe/recipe.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index f6ab08af1e6..1a2a21d3f03 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -580,15 +580,23 @@ def _get_yaml_dict(self) -> Dict[str, Any]: # populate stages stages = original_recipe_dict["stages"] for stage_name, stage_list in stages.items(): - # stage is always a list of size 1 - stage = stage_list[0] - stage_dict = get_yaml_serializable_stage_dict(modifiers=stage["modifiers"]) + for idx, stage in enumerate(stage_list): + if len(stage_list) > 1: + # resolve name clashes caused by combining recipes with + # duplicate stage names + final_stage_name = f"{stage_name}_{idx}" + else: + final_stage_name = stage_name + stage_dict = get_yaml_serializable_stage_dict( + modifiers=stage["modifiers"] + ) + + # infer run_type from stage + if run_type := stage.get("run_type"): + stage_dict["run_type"] = run_type - # infer run_type from stage - if run_type := stage.get("run_type"): - stage_dict["run_type"] = run_type + yaml_recipe_dict[final_stage_name] = stage_dict - yaml_recipe_dict[stage_name] = stage_dict return yaml_recipe_dict From 3a47c49d0c23130aa63ac37f26854edbe63a709f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 25 Apr 2024 17:21:11 -0400 Subject: [PATCH 2/8] Fix failing quality check (#2254) --- tests/sparseml/transformers/test_clear_ml.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/sparseml/transformers/test_clear_ml.py b/tests/sparseml/transformers/test_clear_ml.py index c64a765d176..fd21eddc8ca 100644 --- a/tests/sparseml/transformers/test_clear_ml.py +++ b/tests/sparseml/transformers/test_clear_ml.py @@ -11,20 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from pathlib import Path +import pytest import torch + try: from clearml import Task -except Exception as err: - clearml = None + is_clearml = True +except Exception: + is_clearml = False from sparseml.transformers import train -@pytest.mark.skipif(clearml is None, reason="clearML not installed") + +@pytest.mark.skipif(not is_clearml, reason="clearML not installed") def test_finetune_wout_recipe(tmp_path: Path): recipe_str = None model = "Xenova/llama2.c-stories15M" From 22e3e58d9f7009b7934b4f6994fd73ecc19fffb0 Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Fri, 26 Apr 2024 15:40:37 +0200 Subject: [PATCH 3/8] initial commit (#2257) --- .../tutorials/text-generation/trl_mixin/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/README.md b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/README.md index 25c3b54976b..61fa42af000 100644 --- a/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/README.md +++ b/integrations/huggingface-transformers/tutorials/text-generation/trl_mixin/README.md @@ -30,17 +30,17 @@ class SFTTrainer(SessionManagerMixIn, TRLSFTTrainer): The new `SFTTrainer` class can now apply SparseML recipes and modifiers during supervised finetuning, will full support for all of the original TRL features. The full -class is defined in [sft_trainer.py](sft_trainer.py) and requires very minimal +class is defined in the script `sft_trainer.py` and requires very minimal additional code: just a dataset load override to support passing in tokenized datasets to the Trainer. ### Examples -[ex_trl_sft_data.py](ex_trl_sft_data.py): finetunes a 50% sparse Llama-7b model, +* Script `ex_trl_sft_data.py`: finetunes a 50% sparse Llama-7b model, using TRL's dataset preprocessing. Sparsity is maintained throughout training by applying a `ConstantPruningModifier` recipe to the `SFTTrainer` -[ex_trl_distillation.py](ex_trl_distillation.py): finetunes a 50% sparse Llama-7b +* Script `ex_trl_distillation.py`: finetunes a 50% sparse Llama-7b model using knowledge distillation from a dense Llama-7b model. Sparsity is maintained throughout training with a `ConstantPruningModifier` and layer-wise knowledge distillation is handled by the `OutputDistillationModifier` \ No newline at end of file From c7f3d0291252d7402684f834e550c1f7146663a3 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 26 Apr 2024 16:36:03 -0400 Subject: [PATCH 4/8] [OneShot] Update unit tests (#2231) --- .github/workflows/test-check.yaml | 4 +- src/sparseml/modifiers/smoothquant/base.py | 4 +- .../logarithmic_equalization}/__init__.py | 0 .../logarithmic_equalization/test_base.py | 55 ++ .../modifiers/pruning/sparsegpt/__init__.py | 13 + .../modifiers/pruning/sparsegpt/test_base.py | 48 ++ .../modifiers/pruning/wanda/test_base.py | 45 +- .../modifiers/quantization/test_base.py | 114 ++-- .../modifiers/smoothquant/__init__.py | 13 + .../modifiers/smoothquant/test_base.py | 51 ++ .../logarithmic_equalization/__init__.py | 13 + .../logarithmic_equalization/test_pytorch.py | 47 ++ .../pytorch/modifiers/obcq/test_pytorch.py | 155 ----- .../modifiers/pruning/sparsegpt/__init__.py | 13 + .../pruning/sparsegpt/test_pytorch.py | 180 ++++++ .../modifiers/pruning/wanda/test_pytorch.py | 49 +- .../modifiers/quantization/test_pytorch.py | 165 ++--- .../pytorch/modifiers/smoothquant/__init__.py | 13 + .../modifiers/smoothquant/test_pytorch.py | 47 ++ .../finetune/data/test_dataset_loading.py | 573 ++++++++++-------- 20 files changed, 1041 insertions(+), 561 deletions(-) rename tests/sparseml/{pytorch/modifiers/obcq => modifiers/logarithmic_equalization}/__init__.py (100%) create mode 100644 tests/sparseml/modifiers/logarithmic_equalization/test_base.py create mode 100644 tests/sparseml/modifiers/pruning/sparsegpt/__init__.py create mode 100644 tests/sparseml/modifiers/pruning/sparsegpt/test_base.py create mode 100644 tests/sparseml/modifiers/smoothquant/__init__.py create mode 100644 tests/sparseml/modifiers/smoothquant/test_base.py create mode 100644 tests/sparseml/pytorch/modifiers/logarithmic_equalization/__init__.py create mode 100644 tests/sparseml/pytorch/modifiers/logarithmic_equalization/test_pytorch.py delete mode 100644 tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py create mode 100644 tests/sparseml/pytorch/modifiers/pruning/sparsegpt/__init__.py create mode 100644 tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py create mode 100644 tests/sparseml/pytorch/modifiers/smoothquant/__init__.py create mode 100644 tests/sparseml/pytorch/modifiers/smoothquant/test_pytorch.py diff --git a/.github/workflows/test-check.yaml b/.github/workflows/test-check.yaml index 9f668546d99..c8ac153b5cb 100644 --- a/.github/workflows/test-check.yaml +++ b/.github/workflows/test-check.yaml @@ -165,7 +165,7 @@ jobs: - name: "Clean sparsezoo directory" run: rm -r sparsezoo/ - name: "⚙️ Install dependencies" - run: pip3 install .[dev,torchvision,onnxruntime] + run: pip3 install .[dev,torchvision,onnxruntime] safetensors>=0.4.1 - name: "🔬 Running pytorch tests" run: make test TARGETS=pytorch compat-pytorch-1_9-pytorch-tests: @@ -194,7 +194,7 @@ jobs: - name: "Clean sparsezoo directory" run: rm -r sparsezoo/ - name: "⚙️ Install dependencies" - run: pip3 install .[dev,torchvision,onnxruntime] torch==1.9.1 + run: pip3 install .[dev,torchvision,onnxruntime] torch==1.9.1 safetensors>=0.4.1 - name: "🔬 Running pytorch tests" run: make test TARGETS=pytorch compat-pytorch-1_9-onnx-tests: diff --git a/src/sparseml/modifiers/smoothquant/base.py b/src/sparseml/modifiers/smoothquant/base.py index 41f7983c873..f499808b106 100644 --- a/src/sparseml/modifiers/smoothquant/base.py +++ b/src/sparseml/modifiers/smoothquant/base.py @@ -16,8 +16,6 @@ from dataclasses import dataclass from typing import Dict, Generic, List, Optional, Tuple, TypeVar -from pydantic import Field - from sparseml.core import Modifier from sparseml.core.model import ModifiableModel from sparseml.core.model.base import LT @@ -98,7 +96,7 @@ class SmoothQuantModifier(Modifier): use the whole dataset """ - smoothing_strength: float = Field(validation_alias="alpha", default=0.5) + smoothing_strength: float = 0.5 mappings: List[Tuple] ignore: Optional[List[str]] = None num_calibration_steps: Optional[int] = None diff --git a/tests/sparseml/pytorch/modifiers/obcq/__init__.py b/tests/sparseml/modifiers/logarithmic_equalization/__init__.py similarity index 100% rename from tests/sparseml/pytorch/modifiers/obcq/__init__.py rename to tests/sparseml/modifiers/logarithmic_equalization/__init__.py diff --git a/tests/sparseml/modifiers/logarithmic_equalization/test_base.py b/tests/sparseml/modifiers/logarithmic_equalization/test_base.py new file mode 100644 index 00000000000..a43ee42b13e --- /dev/null +++ b/tests/sparseml/modifiers/logarithmic_equalization/test_base.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import pytest + +from sparseml.core.factory import ModifierFactory +from sparseml.core.framework import Framework +from sparseml.modifiers.logarithmic_equalization.base import ( + LogarithmicEqualizationModifier, +) +from sparseml.modifiers.smoothquant.base import SmoothQuantModifier +from tests.sparseml.modifiers.conf import setup_modifier_factory + + +@pytest.mark.unit +class TestLogarithmicEqualizationIsRegistered(unittest.TestCase): + def setUp(self): + self.kwargs = dict( + smoothing_strength=0.3, + mappings=[(["layer1", "layer2"], "layer3")], + ) + setup_modifier_factory() + + def test_log_equalization_is_registered(self): + modifier = ModifierFactory.create( + type_="LogarithmicEqualizationModifier", + framework=Framework.general, + allow_experimental=False, + allow_registered=True, + **self.kwargs, + ) + + self.assertIsInstance( + modifier, + LogarithmicEqualizationModifier, + "PyTorch LogarithmicEqualizationModifier not registered", + ) + + self.assertIsInstance(modifier, SmoothQuantModifier) + self.assertEqual(modifier.smoothing_strength, self.kwargs["smoothing_strength"]) + self.assertEqual(modifier.mappings, self.kwargs["mappings"]) diff --git a/tests/sparseml/modifiers/pruning/sparsegpt/__init__.py b/tests/sparseml/modifiers/pruning/sparsegpt/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/modifiers/pruning/sparsegpt/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/modifiers/pruning/sparsegpt/test_base.py b/tests/sparseml/modifiers/pruning/sparsegpt/test_base.py new file mode 100644 index 00000000000..43f4e8d3ffa --- /dev/null +++ b/tests/sparseml/modifiers/pruning/sparsegpt/test_base.py @@ -0,0 +1,48 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import pytest + +from sparseml.core.factory import ModifierFactory +from sparseml.core.framework import Framework +from sparseml.modifiers.obcq.base import SparseGPTModifier +from tests.sparseml.modifiers.conf import setup_modifier_factory + + +@pytest.mark.unit +class TestSparseGPTIsRegistered(unittest.TestCase): + def setUp(self): + self.kwargs = dict( + sparsity=0.5, + targets="__ALL_PRUNABLE__", + ) + setup_modifier_factory() + + def test_wanda_is_registered(self): + type_ = ModifierFactory.create( + type_="SparseGPTModifier", + framework=Framework.general, + allow_experimental=False, + allow_registered=True, + **self.kwargs, + ) + + self.assertIsInstance( + type_, + SparseGPTModifier, + "PyTorch SparseGPTModifier not registered", + ) diff --git a/tests/sparseml/modifiers/pruning/wanda/test_base.py b/tests/sparseml/modifiers/pruning/wanda/test_base.py index 8dcb682020d..ccfac16dc17 100644 --- a/tests/sparseml/modifiers/pruning/wanda/test_base.py +++ b/tests/sparseml/modifiers/pruning/wanda/test_base.py @@ -13,27 +13,36 @@ # limitations under the License. +import unittest + +import pytest + from sparseml.core.factory import ModifierFactory from sparseml.core.framework import Framework from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier from tests.sparseml.modifiers.conf import setup_modifier_factory -def test_wanda_is_registered(): - - kwargs = dict( - sparsity=0.5, - targets="__ALL_PRUNABLE__", - ) - setup_modifier_factory() - type_ = ModifierFactory.create( - type_="WandaPruningModifier", - framework=Framework.general, - allow_experimental=False, - allow_registered=True, - **kwargs, - ) - - assert isinstance( - type_, WandaPruningModifier - ), "PyTorch ConstantPruningModifier not registered" +@pytest.mark.unit +class TestWandaIsRegistered(unittest.TestCase): + def setUp(self): + self.kwargs = dict( + sparsity=0.5, + targets="__ALL_PRUNABLE__", + ) + setup_modifier_factory() + + def test_wanda_is_registered(self): + type_ = ModifierFactory.create( + type_="WandaPruningModifier", + framework=Framework.general, + allow_experimental=False, + allow_registered=True, + **self.kwargs, + ) + + self.assertIsInstance( + type_, + WandaPruningModifier, + "PyTorch WandaPruningModifier not registered", + ) diff --git a/tests/sparseml/modifiers/quantization/test_base.py b/tests/sparseml/modifiers/quantization/test_base.py index cd5fab0e755..064d8dcb671 100644 --- a/tests/sparseml/modifiers/quantization/test_base.py +++ b/tests/sparseml/modifiers/quantization/test_base.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + +import pytest + from sparseml.core.event import Event from sparseml.core.factory import ModifierFactory from sparseml.core.framework import Framework @@ -19,65 +23,75 @@ from tests.sparseml.modifiers.conf import setup_modifier_factory -def test_quantization_registered(): - setup_modifier_factory() +@pytest.mark.unit +class TestQuantizationRegistered(unittest.TestCase): + def setUp(self): + setup_modifier_factory() + self.kwargs = dict(index=0, group="quantization", start=2.0, end=-1.0) - kwargs = dict(index=0, group="quantization", start=2.0, end=-1.0) - quant_obj = ModifierFactory.create( - type_="QuantizationModifier", - framework=Framework.general, - allow_experimental=False, - allow_registered=True, - **kwargs, - ) + def test_quantization_registered(self): + quant_obj = ModifierFactory.create( + type_="QuantizationModifier", + framework=Framework.general, + allow_experimental=False, + allow_registered=True, + **self.kwargs, + ) - assert isinstance(quant_obj, QuantizationModifier) + self.assertIsInstance(quant_obj, QuantizationModifier) -def test_end_epochs(): - start = 0.0 - scheme = dict( - input_activations=dict(num_bits=8, symmetric=True), - weights=dict(num_bits=6, symmetric=False), - ) +@pytest.mark.unit +class TestEndEpochs(unittest.TestCase): + def setUp(self): + self.start = 0.0 + self.scheme = dict( + input_activations=dict(num_bits=8, symmetric=True), + weights=dict(num_bits=6, symmetric=False), + ) - disable_quant_epoch, freeze_bn_epoch = None, None - obj_modifier = QuantizationModifier( - start=start, - scheme=scheme, - disable_quantization_observer_epoch=disable_quant_epoch, - freeze_bn_stats_epoch=freeze_bn_epoch, - ) + def test_end_epochs(self): + disable_quant_epoch, freeze_bn_epoch = None, None + obj_modifier = QuantizationModifier( + start=self.start, + scheme=self.scheme, + disable_quantization_observer_epoch=disable_quant_epoch, + freeze_bn_stats_epoch=freeze_bn_epoch, + ) - assert obj_modifier.calculate_disable_observer_epoch() == -1 - assert obj_modifier.calculate_freeze_bn_stats_epoch() == -1 + self.assertEqual(obj_modifier.calculate_disable_observer_epoch(), -1) + self.assertEqual(obj_modifier.calculate_freeze_bn_stats_epoch(), -1) - for epoch in range(3): - event = Event(steps_per_epoch=1, global_step=epoch) - assert not obj_modifier.check_should_disable_observer(event) - assert not obj_modifier.check_should_freeze_bn_stats(event) + for epoch in range(3): + event = Event(steps_per_epoch=1, global_step=epoch) + assert not obj_modifier.check_should_disable_observer(event) + assert not obj_modifier.check_should_freeze_bn_stats(event) - disable_quant_epoch, freeze_bn_epoch = 3.5, 5.0 - obj_modifier = QuantizationModifier( - start=start, - scheme=scheme, - disable_quantization_observer_epoch=disable_quant_epoch, - freeze_bn_stats_epoch=freeze_bn_epoch, - ) + disable_quant_epoch, freeze_bn_epoch = 3.5, 5.0 + obj_modifier = QuantizationModifier( + start=self.start, + scheme=self.scheme, + disable_quantization_observer_epoch=disable_quant_epoch, + freeze_bn_stats_epoch=freeze_bn_epoch, + ) - assert obj_modifier.calculate_disable_observer_epoch() == disable_quant_epoch - assert obj_modifier.calculate_freeze_bn_stats_epoch() == freeze_bn_epoch + self.assertEqual( + obj_modifier.calculate_disable_observer_epoch(), disable_quant_epoch + ) + self.assertEqual( + obj_modifier.calculate_freeze_bn_stats_epoch(), freeze_bn_epoch + ) - for epoch in range(4): - event = Event(steps_per_epoch=1, global_step=epoch) - assert not obj_modifier.check_should_disable_observer(event) - assert not obj_modifier.check_should_freeze_bn_stats(event) + for epoch in range(4): + event = Event(steps_per_epoch=1, global_step=epoch) + assert not obj_modifier.check_should_disable_observer(event) + assert not obj_modifier.check_should_freeze_bn_stats(event) - event = Event(steps_per_epoch=1, global_step=4) - assert obj_modifier.check_should_disable_observer(event) - assert not obj_modifier.check_should_freeze_bn_stats(event) - - for epoch in range(5, 8): - event = Event(steps_per_epoch=1, global_step=epoch) + event = Event(steps_per_epoch=1, global_step=4) assert obj_modifier.check_should_disable_observer(event) - assert obj_modifier.check_should_freeze_bn_stats(event) + assert not obj_modifier.check_should_freeze_bn_stats(event) + + for epoch in range(5, 8): + event = Event(steps_per_epoch=1, global_step=epoch) + assert obj_modifier.check_should_disable_observer(event) + assert obj_modifier.check_should_freeze_bn_stats(event) diff --git a/tests/sparseml/modifiers/smoothquant/__init__.py b/tests/sparseml/modifiers/smoothquant/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/modifiers/smoothquant/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/modifiers/smoothquant/test_base.py b/tests/sparseml/modifiers/smoothquant/test_base.py new file mode 100644 index 00000000000..f3c29d3fd69 --- /dev/null +++ b/tests/sparseml/modifiers/smoothquant/test_base.py @@ -0,0 +1,51 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import pytest + +from sparseml.core.factory import ModifierFactory +from sparseml.core.framework import Framework +from sparseml.modifiers.smoothquant.base import SmoothQuantModifier +from tests.sparseml.modifiers.conf import setup_modifier_factory + + +@pytest.mark.unit +class TestSmoothQuantIsRegistered(unittest.TestCase): + def setUp(self): + self.kwargs = dict( + smoothing_strength=0.3, + mappings=[(["layer1", "layer2"], "layer3")], + ) + setup_modifier_factory() + + def test_smooth_quant_is_registered(self): + modifier = ModifierFactory.create( + type_="SmoothQuantModifier", + framework=Framework.general, + allow_experimental=False, + allow_registered=True, + **self.kwargs, + ) + + self.assertIsInstance( + modifier, + SmoothQuantModifier, + "PyTorch SmoothQuant not registered", + ) + + self.assertEqual(modifier.smoothing_strength, self.kwargs["smoothing_strength"]) + self.assertEqual(modifier.mappings, self.kwargs["mappings"]) diff --git a/tests/sparseml/pytorch/modifiers/logarithmic_equalization/__init__.py b/tests/sparseml/pytorch/modifiers/logarithmic_equalization/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/pytorch/modifiers/logarithmic_equalization/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/pytorch/modifiers/logarithmic_equalization/test_pytorch.py b/tests/sparseml/pytorch/modifiers/logarithmic_equalization/test_pytorch.py new file mode 100644 index 00000000000..e0209ac7e69 --- /dev/null +++ b/tests/sparseml/pytorch/modifiers/logarithmic_equalization/test_pytorch.py @@ -0,0 +1,47 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest +from torch.nn import Linear + +from sparseml.core import State +from sparseml.core.framework import Framework +from sparseml.core.model import ModifiableModel +from sparseml.modifiers.logarithmic_equalization.pytorch import ( + LogarithmicEqualizationModifierPyTorch, +) +from tests.sparseml.pytorch.helpers import LinearNet + + +@pytest.mark.unit +class TestLogEqualizationMapping(unittest.TestCase): + def setUp(self): + self.model = ModifiableModel(framework=Framework.pytorch, model=LinearNet()) + self.state = State(framework=Framework.pytorch, model=self.model) + + def test_successful_map(self): + mappings = [(["seq.fc2"], "seq.block1.fc1")] + modifier = LogarithmicEqualizationModifierPyTorch(mappings=mappings) + + modifier.ignore = [] + modifier.resolved_mappings_ = modifier._resolve_mappings(self.state.model) + + self.assertEqual(len(modifier.resolved_mappings_), len(mappings)) + + mapping = modifier.resolved_mappings_[0] + self.assertEqual(mapping.smooth_name, mappings[0][1]) + self.assertIsInstance(mapping.smooth_layer, Linear) + self.assertIsInstance(mapping.balance_layers[0], Linear) diff --git a/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py b/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py deleted file mode 100644 index 0df15ebfdd2..00000000000 --- a/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from sparseml.core.framework import Framework -from sparseml.core.model import ModifiableModel -from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch -from sparseml.modifiers.quantization import QuantizationModifier -from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch -from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory -from tests.sparseml.pytorch.helpers import LinearNet - - -@pytest.mark.parametrize( - "sparsity,targets", - [ - ([0.5, 0.2], "__ALL__"), # type mismatch - ([0.2, 0.1, 0.3], ["seq.fc1", "seq.fc2"]), # length mismatch - ([0.3, 0.4], ["re:.*fc1", "re:.*fc2"]), # regex not supported - ], -) -def test_invalid_layerwise_recipes_raise_exceptions(sparsity, targets): - setup_modifier_factory() - model = LinearNet() - - kwargs = dict( - sparsity=sparsity, - block_size=128, - quantize=False, - targets=targets, - ) - modifier = SparseGPTModifierPyTorch(**kwargs) - testing_harness = LifecyleTestingHarness(model=model, start=-1) - - # confirm invalid layerwise recipes fail at initialization - with pytest.raises(ValueError): - modifier.initialize(testing_harness.get_state()) - - -def test_successful_layerwise_recipe(): - setup_modifier_factory() - model = LinearNet() - - sparsities = [0.5, 0.2] - targets = ["seq.fc1", "seq.fc2"] - kwargs = dict(sparsity=sparsities, block_size=128, quantize=False, targets=targets) - modifier = SparseGPTModifierPyTorch(**kwargs) - modifier.compressible_layers_ = {"seq.fc1": None, "seq.fc2": None} - modifier.model = ModifiableModel(framework=Framework.pytorch, model=model) - found_compressible_layers = modifier.compressible_layers() - modifier.compressible_layers_ = found_compressible_layers - modifier._validate_layerwise_sparsity() - - # ensure layers names successfully match up with model - assert len(found_compressible_layers) == len(targets) - - -def test_create_default_quant_modifier(): - setup_modifier_factory() - kwargs = dict(sparsity=0.5, block_size=128, quantize=True) - - modifier = SparseGPTModifierPyTorch(**kwargs) - assert modifier.quantization_modifier_ is None - - testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier.on_initialize_structure(testing_harness.get_state()) - assert modifier.quantize - assert isinstance(modifier.quantization_modifier_, QuantizationModifier) - - should_be_default_quant_scheme = modifier.quantization_modifier_.scheme - assert should_be_default_quant_scheme.input_activations.num_bits == 8 - assert not should_be_default_quant_scheme.input_activations.symmetric - assert should_be_default_quant_scheme.weights.num_bits == 8 - assert should_be_default_quant_scheme.weights.symmetric - - -def test_set_quant_if_modifer_already_exists(): - setup_modifier_factory() - - model = LinearNet() - kwargs = dict( - scheme=dict( - input_activations=dict(num_bits=8, symmetric=True), - weights=dict(num_bits=4, symmetric=False), - ), - ) - - modifier = QuantizationModifierPyTorch(**kwargs) - testing_harness = LifecyleTestingHarness(model=model, start=-1) - - assert not testing_harness.get_state().model.qat_active() - modifier.initialize(testing_harness.get_state()) - assert testing_harness.get_state().model.qat_active() - - kwargs = dict(sparsity=0.5, block_size=128, quantize=False) - modifier = SparseGPTModifierPyTorch(**kwargs) - assert not modifier.quantize - modifier.on_initialize_structure(testing_harness.get_state()) - - # quantization modifier not owned by SparseGPT - assert modifier.quantization_modifier_ is None - - # since quantization modifier is already applied, quantization must be set in OBCQ - assert modifier.quantize - - -def test_set_quant_in_sparsegpt(): - setup_modifier_factory() - - quant_kwargs = { - "scheme": { - "input_activations": { - "num_bits": 8, - "symmetric": False, - "strategy": "tensor", - "kwargs": {}, - }, - "weights": { - "num_bits": 4, - "symmetric": True, - "strategy": "channel", - "kwargs": {}, - }, - } - } - quant_config = {"QuantizationModifier": quant_kwargs} - - kwargs = dict(sparsity=0.5, block_size=128, quantize=quant_config) - - modifier = SparseGPTModifierPyTorch(**kwargs) - assert modifier.quantization_modifier_ is None - - testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier.on_initialize_structure(testing_harness.get_state()) - assert modifier.quantize - assert isinstance(modifier.quantization_modifier_, QuantizationModifier) - - dict_scheme = dict(modifier.quantization_modifier_.scheme) - assert dict(dict_scheme["weights"]) == quant_kwargs["scheme"]["weights"] - assert ( - dict(dict_scheme["input_activations"]) - == quant_kwargs["scheme"]["input_activations"] - ) diff --git a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/__init__.py b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py new file mode 100644 index 00000000000..87558f5a625 --- /dev/null +++ b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -0,0 +1,180 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest + +from parameterized import parameterized +from sparseml.core.framework import Framework +from sparseml.core.model import ModifiableModel +from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch +from sparseml.modifiers.quantization import QuantizationModifier +from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch +from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory +from tests.sparseml.pytorch.helpers import LinearNet +from tests.testing_utils import requires_torch + + +@pytest.mark.unit +@requires_torch +class TestInvalidLayerwiseRecipesRaiseExceptions(unittest.TestCase): + def setUp(self): + setup_modifier_factory() + + @parameterized.expand( + [ + [[0.5, 0.2], "__ALL__"], + [[0.2, 0.1, 0.3], ["seq.fc1", "seq.fc2"]], + [[0.3, 0.4], ["re:.*fc1", "re:.*fc2"]], + ] + ) + def test_invalid_layerwise_recipes_raise_exceptions(self, sparsity, targets): + setup_modifier_factory() + kwargs = dict( + sparsity=sparsity, + block_size=128, + quantize=False, + targets=targets, + ) + modifier = SparseGPTModifierPyTorch(**kwargs) + testing_harness = LifecyleTestingHarness(model=LinearNet(), start=-1) + + # confirm invalid layerwise recipes fail at initialization + with self.assertRaises(ValueError): + modifier.initialize(testing_harness.get_state()) + + +@pytest.mark.unit +@requires_torch +class TestSuccessfulLayerwiseRecipe(unittest.TestCase): + def setUp(self): + setup_modifier_factory() + + def test_successful_layerwise_recipe(self): + sparsities = [0.5, 0.2] + targets = ["seq.fc1", "seq.fc2"] + kwargs = dict( + sparsity=sparsities, block_size=128, quantize=False, targets=targets + ) + modifier = SparseGPTModifierPyTorch(**kwargs) + modifier.compressible_layers_ = {"seq.fc1": None, "seq.fc2": None} + modifier.model = ModifiableModel(framework=Framework.pytorch, model=LinearNet()) + found_compressible_layers = modifier.compressible_layers() + modifier.compressible_layers_ = found_compressible_layers + modifier._validate_layerwise_sparsity() + + # ensure layers names successfully match up with model + self.assertEqual(len(found_compressible_layers), len(targets)) + + +@pytest.mark.unit +@requires_torch +class TestCreateDefaultQuantModifier(unittest.TestCase): + def setUp(self): + setup_modifier_factory() + + def test_create_default_quant_modifier(self): + kwargs = dict(sparsity=0.5, block_size=128, quantize=True) + + modifier = SparseGPTModifierPyTorch(**kwargs) + assert modifier.quantization_modifier_ is None + + testing_harness = LifecyleTestingHarness(model=LinearNet()) + modifier.on_initialize_structure(testing_harness.get_state()) + assert modifier.quantize + assert isinstance(modifier.quantization_modifier_, QuantizationModifier) + + should_be_default_quant_scheme = modifier.quantization_modifier_.scheme + self.assertEqual(should_be_default_quant_scheme.input_activations.num_bits, 8) + assert not should_be_default_quant_scheme.input_activations.symmetric + self.assertEqual(should_be_default_quant_scheme.weights.num_bits, 8) + assert should_be_default_quant_scheme.weights.symmetric + + +@pytest.mark.unit +@requires_torch +class TestSetQuantIfModifierAlreadyExists(unittest.TestCase): + def setUp(self): + setup_modifier_factory() + + def test_set_quant_if_modifer_already_exists(self): + model = LinearNet() + kwargs = dict( + scheme=dict( + input_activations=dict(num_bits=8, symmetric=True), + weights=dict(num_bits=4, symmetric=False), + ), + ) + + modifier = QuantizationModifierPyTorch(**kwargs) + testing_harness = LifecyleTestingHarness(model=model, start=-1) + + assert not testing_harness.get_state().model.qat_active() + modifier.initialize(testing_harness.get_state()) + assert testing_harness.get_state().model.qat_active() + + kwargs = dict(sparsity=0.5, block_size=128, quantize=False) + modifier = SparseGPTModifierPyTorch(**kwargs) + assert not modifier.quantize + modifier.on_initialize_structure(testing_harness.get_state()) + + # quantization modifier not owned by SparseGPT + assert modifier.quantization_modifier_ is None + + # since quantization modifier is already applied, quantization must be set in + # OBCQ + assert modifier.quantize + + +class TestSetQuantInSparseGPT(unittest.TestCase): + def setUp(self): + setup_modifier_factory() + self.quant_kwargs = { + "scheme": { + "input_activations": { + "num_bits": 8, + "symmetric": False, + "strategy": "tensor", + "kwargs": {}, + }, + "weights": { + "num_bits": 4, + "symmetric": True, + "strategy": "channel", + "kwargs": {}, + }, + } + } + self.quant_config = {"QuantizationModifier": self.quant_kwargs} + + def test_set_quant_in_sparsegpt(self): + kwargs = dict(sparsity=0.5, block_size=128, quantize=self.quant_config) + + modifier = SparseGPTModifierPyTorch(**kwargs) + assert modifier.quantization_modifier_ is None + + testing_harness = LifecyleTestingHarness(model=LinearNet()) + modifier.on_initialize_structure(testing_harness.get_state()) + assert modifier.quantize + self.assertIsInstance(modifier.quantization_modifier_, QuantizationModifier) + + dict_scheme = dict(modifier.quantization_modifier_.scheme) + self.assertEqual( + dict(dict_scheme["weights"]), self.quant_kwargs["scheme"]["weights"] + ) + self.assertEqual( + dict(dict_scheme["input_activations"]), + self.quant_kwargs["scheme"]["input_activations"], + ) diff --git a/tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch.py b/tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch.py index 2bdca703951..a65959a564a 100644 --- a/tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch.py @@ -12,28 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + +import pytest from sparseml.core.factory import ModifierFactory from sparseml.core.framework import Framework from tests.sparseml.modifiers.conf import setup_modifier_factory +from tests.testing_utils import requires_torch + + +@pytest.mark.unit +@requires_torch +class TestWandaPytorchIsRegistered(unittest.TestCase): + def setUp(self): + self.kwargs = dict( + sparsity=0.5, + targets="__ALL_PRUNABLE__", + ) + setup_modifier_factory() + + def test_wanda_pytorch_is_registered(self): + from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch + type_ = ModifierFactory.create( + type_="WandaPruningModifier", + framework=Framework.pytorch, + allow_experimental=False, + allow_registered=True, + **self.kwargs, + ) -def test_wanda_pytorch_is_registered(): - from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch - - kwargs = dict( - sparsity=0.5, - targets="__ALL_PRUNABLE__", - ) - setup_modifier_factory() - type_ = ModifierFactory.create( - type_="WandaPruningModifier", - framework=Framework.pytorch, - allow_experimental=False, - allow_registered=True, - **kwargs, - ) - - assert isinstance( - type_, WandaPruningModifierPyTorch - ), "PyTorch ConstantPruningModifier not registered" + self.assertIsInstance( + type_, + WandaPruningModifierPyTorch, + "PyTorch ConstantPruningModifier not registered", + ) diff --git a/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py b/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py index 58075c2cc2c..6b258b884cb 100644 --- a/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import pytest +from parameterized import parameterized from sparseml.core import State from sparseml.core.event import Event, EventType from sparseml.core.factory import ModifierFactory @@ -30,6 +33,94 @@ _test_qat_wrapped_module, _test_quantized_module, ) +from tests.testing_utils import requires_torch + + +@pytest.mark.unit +@requires_torch +class TestQuantizationRegistered(unittest.TestCase): + def setUp(self): + setup_modifier_factory() + self.kwargs = dict(index=0, group="quantization", start=2.0, end=-1.0) + + def test_quantization_registered(self): + quant_obj = ModifierFactory.create( + type_="QuantizationModifier", + framework=Framework.pytorch, + allow_experimental=False, + allow_registered=True, + **self.kwargs, + ) + + self.assertIsInstance(quant_obj, QuantizationModifierPyTorch) + + +@pytest.mark.unit +@requires_torch +class TestQuantizationOneShot(unittest.TestCase): + def setUp(self): + scheme = dict( + input_activations=dict(num_bits=8, symmetric=True), + weights=dict(num_bits=4, symmetric=False, strategy="channel"), + ) + self.kwargs = dict(scheme=scheme) + + @parameterized.expand([[ConvNet], [LinearNet]]) + def test_quantization_oneshot(self, model_class): + model = model_class() + state = State(framework=Framework.pytorch, start_event=Event()) + state.update(model=model, start=-1) + + modifier = QuantizationModifierPyTorch(**self.kwargs) + + modifier.initialize(state) + + # for one-shot, we set up quantization on initialization + _test_qat_applied(modifier, model) + + # we shouldn't keep updating stats after one-shot + assert modifier.quantization_observer_disabled_ + + test_start_event = Event(type_=EventType.BATCH_START) + test_end_event = Event(type_=EventType.BATCH_END) + assert not modifier.should_start(test_start_event) + assert not modifier.should_end(test_end_event) + + modifier.finalize(state) + assert modifier.finalized + + +@pytest.mark.unit +@requires_torch +class TestQuantizationTraining(unittest.TestCase): + def setUp(self): + self.start_epoch = 2 + + self.kwargs = dict( + start=self.start_epoch, + scheme=dict( + input_activations=dict(num_bits=8, symmetric=True), + weights=dict(num_bits=4, symmetric=False), + ), + ) + + @parameterized.expand([[ConvNet], [LinearNet]]) + def test_quantization_training(self, model_class): + model = model_class() + + modifier = QuantizationModifierPyTorch(**self.kwargs) + + testing_harness = LifecyleTestingHarness(model=model) + modifier.initialize(testing_harness.get_state()) + assert not modifier.qat_enabled_ + + testing_harness.trigger_modifier_for_epochs(modifier, self.start_epoch) + assert not modifier.qat_enabled_ + testing_harness.trigger_modifier_for_epochs(modifier, self.start_epoch + 1) + _test_qat_applied(modifier, model) + + modifier.finalize(testing_harness.get_state()) + assert modifier.quantization_observer_disabled_ def _test_qat_applied(modifier, model): @@ -67,77 +158,3 @@ def _test_qat_applied(modifier, model): # check all non-target modules are not quantized assert not hasattr(module, "quantization_scheme") assert not hasattr(module, "qconfig") - - -def test_quantization_registered(): - setup_modifier_factory() - - kwargs = dict(index=0, group="quantization", start=2.0, end=-1.0) - quant_obj = ModifierFactory.create( - type_="QuantizationModifier", - framework=Framework.pytorch, - allow_experimental=False, - allow_registered=True, - **kwargs, - ) - - assert isinstance(quant_obj, QuantizationModifierPyTorch) - - -@pytest.mark.parametrize("model_class", [ConvNet, LinearNet]) -def test_quantization_oneshot(model_class): - model = model_class() - state = State(framework=Framework.pytorch, start_event=Event()) - state.update(model=model, start=-1) - - scheme = dict( - input_activations=dict(num_bits=8, symmetric=True), - weights=dict(num_bits=4, symmetric=False, strategy="channel"), - ) - kwargs = dict(scheme=scheme) - - modifier = QuantizationModifierPyTorch(**kwargs) - - modifier.initialize(state) - - # for one-shot, we set up quantization on initialization - _test_qat_applied(modifier, model) - - # we shouldn't keep updating stats after one-shot - assert modifier.quantization_observer_disabled_ - - test_start_event = Event(type_=EventType.BATCH_START) - test_end_event = Event(type_=EventType.BATCH_END) - assert not modifier.should_start(test_start_event) - assert not modifier.should_end(test_end_event) - - modifier.finalize(state) - assert modifier.finalized - - -@pytest.mark.parametrize("model_class", [ConvNet, LinearNet]) -def test_quantization_training(model_class): - start_epoch = 2 - - model = model_class() - kwargs = dict( - start=start_epoch, - scheme=dict( - input_activations=dict(num_bits=8, symmetric=True), - weights=dict(num_bits=4, symmetric=False), - ), - ) - - modifier = QuantizationModifierPyTorch(**kwargs) - - testing_harness = LifecyleTestingHarness(model=model) - modifier.initialize(testing_harness.get_state()) - assert not modifier.qat_enabled_ - - testing_harness.trigger_modifier_for_epochs(modifier, start_epoch) - assert not modifier.qat_enabled_ - testing_harness.trigger_modifier_for_epochs(modifier, start_epoch + 1) - _test_qat_applied(modifier, model) - - modifier.finalize(testing_harness.get_state()) - assert modifier.quantization_observer_disabled_ diff --git a/tests/sparseml/pytorch/modifiers/smoothquant/__init__.py b/tests/sparseml/pytorch/modifiers/smoothquant/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/pytorch/modifiers/smoothquant/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/pytorch/modifiers/smoothquant/test_pytorch.py b/tests/sparseml/pytorch/modifiers/smoothquant/test_pytorch.py new file mode 100644 index 00000000000..21943277532 --- /dev/null +++ b/tests/sparseml/pytorch/modifiers/smoothquant/test_pytorch.py @@ -0,0 +1,47 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest +from torch.nn import Linear + +from sparseml.core import State +from sparseml.core.framework import Framework +from sparseml.core.model import ModifiableModel +from sparseml.modifiers.smoothquant.pytorch import SmoothQuantModifierPyTorch +from tests.sparseml.pytorch.helpers import LinearNet +from tests.testing_utils import requires_torch + + +@pytest.mark.unit +@requires_torch +class TestSmoothQuantMapping(unittest.TestCase): + def setUp(self): + self.model = ModifiableModel(framework=Framework.pytorch, model=LinearNet()) + self.state = State(framework=Framework.pytorch, model=self.model) + + def test_successful_map(self): + mappings = [(["seq.fc1"], "seq.fc2")] + modifier = SmoothQuantModifierPyTorch(mappings=mappings) + + modifier.ignore = [] + modifier.resolved_mappings_ = modifier._resolve_mappings(self.state.model) + + self.assertEqual(len(modifier.resolved_mappings_), len(mappings)) + + mapping = modifier.resolved_mappings_[0] + self.assertEqual(mapping.smooth_name, mappings[0][1]) + self.assertIsInstance(mapping.smooth_layer, Linear) + self.assertIsInstance(mapping.balance_layers[0], Linear) diff --git a/tests/sparseml/transformers/finetune/data/test_dataset_loading.py b/tests/sparseml/transformers/finetune/data/test_dataset_loading.py index cd2c230b581..c976e644055 100644 --- a/tests/sparseml/transformers/finetune/data/test_dataset_loading.py +++ b/tests/sparseml/transformers/finetune/data/test_dataset_loading.py @@ -13,272 +13,365 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import pytest -import torch from datasets import IterableDataset, load_dataset +from parameterized import parameterized from sparseml.transformers.finetune.data import TextGenerationDataset from sparseml.transformers.finetune.data.data_args import DataTrainingArguments -from sparseml.transformers.finetune.data.data_helpers import format_calibration_data -from sparseml.transformers.finetune.model_args import ModelArguments from sparseml.transformers.finetune.runner import StageRunner from sparseml.transformers.finetune.training_args import TrainingArguments +from tests.testing_utils import requires_torch -@pytest.mark.usefixtures("tiny_llama_tokenizer") -def test_concatenation_tokenization(tiny_llama_tokenizer): - data_args = DataTrainingArguments( - dataset="wikitext", - dataset_config_name="wikitext-2-raw-v1", - concatenate_data=True, - ) - wiki_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train[:5%]", - tokenizer=tiny_llama_tokenizer, - ) - raw_dataset = wiki_manager.get_raw_dataset() - assert len(raw_dataset) > 0 - assert raw_dataset.split == "train[:5%]" - assert raw_dataset.info.config_name == "wikitext-2-raw-v1" - tokenized_dataset = wiki_manager.tokenize_and_process(raw_dataset) - assert "input_ids" in tokenized_dataset.features - assert "labels" in tokenized_dataset.features - for i in range(len(tokenized_dataset)): - assert len(tokenized_dataset[i]["input_ids"]) == wiki_manager.max_seq_length - - -@pytest.mark.usefixtures("tiny_llama_tokenizer") -def test_no_padding_tokenization(tiny_llama_tokenizer): - data_args = DataTrainingArguments(dataset="open_platypus", pad_to_max_length=False) - op_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train[5%:10%]", - tokenizer=tiny_llama_tokenizer, - ) - raw_dataset = op_manager.get_raw_dataset() - assert len(raw_dataset) > 0 - ex_item = raw_dataset[0]["text"] - assert "Below is an instruction that describes a task" in ex_item - - assert raw_dataset.split == "train[5%:10%]" - tokenized_dataset = op_manager.tokenize_and_process(raw_dataset) - assert "input_ids" in tokenized_dataset.features - assert "labels" in tokenized_dataset.features - print(tokenized_dataset[0]["input_ids"]) - - for i in range(len(tokenized_dataset)): - assert len(tokenized_dataset[i]["input_ids"]) <= op_manager.max_seq_length - - -@pytest.mark.usefixtures("tiny_llama_tokenizer") -def test_max_seq_len_clipped(tiny_llama_tokenizer): - data_args = DataTrainingArguments(dataset="open_platypus", max_seq_length=4096) - op_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train[80%:]", - tokenizer=tiny_llama_tokenizer, - ) +@pytest.mark.unit +class TestConcentrationTokenization(unittest.TestCase): + def setUp(self): + self.data_args = DataTrainingArguments( + dataset="wikitext", + dataset_config_name="wikitext-2-raw-v1", + concatenate_data=True, + ) - assert op_manager.max_seq_length == tiny_llama_tokenizer.model_max_length + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer + + def test_concatenation_tokenization(self): + wiki_manager = TextGenerationDataset.load_from_registry( + self.data_args.dataset, + data_args=self.data_args, + split="train[:5%]", + tokenizer=self.tiny_llama_tokenizer, + ) + raw_dataset = wiki_manager.get_raw_dataset() + self.assertGreater(len(raw_dataset), 0) + self.assertEqual(raw_dataset.split, "train[:5%]") + self.assertEqual(raw_dataset.info.config_name, "wikitext-2-raw-v1") + tokenized_dataset = wiki_manager.tokenize_and_process(raw_dataset) + self.assertIn("input_ids", tokenized_dataset.features) + self.assertIn("labels", tokenized_dataset.features) + for i in range(len(tokenized_dataset)): + self.assertEqual( + len(tokenized_dataset[i]["input_ids"]), wiki_manager.max_seq_length + ) + + +@pytest.mark.unit +class TestNoPaddingTokenization(unittest.TestCase): + def setUp(self): + self.data_args = DataTrainingArguments( + dataset="open_platypus", pad_to_max_length=False + ) + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer + + @pytest.mark.usefixtures("tiny_llama_tokenizer") + def test_no_padding_tokenization(self): + op_manager = TextGenerationDataset.load_from_registry( + self.data_args.dataset, + data_args=self.data_args, + split="train[5%:10%]", + tokenizer=self.tiny_llama_tokenizer, + ) + raw_dataset = op_manager.get_raw_dataset() + self.assertGreater(len(raw_dataset), 0) + ex_item = raw_dataset[0]["text"] + self.assertIn("Below is an instruction that describes a task", ex_item) + + self.assertEqual(raw_dataset.split, "train[5%:10%]") + tokenized_dataset = op_manager.tokenize_and_process(raw_dataset) + self.assertIn("input_ids", tokenized_dataset.features) + self.assertIn("labels", tokenized_dataset.features) + print(tokenized_dataset[0]["input_ids"]) + + for i in range(len(tokenized_dataset)): + self.assertLessEqual( + len(tokenized_dataset[i]["input_ids"]), op_manager.max_seq_length + ) + + +@pytest.mark.unit +class TestMaxSeqLenClipped(unittest.TestCase): + def setUp(self): + self.data_args = DataTrainingArguments( + dataset="open_platypus", max_seq_length=4096 + ) -# test loading percentages works as expected size-wise -@pytest.mark.usefixtures("tiny_llama_tokenizer") -def test_dataset_kwargs_and_percentages(tiny_llama_tokenizer): - data_args = DataTrainingArguments( - dataset="wikitext", - raw_kwargs={ - "data_files": {"train": "wikitext-2-raw-v1/train-00000-of-00001.parquet"} - }, - ) - c4_manager_a = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train[5%:10%]", - tokenizer=tiny_llama_tokenizer, - ) - raw_dataset_a = c4_manager_a.get_raw_dataset() + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer - c4_manager_b = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train[5%:15%]", - tokenizer=tiny_llama_tokenizer, - ) - raw_dataset_b = c4_manager_b.get_raw_dataset() - - assert len(raw_dataset_b) == 2 * len(raw_dataset_a) - - -@pytest.mark.usefixtures("tiny_llama_tokenizer") -@pytest.mark.parametrize( - "dataset_key,dataset_config,split,do_concat", - [ - ("ptb", "penn_treebank", "train[:5%]", False), - ("gsm8k", "main", "train[:5%]", True), - ("ultrachat_200k", "default", "train_sft[:2%]", False), - ], -) -def test_datasets(tiny_llama_tokenizer, dataset_key, dataset_config, split, do_concat): - data_args = DataTrainingArguments( - dataset=dataset_key, - dataset_config_name=dataset_config, - concatenate_data=do_concat, - ) - manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split=split, - tokenizer=tiny_llama_tokenizer, - ) - raw_dataset = manager.get_raw_dataset() - assert len(raw_dataset) > 0 - assert raw_dataset.split == split - assert raw_dataset.info.config_name == dataset_config + def test_max_seq_len_clipped(self): + op_manager = TextGenerationDataset.load_from_registry( + self.data_args.dataset, + data_args=self.data_args, + split="train[80%:]", + tokenizer=self.tiny_llama_tokenizer, + ) - tokenized_dataset = manager.tokenize_and_process(raw_dataset) - assert "input_ids" in tokenized_dataset.features - assert "labels" in tokenized_dataset.features - for i in range(len(tokenized_dataset)): - if do_concat: - assert len(tokenized_dataset[i]["input_ids"]) == manager.max_seq_length - else: - assert len(tokenized_dataset[i]["input_ids"]) <= manager.max_seq_length + self.assertEqual( + op_manager.max_seq_length, self.tiny_llama_tokenizer.model_max_length + ) -@pytest.mark.skip("Dataset load broken on Hugging Face") -@pytest.mark.usefixtures("tiny_llama_tokenizer") -def test_evol(tiny_llama_tokenizer): - data_args = DataTrainingArguments( - dataset="evolcodealpaca", - dataset_config_name=None, - concatenate_data=False, - ) - evol_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train[:2%]", - tokenizer=tiny_llama_tokenizer, - ) - raw_dataset = evol_manager.get_raw_dataset() - assert len(raw_dataset) > 0 - assert raw_dataset.split == "train[:2%]" - - tokenized_dataset = evol_manager.tokenize_and_process(raw_dataset) - assert "input_ids" in tokenized_dataset.features - assert "labels" in tokenized_dataset.features - for i in range(len(tokenized_dataset)): - assert len(tokenized_dataset[i]["input_ids"]) <= evol_manager.max_seq_length - - -@pytest.mark.usefixtures("tiny_llama_tokenizer") -def test_dvc_dataloading(tiny_llama_tokenizer): - data_args = DataTrainingArguments( - dataset="csv", - dataset_path="dvc://workshop/satellite-data/jan_train.csv", - dvc_data_repository="https://github.com/iterative/dataset-registry.git", - ) - manager = TextGenerationDataset( - text_column="", - data_args=data_args, - split="train", - tokenizer=tiny_llama_tokenizer, - ) +@pytest.mark.unit +class TestDatasetKwargsAndPercent(unittest.TestCase): + def setUp(self): + self.data_args = DataTrainingArguments( + dataset="wikitext", + raw_kwargs={ + "data_files": { + "train": "wikitext-2-raw-v1/train-00000-of-00001.parquet" + } + }, + ) - raw_dataset = manager.get_raw_dataset() - assert len(raw_dataset) > 0 - assert isinstance(raw_dataset[0], dict) + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer + def test_dataset_kwargs_and_percentages(self): -@pytest.mark.usefixtures("tiny_llama_tokenizer") -def test_stream_loading(tiny_llama_tokenizer): - data_args = DataTrainingArguments( - dataset="wikitext", - dataset_config_name="wikitext-2-raw-v1", - concatenate_data=True, - streaming=True, - ) - manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train", - tokenizer=tiny_llama_tokenizer, - ) + c4_manager_a = TextGenerationDataset.load_from_registry( + self.data_args.dataset, + data_args=self.data_args, + split="train[5%:10%]", + tokenizer=self.tiny_llama_tokenizer, + ) + raw_dataset_a = c4_manager_a.get_raw_dataset() + + c4_manager_b = TextGenerationDataset.load_from_registry( + self.data_args.dataset, + data_args=self.data_args, + split="train[5%:15%]", + tokenizer=self.tiny_llama_tokenizer, + ) + raw_dataset_b = c4_manager_b.get_raw_dataset() + + self.assertEqual(len(raw_dataset_b), 2 * len(raw_dataset_a)) + + +@pytest.mark.unit +class TestDatasets(unittest.TestCase): + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer - raw_dataset = manager.get_raw_dataset() - processed = manager.tokenize_and_process(raw_dataset) - assert isinstance(processed, IterableDataset) - with pytest.raises(TypeError): - # in streaming mode we don't know the length of the dataset - _ = len(processed) - - # confirm tokenization of streamed item works correctly - item = next(iter(processed)) - assert "labels" in item - assert len(item["input_ids"]) == manager.max_seq_length - - -@pytest.mark.usefixtures("tiny_llama_tokenizer") -@pytest.mark.parametrize( - "split_def", [("train"), ("train[60%:]"), ({"train": "train[:20%]"}), (None)] -) -def test_split_loading(split_def, tiny_llama_tokenizer): - data_args = DataTrainingArguments(dataset="open_platypus", splits=split_def) - training_args = TrainingArguments(do_train=True, output_dir="dummy") - model_args = ModelArguments(model=None) - stage_runner = StageRunner( - model_args=model_args, data_args=data_args, training_args=training_args + @parameterized.expand( + [ + ["ptb", "penn_treebank", "train[:5%]", False], + ["gsm8k", "main", "train[:5%]", True], + ["ultrachat_200k", "default", "train_sft[:2%]", False], + ] ) - stage_runner.populate_datasets(tokenizer=tiny_llama_tokenizer) + def test_datasets(self, dataset_key, dataset_config, split, do_concat): + data_args = DataTrainingArguments( + dataset=dataset_key, + dataset_config_name=dataset_config, + concatenate_data=do_concat, + ) + manager = TextGenerationDataset.load_from_registry( + data_args.dataset, + data_args=data_args, + split=split, + tokenizer=self.tiny_llama_tokenizer, + ) + raw_dataset = manager.get_raw_dataset() + self.assertGreater(len(raw_dataset), 0) + self.assertEqual(raw_dataset.split, split) + self.assertEqual(raw_dataset.info.config_name, dataset_config) + + tokenized_dataset = manager.tokenize_and_process(raw_dataset) + self.assertIn("input_ids", tokenized_dataset.features) + self.assertIn("labels", tokenized_dataset.features) + for i in range(len(tokenized_dataset)): + if do_concat: + self.assertEqual( + len(tokenized_dataset[i]["input_ids"]), manager.max_seq_length + ) + else: + self.assertLessEqual( + len(tokenized_dataset[i]["input_ids"]), manager.max_seq_length + ) - train_dataset = stage_runner.get_dataset_split("train") - assert train_dataset is not None - assert isinstance(train_dataset[0], dict) +@pytest.mark.skip("Dataset load broken on Hugging Face") +@pytest.mark.unit +class TestEvol(unittest.TestCase): + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer + + def setUp(self): + self.data_args = DataTrainingArguments( + dataset="evolcodealpaca", + dataset_config_name=None, + concatenate_data=False, + ) -def test_load_tokenized_data(tiny_llama_tokenizer): - dataset = load_dataset("garage-bAInd/Open-Platypus")["train"] - NUM_CALIB_SAMPS = 256 - MAX_SEQ_LEN = 512 - dataset = dataset.shuffle(seed=42).select(range(NUM_CALIB_SAMPS)) + def test_evol(self): + evol_manager = TextGenerationDataset.load_from_registry( + self.data_args.dataset, + data_args=self.data_args, + split="train[:2%]", + tokenizer=self.tiny_llama_tokenizer, + ) + raw_dataset = evol_manager.get_raw_dataset() + self.assertGreater(len(raw_dataset), 0) + self.assertEqual(raw_dataset.split, "train[:2%]") + + tokenized_dataset = evol_manager.tokenize_and_process(raw_dataset) + self.assertIn("input_ids", tokenized_dataset.features) + self.assertIn("labels", tokenized_dataset.features) + for i in range(len(tokenized_dataset)): + self.assertLessEqual( + len(tokenized_dataset[i]["input_ids"]), evol_manager.max_seq_length + ) + + +@pytest.mark.unit +class TestDVCLoading(unittest.TestCase): + def setUp(self): + self.data_args = DataTrainingArguments( + dataset="csv", + dataset_path="dvc://workshop/satellite-data/jan_train.csv", + dvc_data_repository="https://github.com/iterative/dataset-registry.git", + ) - def preprocess(sample): - concat_text = "INPUT: " + sample.get("input", "") - concat_text += "INSTRUCTIONS: " + sample.get("instruction", "") - concat_text += "OUTPUT: " + sample.get("output", "") + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer - return tiny_llama_tokenizer( - concat_text, padding=False, max_length=MAX_SEQ_LEN, truncation=True + def test_dvc_dataloading(self): + manager = TextGenerationDataset( + text_column="", + data_args=self.data_args, + split="train", + tokenizer=self.tiny_llama_tokenizer, ) - tokenized_dataset = dataset.map( - preprocess, remove_columns=["input", "output", "instruction", "data_source"] - ) - stage_runner = StageRunner( - model_args=None, - data_args=DataTrainingArguments( - dataset=tokenized_dataset, shuffle_calibration_samples=False - ), - training_args=TrainingArguments(do_oneshot=True), - ) - stage_runner.populate_datasets(tokenizer=None) - calib_dataset = stage_runner.get_dataset_split("calibration") - assert len(calib_dataset) == NUM_CALIB_SAMPS - data_cols = calib_dataset.column_names - assert len(data_cols) == 2 - assert "input_ids" in data_cols and "attention_mask" in data_cols - - # confirm turning shuffle off works - calib_dataloader = format_calibration_data( - tokenized_dataset=calib_dataset, - num_calibration_samples=NUM_CALIB_SAMPS, - do_shuffle=stage_runner._data_args.shuffle_calibration_samples, + raw_dataset = manager.get_raw_dataset() + self.assertGreater(len(raw_dataset), 0) + self.assertIsInstance(raw_dataset[0], dict) + + +@pytest.mark.unit +class TestStreamLoading(unittest.TestCase): + def setUp(self): + self.data_args = DataTrainingArguments( + dataset="wikitext", + dataset_config_name="wikitext-2-raw-v1", + concatenate_data=True, + streaming=True, + ) + + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer + + def test_stream_loading(self): + manager = TextGenerationDataset.load_from_registry( + self.data_args.dataset, + data_args=self.data_args, + split="train", + tokenizer=self.tiny_llama_tokenizer, + ) + + raw_dataset = manager.get_raw_dataset() + processed = manager.tokenize_and_process(raw_dataset) + self.assertIsInstance(processed, IterableDataset) + with pytest.raises(TypeError): + # in streaming mode we don't know the length of the dataset + _ = len(processed) + + # confirm tokenization of streamed item works correctly + item = next(iter(processed)) + self.assertIn("labels", item) + self.assertEqual(len(item["input_ids"]), manager.max_seq_length) + + +@pytest.mark.unit +class TestSplitLoading(unittest.TestCase): + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer + + @parameterized.expand( + [["train"], ["train[60%:]"], [{"train": "train[:20%]"}], [None]] ) - assert len(calib_dataloader) == NUM_CALIB_SAMPS - dataloader_sample = next(iter(calib_dataloader))["input_ids"] - diff = dataloader_sample - torch.Tensor(calib_dataset[0]["input_ids"]) - assert torch.sum(diff) == 0 + def test_split_loading(self, split_def): + from sparseml.transformers.finetune.model_args import ModelArguments + + data_args = DataTrainingArguments(dataset="open_platypus", splits=split_def) + training_args = TrainingArguments(do_train=True, output_dir="dummy") + model_args = ModelArguments(model=None) + stage_runner = StageRunner( + model_args=model_args, data_args=data_args, training_args=training_args + ) + stage_runner.populate_datasets(tokenizer=self.tiny_llama_tokenizer) + + train_dataset = stage_runner.get_dataset_split("train") + assert train_dataset is not None + self.assertIsInstance(train_dataset[0], dict) + + +@requires_torch +@pytest.mark.unit +class TestTokenizationDataset(unittest.TestCase): + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer + dataset = load_dataset("garage-bAInd/Open-Platypus")["train"] + self.num_calib_samples = 256 + self.max_seq_len = 512 + self.dataset = dataset.shuffle(seed=42).select(range(self.num_calib_samples)) + + def test_load_tokenized_data(self): + import torch + + from sparseml.transformers.finetune.data.data_helpers import ( + format_calibration_data, + ) + + def preprocess(sample): + concat_text = "INPUT: " + sample.get("input", "") + concat_text += "INSTRUCTIONS: " + sample.get("instruction", "") + concat_text += "OUTPUT: " + sample.get("output", "") + + return self.tiny_llama_tokenizer( + concat_text, padding=False, max_length=self.max_seq_len, truncation=True + ) + + tokenized_dataset = self.dataset.map( + preprocess, remove_columns=["input", "output", "instruction", "data_source"] + ) + stage_runner = StageRunner( + model_args=None, + data_args=DataTrainingArguments( + dataset=tokenized_dataset, shuffle_calibration_samples=False + ), + training_args=TrainingArguments(do_oneshot=True), + ) + stage_runner.populate_datasets(tokenizer=None) + calib_dataset = stage_runner.get_dataset_split("calibration") + self.assertEqual(len(calib_dataset), self.num_calib_samples) + data_cols = calib_dataset.column_names + self.assertEqual(len(data_cols), 2) + self.assertIn("input_ids", data_cols) + self.assertIn("attention_mask", data_cols) + + # confirm turning shuffle off works + calib_dataloader = format_calibration_data( + tokenized_dataset=calib_dataset, + num_calibration_samples=self.num_calib_samples, + do_shuffle=stage_runner._data_args.shuffle_calibration_samples, + ) + self.assertEqual(len(calib_dataloader), self.num_calib_samples) + dataloader_sample = next(iter(calib_dataloader))["input_ids"] + diff = dataloader_sample - torch.Tensor(calib_dataset[0]["input_ids"]) + self.assertEqual(torch.sum(diff), 0) From 8467ee4521ae42256fa2e7040ce7e7def914242e Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 26 Apr 2024 21:27:24 -0400 Subject: [PATCH 5/8] [OneShot][Testing] Add cli and oneshot pathway integration smoke tests (#2235) --- .github/workflows/test-check.yaml | 4 +- .../{ => image_classification}/conftest.py | 0 tests/sparseml/pytorch/oneshot/__init__.py | 13 +++ .../pytorch/oneshot/dataset_processing.py | 90 +++++++++++++++++++ .../oneshot_configs/recipes/recipe.yaml | 10 +++ .../oneshot_configs/tiny_stories_conf1.yaml | 16 ++++ .../oneshot_configs/tiny_stories_conf2.yaml | 6 ++ .../oneshot_configs/tiny_stories_conf3.yaml | 7 ++ .../oneshot_configs/tiny_stories_conf4.yaml | 17 ++++ .../oneshot_configs/tiny_stories_conf5.yaml | 6 ++ .../oneshot_configs/tiny_stories_conf6.yaml | 6 ++ .../pytorch/oneshot/test_api_inputs.py | 87 ++++++++++++++++++ tests/sparseml/pytorch/oneshot/test_cli.py | 72 +++++++++++++++ tests/testing_utils.py | 12 +++ 14 files changed, 344 insertions(+), 2 deletions(-) rename tests/sparseml/pytorch/{ => image_classification}/conftest.py (100%) create mode 100644 tests/sparseml/pytorch/oneshot/__init__.py create mode 100644 tests/sparseml/pytorch/oneshot/dataset_processing.py create mode 100644 tests/sparseml/pytorch/oneshot/oneshot_configs/recipes/recipe.yaml create mode 100644 tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf1.yaml create mode 100644 tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf2.yaml create mode 100644 tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf3.yaml create mode 100644 tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf4.yaml create mode 100644 tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf5.yaml create mode 100644 tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf6.yaml create mode 100644 tests/sparseml/pytorch/oneshot/test_api_inputs.py create mode 100644 tests/sparseml/pytorch/oneshot/test_cli.py diff --git a/.github/workflows/test-check.yaml b/.github/workflows/test-check.yaml index c8ac153b5cb..362fd297321 100644 --- a/.github/workflows/test-check.yaml +++ b/.github/workflows/test-check.yaml @@ -165,7 +165,7 @@ jobs: - name: "Clean sparsezoo directory" run: rm -r sparsezoo/ - name: "⚙️ Install dependencies" - run: pip3 install .[dev,torchvision,onnxruntime] safetensors>=0.4.1 + run: pip3 install .[dev,torchvision,onnxruntime,transformers] - name: "🔬 Running pytorch tests" run: make test TARGETS=pytorch compat-pytorch-1_9-pytorch-tests: @@ -194,7 +194,7 @@ jobs: - name: "Clean sparsezoo directory" run: rm -r sparsezoo/ - name: "⚙️ Install dependencies" - run: pip3 install .[dev,torchvision,onnxruntime] torch==1.9.1 safetensors>=0.4.1 + run: pip3 install .[dev,torchvision,onnxruntime,transformers] - name: "🔬 Running pytorch tests" run: make test TARGETS=pytorch compat-pytorch-1_9-onnx-tests: diff --git a/tests/sparseml/pytorch/conftest.py b/tests/sparseml/pytorch/image_classification/conftest.py similarity index 100% rename from tests/sparseml/pytorch/conftest.py rename to tests/sparseml/pytorch/image_classification/conftest.py diff --git a/tests/sparseml/pytorch/oneshot/__init__.py b/tests/sparseml/pytorch/oneshot/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/pytorch/oneshot/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/pytorch/oneshot/dataset_processing.py b/tests/sparseml/pytorch/oneshot/dataset_processing.py new file mode 100644 index 00000000000..6e06e7f906a --- /dev/null +++ b/tests/sparseml/pytorch/oneshot/dataset_processing.py @@ -0,0 +1,90 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict + +from datasets import load_dataset + + +ALPACA_TEMPLATE = { + "prompt_input": "Below is an instruction that describes a task, paired with an " + "input that provides further context. Write a response that appropriately " + "completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n" + "{input}\n\n### Response:\n", + "prompt_no_input": "Below is an instruction that describes a task. Write a " + "response that appropriately completes the request.\n\n### Instruction:\n{" + "instruction}\n\n### Response:\n", +} + +GSM_TEMPLATE = "Question: {question}.\nAnswer: " + + +def _fetch_open_platypus_dataset(): + dataset = load_dataset("garage-bAInd/Open-Platypus")["train"] + dataset = dataset.shuffle(seed=42).select(range(256)) + return dataset + + +def _fetch_gsm8k_data(): + dataset = load_dataset("gsm8k", "main")["train"] + dataset = dataset.shuffle(seed=42).select(range(256)) + return dataset + + +def _preprocess_alpaca(sample): + if "input" in sample: + concat_text = ALPACA_TEMPLATE["prompt_input"].format( + instruction=sample["instruction"], input=sample["input"] + ) + else: + concat_text = ALPACA_TEMPLATE["prompt_no_input"].format( + instruction=sample["instruction"] + ) + if "output" in sample: + concat_text += sample["output"] + + return concat_text + + +def _preprocess_gsm(sample): + concat_text = GSM_TEMPLATE.format(question=sample["question"]) + concat_text += sample["answer"] + return concat_text + + +def get_data_utils(dataset_name: str) -> Dict: + """ + Given the name of a dataset, fetch the appropriate set of data processing utils. + Returns a dictionary of data processing utils required to process the data when + providing tokenized data to oneshot. + Includes: + 1. dataload: function to load the dataset + 2. preprocess: preprocessing function to apply to the dataset + 3. remove_columns: specific columns which should be removed from the dataset + + :param dataset_name: the name of the dataset + :returns dictionary of preprocessing functions/utils. + """ + data_mapping = { + "open_platypus": { + "preprocess": _preprocess_alpaca, + "dataload": _fetch_open_platypus_dataset, + "remove_columns": ["input", "output", "instruction", "data_source"], + }, + "gsm8k": { + "preprocess": _preprocess_gsm, + "dataload": _fetch_gsm8k_data, + "remove_columns": ["question", "answer"], + }, + } + return data_mapping.get(dataset_name) diff --git a/tests/sparseml/pytorch/oneshot/oneshot_configs/recipes/recipe.yaml b/tests/sparseml/pytorch/oneshot/oneshot_configs/recipes/recipe.yaml new file mode 100644 index 00000000000..6157f2ec114 --- /dev/null +++ b/tests/sparseml/pytorch/oneshot/oneshot_configs/recipes/recipe.yaml @@ -0,0 +1,10 @@ +test_stage: + obcq_modifiers: + SparseGPTModifier: + sparsity: 0.5 + block_size: 128 + sequential_update: False + quantize: False + targets: [ + 're:model.layers.3.mlp.gate_proj.weight' + ] \ No newline at end of file diff --git a/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf1.yaml b/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf1.yaml new file mode 100644 index 00000000000..59379b9aabd --- /dev/null +++ b/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf1.yaml @@ -0,0 +1,16 @@ +cadence: "commit" +test_type: "smoke" +tokenize: False +model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +dataset: open_platypus +recipe: | + test_stage: + obcq_modifiers: + SparseGPTModifier: + sparsity: 0.5 + block_size: 128 + sequential_update: False + quantize: False + targets: [ + 're:model.layers.3.mlp.gate_proj.weight' + ] \ No newline at end of file diff --git a/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf2.yaml b/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf2.yaml new file mode 100644 index 00000000000..a1a9df29b1c --- /dev/null +++ b/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf2.yaml @@ -0,0 +1,6 @@ +cadence: "commit" +test_type: "smoke" +tokenize: False +model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +dataset: open_platypus +recipe: "tests/sparseml/pytorch/oneshot/oneshot_configs/recipes/recipe.yaml" \ No newline at end of file diff --git a/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf3.yaml b/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf3.yaml new file mode 100644 index 00000000000..38ecc948a39 --- /dev/null +++ b/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf3.yaml @@ -0,0 +1,7 @@ +cadence: "commit" +test_type: "smoke" +tokenize: False +model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +dataset: "gsm8k" +dataset_config_name: "main" +recipe: "tests/sparseml/pytorch/oneshot/oneshot_configs/recipes/recipe.yaml" \ No newline at end of file diff --git a/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf4.yaml b/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf4.yaml new file mode 100644 index 00000000000..a742208a09a --- /dev/null +++ b/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf4.yaml @@ -0,0 +1,17 @@ +cadence: "commit" +test_type: "smoke" +tokenize: False +model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +dataset: "gsm8k" +dataset_config_name: "main" +recipe: | + test_stage: + obcq_modifiers: + SparseGPTModifier: + sparsity: 0.5 + block_size: 128 + sequential_update: False + quantize: False + targets: [ + 're:model.layers.3.mlp.gate_proj.weight' + ] \ No newline at end of file diff --git a/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf5.yaml b/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf5.yaml new file mode 100644 index 00000000000..8e8e86aa2bb --- /dev/null +++ b/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf5.yaml @@ -0,0 +1,6 @@ +cadence: "commit" +test_type: "smoke" +tokenize: True +model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +dataset: open_platypus +recipe: "tests/sparseml/pytorch/oneshot/oneshot_configs/recipes/recipe.yaml" \ No newline at end of file diff --git a/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf6.yaml b/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf6.yaml new file mode 100644 index 00000000000..b1eade5ab97 --- /dev/null +++ b/tests/sparseml/pytorch/oneshot/oneshot_configs/tiny_stories_conf6.yaml @@ -0,0 +1,6 @@ +cadence: "commit" +test_type: "smoke" +tokenize: True +model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +dataset: "gsm8k" +recipe: "tests/sparseml/pytorch/oneshot/oneshot_configs/recipes/recipe.yaml" \ No newline at end of file diff --git a/tests/sparseml/pytorch/oneshot/test_api_inputs.py b/tests/sparseml/pytorch/oneshot/test_api_inputs.py new file mode 100644 index 00000000000..6f56bade73b --- /dev/null +++ b/tests/sparseml/pytorch/oneshot/test_api_inputs.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import unittest + +import pytest + +from parameterized import parameterized_class +from tests.sparseml.pytorch.oneshot.dataset_processing import get_data_utils +from tests.testing_utils import parse_params, requires_torch + + +CONFIGS_DIRECTORY = "tests/sparseml/pytorch/oneshot/oneshot_configs" + +# TODO: Seems better to mark test type (smoke, sanity, regression) as a marker as +# opposed to using a field in the config file? + + +@pytest.mark.smoke +@pytest.mark.integration +@requires_torch +@parameterized_class(parse_params(CONFIGS_DIRECTORY)) +class TestOneShotInputs(unittest.TestCase): + model = None + dataset = None + recipe = None + dataset_config_name = None + tokenize = None + + def setUp(self): + from sparseml.transformers import ( + SparseAutoModelForCausalLM, + SparseAutoTokenizer, + ) + + self.tokenizer = SparseAutoTokenizer.from_pretrained(self.model) + self.model = SparseAutoModelForCausalLM.from_pretrained(self.model) + self.output = "./oneshot_output" + self.kwargs = {"dataset_config_name": self.dataset_config_name} + + data_utils = get_data_utils(self.dataset) + + def wrapped_preprocess_func(sample): + preprocess_func = data_utils.get("preprocess") + return self.tokenizer( + preprocess_func(sample), padding=False, max_length=512, truncation=True + ) + + # If `tokenize` is set to True, use the appropriate preprocessing function + # and set self.tokenizer = None. Updates the self.dataset field from the string + # to the loaded dataset. + if self.tokenize: + loaded_dataset = data_utils.get("dataload")() + self.dataset = loaded_dataset.map( + wrapped_preprocess_func, + remove_columns=data_utils.get("remove_columns"), + ) + self.tokenizer = None + + def test_one_shot_inputs(self): + from sparseml.transformers import oneshot + + oneshot( + model=self.model, + tokenizer=self.tokenizer, + dataset=self.dataset, + recipe=self.recipe, + output_dir=self.output, + num_calibration_samples=10, + pad_to_max_length=False, + **self.kwargs, + ) + + def tearDown(self): + shutil.rmtree(self.output) diff --git a/tests/sparseml/pytorch/oneshot/test_cli.py b/tests/sparseml/pytorch/oneshot/test_cli.py new file mode 100644 index 00000000000..809d24f35bf --- /dev/null +++ b/tests/sparseml/pytorch/oneshot/test_cli.py @@ -0,0 +1,72 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import unittest + +import pytest + +from parameterized import parameterized_class +from tests.testing_utils import parse_params, requires_torch, run_cli_command + + +CONFIGS_DIRECTORY = "tests/sparseml/pytorch/oneshot/oneshot_configs" + + +@pytest.mark.smoke +@pytest.mark.integration +@requires_torch +@parameterized_class(parse_params(CONFIGS_DIRECTORY)) +class TestOneShotCli(unittest.TestCase): + model = None + dataset = None + recipe = None + dataset_config_name = None + tokenize = None + + def setUp(self): + if self.tokenize: + pytest.skip("Tokenized data input not supported for oneshot cli") + + self.output = "./oneshot_output" + self.additional_args = [] + if self.dataset_config_name: + self.additional_args.append("--dataset_config_name") + self.additional_args.append(self.dataset_config_name) + + def test_one_shot_cli(self): + cmd = [ + "sparseml.transformers.text_generation.oneshot", + "--dataset", + self.dataset, + "--model", + self.model, + "--output_dir", + self.output, + "--recipe", + self.recipe, + "--num_calibration_samples", + "10", + "--pad_to_max_length", + "False", + ] + + if len(self.additional_args) > 0: + cmd.extend(self.additional_args) + res = run_cli_command(cmd) + self.assertEqual(res.returncode, 0) + print(res.stdout) + + def tearDown(self): + shutil.rmtree(self.output) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 81853d0ca03..c42402847af 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -17,6 +17,7 @@ import logging import os import unittest +from subprocess import PIPE, STDOUT, run from typing import List, Optional, Union import yaml @@ -107,3 +108,14 @@ def parse_params( f"Skipping testing model: {file} for cadence: {config['cadence']}" ) return config_dicts + + +def run_cli_command(cmd: List[str]): + """ + Run a cli command and return the response. The cli command is launched through a new + subprocess. + + :param cmd: cli command provided as a list of arguments where each argument + should be a string + """ + return run(cmd, stdout=PIPE, stderr=STDOUT, check=False, encoding="utf-8") From 7cd2febeeb6351b8e97207b8a9264e4808adfa38 Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Mon, 29 Apr 2024 19:11:43 +0200 Subject: [PATCH 6/8] Refactor the quantization modification logic (#2233) --- .../quantization/modification/__init__.py | 20 +-- .../modification/modification_objects.py | 2 +- .../modification/modify_model.py | 31 ++-- .../quantization}/modification/registry.py | 19 +-- .../modifiers/quantization/pytorch.py | 6 + .../transformers/sparsification/__init__.py | 1 + .../sparsification/modification/__init__.py | 20 ++- .../modification/modifying_bert.py | 43 ++--- .../modification/modifying_distilbert.py | 42 ++--- .../modification/modifying_llama.py | 52 +++--- .../modification/modifying_mistral.py | 53 +++--- .../modification/modifying_mobilebert.py | 28 ++-- .../modification/modifying_opt.py | 46 +++--- .../sparsification/sparse_model.py | 11 +- .../modification/test_modify_model.py | 8 +- tests/sparseml/transformers/obcq/test_obcq.py | 31 +++- .../sparsification/modification/conftest.py | 151 ++++++++++-------- .../modification/test_modifying_bert.py | 31 ---- .../test_modifying_distillbert.py | 29 ---- .../modification/test_modifying_llama.py | 43 ++--- .../modification/test_modifying_mistral.py | 43 ++--- .../modification/test_modifying_non_causal.py | 43 +++++ .../modification/test_modifying_opt.py | 43 ++--- 23 files changed, 371 insertions(+), 425 deletions(-) rename tests/sparseml/transformers/sparsification/modification/test_modifying_mobilebert.py => src/sparseml/modifiers/quantization/modification/__init__.py (53%) rename src/sparseml/{transformers/sparsification => modifiers/quantization}/modification/modification_objects.py (99%) rename src/sparseml/{transformers/sparsification => modifiers/quantization}/modification/modify_model.py (72%) rename src/sparseml/{transformers/sparsification => modifiers/quantization}/modification/registry.py (53%) rename tests/sparseml/{transformers/sparsification => modifiers/quantization}/modification/test_modify_model.py (91%) delete mode 100644 tests/sparseml/transformers/sparsification/modification/test_modifying_bert.py delete mode 100644 tests/sparseml/transformers/sparsification/modification/test_modifying_distillbert.py create mode 100644 tests/sparseml/transformers/sparsification/modification/test_modifying_non_causal.py diff --git a/tests/sparseml/transformers/sparsification/modification/test_modifying_mobilebert.py b/src/sparseml/modifiers/quantization/modification/__init__.py similarity index 53% rename from tests/sparseml/transformers/sparsification/modification/test_modifying_mobilebert.py rename to src/sparseml/modifiers/quantization/modification/__init__.py index 0013cf61fe4..7669c7b26fc 100644 --- a/tests/sparseml/transformers/sparsification/modification/test_modifying_mobilebert.py +++ b/src/sparseml/modifiers/quantization/modification/__init__.py @@ -11,21 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from copy import deepcopy - -from torch import nn - -from sparseml.transformers.sparsification.modification import modify_model -from sparseml.transformers.sparsification.modification.modification_objects import ( - QATLinear, -) - - -def test_modifying_mobilebert(mobilebert_model): - - mobilebert_ = deepcopy(mobilebert_model) - mobilebert = modify_model(mobilebert_model) - - assert isinstance(mobilebert_.embeddings.embedding_transformation, nn.Linear) - assert isinstance(mobilebert.embeddings.embedding_transformation, QATLinear) +# flake8: noqa +from .modify_model import modify_model diff --git a/src/sparseml/transformers/sparsification/modification/modification_objects.py b/src/sparseml/modifiers/quantization/modification/modification_objects.py similarity index 99% rename from src/sparseml/transformers/sparsification/modification/modification_objects.py rename to src/sparseml/modifiers/quantization/modification/modification_objects.py index a96d45b983f..ce46bbfd986 100644 --- a/src/sparseml/transformers/sparsification/modification/modification_objects.py +++ b/src/sparseml/modifiers/quantization/modification/modification_objects.py @@ -14,7 +14,7 @@ """ Set of helper objects that are used to modify -the HuggingFace transformer models +the quantized models """ import torch diff --git a/src/sparseml/transformers/sparsification/modification/modify_model.py b/src/sparseml/modifiers/quantization/modification/modify_model.py similarity index 72% rename from src/sparseml/transformers/sparsification/modification/modify_model.py rename to src/sparseml/modifiers/quantization/modification/modify_model.py index 41447b94944..1fee2d70c3c 100644 --- a/src/sparseml/transformers/sparsification/modification/modify_model.py +++ b/src/sparseml/modifiers/quantization/modification/modify_model.py @@ -15,31 +15,34 @@ import logging import os -import torch - -from sparseml.transformers.sparsification.modification.registry import ( - ModificationRegistry, -) +from sparseml.modifiers.quantization.modification.registry import ModificationRegistry _LOGGER = logging.getLogger(__name__) -def modify_model(model: torch.nn.Module, disable: int = False) -> torch.nn.Module: +def modify_model( + model: "torch.nn.Module", disable: bool = False # noqa: F821 +) -> "torch.nn.Module": # noqa: F821 """ - Modify the original transformers model so that it is - compatible with the SparseML library. + Modify the original model so that it is + compatible with the quantization format required by the + SparseML library. + The model will be modified, if there exist a modification function for the model in the registry of modifications. Otherwise, the original model will be returned. - :param model: The original HuggingFace transformers model - :return: The potentially modified model + :param model: The original model to be modified + :param disable: If True, the modification will be disabled + :return: The potentially modified model to support + SparseML quantization """ model_name = model.__class__.__name__ - NM_DISABLE_TRANSFORMERS_MODIFICATION = os.environ.get( - "NM_DISABLE_TRANSFORMERS_MODIFICATION", "False" + NM_DISABLE_QUANTIZATION_MODIFICATION = os.environ.get( + "NM_DISABLE_QUANTIZATION_MODIFICATION", "False" ).lower() in ["true", "1"] + try: modification_func = ModificationRegistry.get_value_from_registry(model_name) except KeyError: @@ -50,7 +53,7 @@ def modify_model(model: torch.nn.Module, disable: int = False) -> torch.nn.Modul ) return model - if NM_DISABLE_TRANSFORMERS_MODIFICATION: + if NM_DISABLE_QUANTIZATION_MODIFICATION: _LOGGER.debug( "Application of the modification function to model " "disabled through the environment variable." @@ -65,6 +68,6 @@ def modify_model(model: torch.nn.Module, disable: int = False) -> torch.nn.Modul return model _LOGGER.info( - f"Modifying the model {model_name} to be compatible with SparseML library" + f"Modifying the model {model_name} to be compatible with SparseML quantization" ) return modification_func(model) diff --git a/src/sparseml/transformers/sparsification/modification/registry.py b/src/sparseml/modifiers/quantization/modification/registry.py similarity index 53% rename from src/sparseml/transformers/sparsification/modification/registry.py rename to src/sparseml/modifiers/quantization/modification/registry.py index 894d78077d5..5deabfa23db 100644 --- a/src/sparseml/transformers/sparsification/modification/registry.py +++ b/src/sparseml/modifiers/quantization/modification/registry.py @@ -11,27 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from sparseml.transformers.sparsification.modification.base import ( - check_transformers_version, -) from sparsezoo.utils.registry import RegistryMixin class ModificationRegistry(RegistryMixin): """ A registry for modification functions that can be applied to models - so that they can be used in the context of sparseml.transformers + so that they can be compatible with the quantization format required by the + SparseML library. """ - - @classmethod - def get_value_from_registry(cls, name: str): - """ - Extends the base class method to check the transformers version after - successfully retrieving the value from the registry. The motivation is - to ensure that the transformers version falls within the supported range - before we proceed with model modification. - """ - retrieved_value = super().get_value_from_registry(name) - check_transformers_version() - return retrieved_value diff --git a/src/sparseml/modifiers/quantization/pytorch.py b/src/sparseml/modifiers/quantization/pytorch.py index 4c2d99cefa8..927d8db79d3 100644 --- a/src/sparseml/modifiers/quantization/pytorch.py +++ b/src/sparseml/modifiers/quantization/pytorch.py @@ -20,6 +20,7 @@ from sparseml.core import Event, EventType, State from sparseml.modifiers.quantization.base import QuantizationModifier +from sparseml.modifiers.quantization.modification import modify_model from sparseml.modifiers.quantization.utils.helpers import ( configure_module_bn_wrappers, freeze_bn_stats, @@ -73,11 +74,16 @@ def __init__(self, **kwargs): def on_initialize_structure(self, state: State, **kwargs): module = state.model.model + # before the structure is modified to support quantization, + # we need to potentially modify the model architecture + module = modify_model(module) self._enable_module_qat(module) state.model.model.apply(torch.quantization.disable_observer) def on_initialize(self, state: State, **kwargs) -> bool: raise_if_torch_quantization_not_available() + module = state.model.model + module = modify_model(module) if self.end and self.end != -1: raise ValueError( "end_epoch is disabled for QuantizationModifier and can only be set to" diff --git a/src/sparseml/transformers/sparsification/__init__.py b/src/sparseml/transformers/sparsification/__init__.py index ac78a780487..82d880423bf 100644 --- a/src/sparseml/transformers/sparsification/__init__.py +++ b/src/sparseml/transformers/sparsification/__init__.py @@ -19,6 +19,7 @@ # flake8: noqa +from .modification import * from .question_answering import * from .sparse_config import * from .sparse_model import * diff --git a/src/sparseml/transformers/sparsification/modification/__init__.py b/src/sparseml/transformers/sparsification/modification/__init__.py index d064f1b6bf7..69f95566d0d 100644 --- a/src/sparseml/transformers/sparsification/modification/__init__.py +++ b/src/sparseml/transformers/sparsification/modification/__init__.py @@ -11,11 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + # flake8: noqa -from .modify_model import modify_model -from .modifying_bert import * -from .modifying_distilbert import * -from .modifying_llama import * -from .modifying_mistral import * -from .modifying_mobilebert import * -from .modifying_opt import * +# isort:skip_file + +# the modification module that adds modifications +# for transformers models to enable quantization + +# import all the modification functions for the different models +from .modifying_bert import modify +from .modifying_llama import modify +from .modifying_mistral import modify +from .modifying_distilbert import modify +from .modifying_mobilebert import modify +from .modifying_opt import modify diff --git a/src/sparseml/transformers/sparsification/modification/modifying_bert.py b/src/sparseml/transformers/sparsification/modification/modifying_bert.py index 20e2e8ded4e..b1c273999ba 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_bert.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_bert.py @@ -14,68 +14,59 @@ """ Modification to the original Bert model required in the -context of SparseML +context of SparseML quantization """ -import logging import math from typing import Optional, Tuple import torch from torch import nn -from transformers.models.bert.modeling_bert import BertAttention, BertSelfAttention +from transformers.models.bert.modeling_bert import BertSelfAttention +from sparseml.modifiers.quantization.modification.modification_objects import QATMatMul +from sparseml.modifiers.quantization.modification.registry import ModificationRegistry from sparseml.pytorch.utils.helpers import swap_modules -from sparseml.transformers.sparsification.modification.modification_objects import ( - QATMatMul, +from sparseml.transformers.sparsification.modification.base import ( + check_transformers_version, ) -from sparseml.transformers.sparsification.modification.registry import ( - ModificationRegistry, -) - - -_LOGGER = logging.getLogger(__name__) @ModificationRegistry.register(name="BertModel", alias=["BertForQuestionAnswering"]) def modify(model: nn.Module) -> nn.Module: """ Modify the Bert model to be compatible with SparseML + quantization - 1. Replaces the MultiHeadSelfAttention modules with - MultiHeadSelfAttentionWithQuantizableMatmuls modules - - Note: This function will not alter any of the alternatives - to the MultiHeadSelfAttention module such as BertAttention - + Replaces the attention modules with + MultiHeadSelfAttentionWithQuantizableMatmuls modules :param model: the original Bert model :return: the modified Bert model """ + check_transformers_version() for name, submodule in model.named_modules(): - if isinstance(submodule, BertSelfAttention): + if isinstance(submodule, BertSelfAttention) and not isinstance( + submodule, BertSelfAttentionWithQuantizableMatmuls + ): swap_modules( model, name, BertSelfAttentionWithQuantizableMatmuls(submodule) ) - elif isinstance(submodule, BertAttention): - _LOGGER.debug( - f"The model contains {submodule.__class__.__name__} " - "module, which will not be modified" - ) return model class BertSelfAttentionWithQuantizableMatmuls(BertSelfAttention): """ - Wrapper around the original BertSelfAttention module to replace the + Wrapper around the original attention module to replace the matmul operations with quantizable matmul operations - :param bert_self_attention: the original BertSelfAttention module + :param bert_self_attention: the original attention module to be + wrapped and modified """ def __init__(self, bert_self_attention: BertSelfAttention): self.__class__ = type( - bert_self_attention.__class__.__name__, + self.__class__.__name__, (self.__class__, bert_self_attention.__class__), {}, ) diff --git a/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py b/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py index c37da2cbdd0..2cc9915b900 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py @@ -14,10 +14,9 @@ """ Modification to the original DistilBert model required in the -context of SparseML +context of SparseML quantization """ -import logging import math from typing import Optional, Tuple @@ -28,56 +27,49 @@ MultiHeadSelfAttention, ) +from sparseml.modifiers.quantization.modification.modification_objects import QATMatMul +from sparseml.modifiers.quantization.modification.registry import ModificationRegistry from sparseml.pytorch.utils.helpers import swap_modules -from sparseml.transformers.sparsification.modification.modification_objects import ( - QATMatMul, +from sparseml.transformers.sparsification.modification.base import ( + check_transformers_version, ) -from sparseml.transformers.sparsification.modification.registry import ( - ModificationRegistry, -) - - -_LOGGER = logging.getLogger(__name__) @ModificationRegistry.register(name="DistilBertModel") def modify(model: nn.Module) -> nn.Module: """ Modify the DistilBert model to be compatible with SparseML + quantization - 1. Replaces the MultiHeadSelfAttention modules with - MultiHeadSelfAttentionWithQuantizableMatmuls modules - - Note: This function will not alter any of the alternatives - to the MultiHeadSelfAttention module such as DistilBertFlashAttention2 + Replaces the attention modules with + MultiHeadSelfAttentionWithQuantizableMatmuls modules :param model: the original DistilBert model :return: the modified DistilBert model """ + check_transformers_version() for name, submodule in model.named_modules(): - if isinstance(submodule, MultiHeadSelfAttention): + if isinstance( + submodule, (MultiHeadSelfAttention, DistilBertFlashAttention2) + ) and not isinstance(submodule, MultiHeadSelfAttentionWithQuantizableMatmuls): swap_modules( model, name, MultiHeadSelfAttentionWithQuantizableMatmuls(submodule) ) - if isinstance(submodule, DistilBertFlashAttention2): - _LOGGER.debug( - f"The model contains {submodule.__class__.__name__} " - "module, which will not be modified" - ) return model class MultiHeadSelfAttentionWithQuantizableMatmuls(MultiHeadSelfAttention): """ - Wrapper around the original MultiHeadSelfAttention module to replace the - matmul operations with quantizable matmul operations + Wrapper around the original attention module to introduce + MultiHeadSelfAttention with quantizable matmul operations - :param mhs_attention: the original MultiHeadSelfAttention module + :param mhs_attention: the original attention module to be + wrapped and modified """ def __init__(self, mhs_attention: MultiHeadSelfAttention): self.__class__ = type( - mhs_attention.__class__.__name__, + self.__class__.__name__, (self.__class__, mhs_attention.__class__), {}, ) diff --git a/src/sparseml/transformers/sparsification/modification/modifying_llama.py b/src/sparseml/transformers/sparsification/modification/modifying_llama.py index 6c89469f524..d51827fc8f3 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_llama.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_llama.py @@ -14,12 +14,11 @@ """ Modification to the original LLaMa model required in the -context of SparseML +context of SparseML quantization """ -import logging import math -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F @@ -33,42 +32,35 @@ repeat_kv, ) -from sparseml.pytorch.utils.helpers import swap_modules -from sparseml.transformers.sparsification.modification.modification_objects import ( +from sparseml.modifiers.quantization.modification.modification_objects import ( QuantizableIdentity, QuantizableMatMul, ) -from sparseml.transformers.sparsification.modification.registry import ( - ModificationRegistry, +from sparseml.modifiers.quantization.modification.registry import ModificationRegistry +from sparseml.pytorch.utils.helpers import swap_modules +from sparseml.transformers.sparsification.modification.base import ( + check_transformers_version, ) -_LOGGER = logging.getLogger(__name__) - - @ModificationRegistry.register(name="LlamaModel", alias=["LlamaForCausalLM"]) def modify(model: nn.Module) -> nn.Module: """ Modify the LLaMa model to be compatible with SparseML + quantization - 1. Replaces the LlamaAttention modules with - LlamaAttentionWithQuantizableMatmuls modules - - Note: This function will not alter any of the alternatives - to the LlamaAttention module such as LlamaFlashAttention2 - or LlamaSdpaAttention + Replaces the attention modules with + LlamaAttentionWithQuantizableMatmuls modules :param model: the original LLaMa model :return: the modified LLaMa model """ + check_transformers_version() for name, submodule in model.named_modules(): - if isinstance(submodule, LlamaAttention): + if isinstance( + submodule, (LlamaAttention, LlamaFlashAttention2, LlamaSdpaAttention) + ) and not isinstance(submodule, LlamaAttentionWithQuantizableMatmuls): swap_modules(model, name, LlamaAttentionWithQuantizableMatmuls(submodule)) - elif isinstance(submodule, (LlamaSdpaAttention, LlamaFlashAttention2)): - _LOGGER.debug( - f"The model contains {submodule.__class__.__name__} " - "module, which will not be modified" - ) return model @@ -98,15 +90,21 @@ class MatMulOutput_PV(QuantizableIdentity): class LlamaAttentionWithQuantizableMatmuls(LlamaAttention): """ - Wrapper around the original LlamaAttention module to replace the - matmul operations with quantizable matmul operations + Wrapper around the original attention module to introduce + LlamaAttention with quantizable matmul operations - :param llama_attention: the original LlamaAttention module + :param llama_attention: the original attention module to be + wrapped and modified """ - def __init__(self, llama_attention: LlamaAttention): + def __init__( + self, + llama_attention: Union[ + LlamaAttention, LlamaFlashAttention2, LlamaSdpaAttention + ], + ): self.__class__ = type( - llama_attention.__class__.__name__, + self.__class__.__name__, (self.__class__, llama_attention.__class__), {}, ) diff --git a/src/sparseml/transformers/sparsification/modification/modifying_mistral.py b/src/sparseml/transformers/sparsification/modification/modifying_mistral.py index 28d9d7f109f..1a03d635027 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_mistral.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_mistral.py @@ -14,12 +14,12 @@ """ Modification to the original Mistral model required in the -context of SparseML +context of SparseML quantization """ -import logging + import math import warnings -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from torch import nn @@ -32,42 +32,35 @@ repeat_kv, ) -from sparseml.pytorch.utils.helpers import swap_modules -from sparseml.transformers.sparsification.modification.modification_objects import ( +from sparseml.modifiers.quantization.modification.modification_objects import ( QuantizableIdentity, QuantizableMatMul, ) -from sparseml.transformers.sparsification.modification.registry import ( - ModificationRegistry, +from sparseml.modifiers.quantization.modification.registry import ModificationRegistry +from sparseml.pytorch.utils.helpers import swap_modules +from sparseml.transformers.sparsification.modification.base import ( + check_transformers_version, ) -_LOGGER = logging.getLogger(__name__) - - @ModificationRegistry.register(name="MistralModel", alias=["MistralForCausalLM"]) def modify(model: torch.nn.Module) -> torch.nn.Module: """ Modify the Mistral model to be compatible with SparseML + quantization - 1. Replaces the MistralAttention modules with - MistralAttentionWithQuantizableMatmuls modules - - Note: This function will not alter any of the alternatives - to the MistralAttention module such as MistralFlashAttention2 - or MistralSdpaAttention + Replaces the attention modules with + MistralAttentionWithQuantizableMatmuls modules :param model: the original Mistral model :return: the modified Mistral model """ + check_transformers_version() for name, submodule in model.named_modules(): - if isinstance(submodule, MistralAttention): + if isinstance( + submodule, (MistralAttention, MistralFlashAttention2, MistralSdpaAttention) + ) and not isinstance(submodule, MistralAttentionWithQuantizableMatmuls): swap_modules(model, name, MistralAttentionWithQuantizableMatmuls(submodule)) - if isinstance(submodule, (MistralSdpaAttention, MistralFlashAttention2)): - _LOGGER.debug( - f"The model contains {submodule.__class__.__name__} " - "module, which will not be modified" - ) return model @@ -89,16 +82,22 @@ class MatMulRightInput_PV(QuantizableIdentity): class MistralAttentionWithQuantizableMatmuls(MistralAttention): """ - Wrapper around the original MistralAttention module to replace the - matmul operations with quantizable matmul operations + Wrapper around the original attention module to introduce + MistralAttention with quantizable matmul operations - :param mistral_attention: the original MistralAttention module + :param mistral_attention: the original attention module to be + wrapped and modified """ - def __init__(self, mistral_attention: MistralAttention): + def __init__( + self, + mistral_attention: Union[ + MistralAttention, MistralFlashAttention2, MistralSdpaAttention + ], + ): self.__class__ = type( - mistral_attention.__class__.__name__, + self.__class__.__name__, (self.__class__, mistral_attention.__class__), {}, ) diff --git a/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py b/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py index 57a5c6d83e4..469ca36a736 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py @@ -14,39 +14,37 @@ """ Modification to the original MobileBert model required in the -context of SparseML +context of SparseML quantization """ -import logging - from torch import nn from transformers.models.mobilebert.modeling_mobilebert import MobileBertEmbeddings +from sparseml.modifiers.quantization.modification.modification_objects import QATLinear +from sparseml.modifiers.quantization.modification.registry import ModificationRegistry from sparseml.pytorch.utils.helpers import swap_modules -from sparseml.transformers.sparsification.modification.modification_objects import ( - QATLinear, -) -from sparseml.transformers.sparsification.modification.registry import ( - ModificationRegistry, +from sparseml.transformers.sparsification.modification.base import ( + check_transformers_version, ) -_LOGGER = logging.getLogger(__name__) - - @ModificationRegistry.register(name="MobileBertModel") def modify(model: nn.Module) -> nn.Module: """ Modify the MobileBert model to be compatible with SparseML + quantization - 1. Replaces the MobileBertEmbeddings modules with - MobileBertEmbeddingsWithQuantizableMatmuls modules + Replaces the MobileBertEmbeddings modules with + MobileBertEmbeddingsWithQuantizableMatmuls modules :param model: the original MobileBert model :return: the modified MobileBert model """ + check_transformers_version() for name, submodule in model.named_modules(): - if isinstance(submodule, MobileBertEmbeddings): + if isinstance(submodule, MobileBertEmbeddings) and not isinstance( + submodule, MobileBertEmbeddingsWithQuantizableLinear + ): swap_modules( model, name, MobileBertEmbeddingsWithQuantizableLinear(submodule) ) @@ -63,7 +61,7 @@ class MobileBertEmbeddingsWithQuantizableLinear(MobileBertEmbeddings): def __init__(self, mobilebert_emb: MobileBertEmbeddings): self.__class__ = type( - mobilebert_emb.__class__.__name__, + self.__class__.__name__, (self.__class__, mobilebert_emb.__class__), {}, ) diff --git a/src/sparseml/transformers/sparsification/modification/modifying_opt.py b/src/sparseml/transformers/sparsification/modification/modifying_opt.py index 373f6fbd467..5f696ee36c7 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_opt.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_opt.py @@ -14,51 +14,44 @@ """ Modification to the original OPT model required in the -context of SparseML +context of SparseML quantization """ -import logging -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from torch import nn from transformers.models.opt.modeling_opt import OPTAttention, OptFlashAttention2 -from sparseml.pytorch.utils.helpers import swap_modules -from sparseml.transformers.sparsification.modification.modification_objects import ( +from sparseml.modifiers.quantization.modification.modification_objects import ( QuantizableBatchMatmul, QuantizableIdentity, ) -from sparseml.transformers.sparsification.modification.registry import ( - ModificationRegistry, +from sparseml.modifiers.quantization.modification.registry import ModificationRegistry +from sparseml.pytorch.utils.helpers import swap_modules +from sparseml.transformers.sparsification.modification.base import ( + check_transformers_version, ) -_LOGGER = logging.getLogger(__name__) - - @ModificationRegistry.register(name="OPTModel", alias=["OPTForCausalLM"]) def modify(model: nn.Module) -> nn.Module: """ Modify the OPT model to be compatible with SparseML + quantization - 1. Replaces the OPTAttention modules with - OPTAttentionWithQuantizableMatmuls modules - - Note: This function will not alter any of the alternatives - to the OPTAttention module such as OptFlashAttention2 + Replaces the OPT attention modules with + OPTAttentionWithQuantizableMatmuls modules :param model: the original OPT model :return: the modified OPT model """ + check_transformers_version() for name, submodule in model.named_modules(): - if isinstance(submodule, OPTAttention): + if isinstance(submodule, (OPTAttention, OptFlashAttention2)) and not isinstance( + submodule, OPTAttentionWithQuantizableMatmuls + ): swap_modules(model, name, OPTAttentionWithQuantizableMatmuls(submodule)) - elif isinstance(submodule, OptFlashAttention2): - _LOGGER.debug( - f"The model contains {submodule.__class__.__name__} " - "module, which will not be modified" - ) return model @@ -88,15 +81,16 @@ class BMMOutput_PV(QuantizableIdentity): class OPTAttentionWithQuantizableMatmuls(OPTAttention): """ - Wrapper around the original OPTAttention module to replace the - matmul operations with quantizable matmul operations + Wrapper around the original attention module to introduce + OPTAttention with quantizable matmul operations - :param opt_attention: the original OPTAttention module + :param opt_attention: the original attention module to be + wrapped and modified """ - def __init__(self, opt_attention: OPTAttention): + def __init__(self, opt_attention: Union[OptFlashAttention2, OPTAttention]): self.__class__ = type( - opt_attention.__class__.__name__, + self.__class__.__name__, (self.__class__, opt_attention.__class__), {}, ) diff --git a/src/sparseml/transformers/sparsification/sparse_model.py b/src/sparseml/transformers/sparsification/sparse_model.py index 88f90de65d9..22ea6a21848 100644 --- a/src/sparseml/transformers/sparsification/sparse_model.py +++ b/src/sparseml/transformers/sparsification/sparse_model.py @@ -30,6 +30,7 @@ ) from transformers.file_utils import WEIGHTS_NAME +from sparseml.modifiers.quantization.modification import modify_model from sparseml.pytorch.model_load.helpers import ( apply_recipe_structure_to_model, log_model_load, @@ -39,7 +40,6 @@ infer_compressor_from_model_config, modify_save_pretrained, ) -from sparseml.transformers.sparsification.modification import modify_model from sparseml.transformers.utils.helpers import download_model_directory, resolve_recipe @@ -59,11 +59,9 @@ class SparseAutoModelForCausalLM(AutoModelForCausalLM): of the model will be retrieved 2. The original model definition will be loaded, without the model weights - 3. The model will be potentially modifier by `modify_model` - function, so that is compatible with SparseML - 4. The appropriate recipy will be applied to the model + 3. The appropriate recipy will be applied to the model if requested or required - 5. The appropriate set of weights will be loaded into the model + 4. The appropriate set of weights will be loaded into the model """ @classmethod @@ -115,7 +113,6 @@ def skip(*args, **kwargs): pretrained_model_name_or_path, *model_args, **kwargs ) logger.setLevel(level=restore_log_level) - model = modify_model(model) # override the PreTrainedModel instance with compression save function modify_save_pretrained(model) @@ -144,6 +141,8 @@ class SparseAutoModel: Factory class for creating sparse models using transformers AutoModel classes """ + from sparseml.modifiers.quantization.modification import modify_model + @staticmethod def masked_language_modeling_from_pretrained( model_name_or_path: str, diff --git a/tests/sparseml/transformers/sparsification/modification/test_modify_model.py b/tests/sparseml/modifiers/quantization/modification/test_modify_model.py similarity index 91% rename from tests/sparseml/transformers/sparsification/modification/test_modify_model.py rename to tests/sparseml/modifiers/quantization/modification/test_modify_model.py index 1b4da67a553..2bde19a5757 100644 --- a/tests/sparseml/transformers/sparsification/modification/test_modify_model.py +++ b/tests/sparseml/modifiers/quantization/modification/test_modify_model.py @@ -17,10 +17,8 @@ import pytest -from sparseml.transformers.sparsification.modification import modify_model -from sparseml.transformers.sparsification.modification.registry import ( - ModificationRegistry, -) +from sparseml.modifiers.quantization.modification import modify_model +from sparseml.modifiers.quantization.modification.registry import ModificationRegistry from sparsezoo.utils.registry import _ALIAS_REGISTRY, _REGISTRY, standardize_lookup_name @@ -88,7 +86,7 @@ def dummy_modification(model): return model is_modified = copy(model.modified) - monkeypatch.setenv("NM_DISABLE_TRANSFORMERS_MODIFICATION", "1") + monkeypatch.setenv("NM_DISABLE_QUANTIZATION_MODIFICATION", "1") model = modify_model(model) assert model.modified == is_modified == False # noqa E712 monkeypatch.undo() diff --git a/tests/sparseml/transformers/obcq/test_obcq.py b/tests/sparseml/transformers/obcq/test_obcq.py index f61ac1c2567..6f0f0108db2 100644 --- a/tests/sparseml/transformers/obcq/test_obcq.py +++ b/tests/sparseml/transformers/obcq/test_obcq.py @@ -25,6 +25,9 @@ from sparseml.pytorch.model_load.helpers import get_session_model from sparseml.pytorch.utils.helpers import tensor_sparsity from sparseml.transformers import SparseAutoModelForCausalLM, oneshot +from sparseml.transformers.sparsification.modification.modifying_llama import ( + LlamaAttentionWithQuantizableMatmuls, +) @pytest.mark.parametrize( @@ -35,20 +38,36 @@ "tests/sparseml/transformers/obcq/quant_and_sparse.yaml", ], ) -def test_obcq_tinystories(recipe_file_path): +def test_obcq_tinystories(tmp_path, recipe_file_path): tiny_model_path = "Xenova/llama2.c-stories15M" - device = "cuda:0" - if not torch.cuda.is_available(): - device = "cpu" + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model = SparseAutoModelForCausalLM.from_pretrained( + tiny_model_path, device_map=device + ) oneshot( - model=tiny_model_path, + model=model, dataset="open_platypus", oneshot_device=device, recipe=recipe_file_path, max_seq_length=128, num_calibration_samples=64, pad_to_max_length=False, + output_dir=tmp_path / "temp_output", + ) + + is_model_quantized = "quant" in recipe_file_path + # if quantization recipe has been applied to the model, + # assert that the attention modules + # (6 of them for the tested tiny llama model), + # have been swapped for LlamaAttentionWithQuantizableMatmuls + assert is_model_quantized == ( + sum( + module.__class__.__name__ + == LlamaAttentionWithQuantizableMatmuls.__name__ # noqa E501 + for module in model.modules() + ) + == 6 ) @@ -89,6 +108,8 @@ def test_lm_head_target(): layers_head = len(sparsegpt_modifier_head.compressible_layers_) assert layers_head == layers_no_head + 1 + # check that the + def test_sparsities(): tiny_model_path = "Xenova/llama2.c-stories15M" diff --git a/tests/sparseml/transformers/sparsification/modification/conftest.py b/tests/sparseml/transformers/sparsification/modification/conftest.py index f5257f1e9b3..d6a9fd1c0ad 100644 --- a/tests/sparseml/transformers/sparsification/modification/conftest.py +++ b/tests/sparseml/transformers/sparsification/modification/conftest.py @@ -18,80 +18,81 @@ from transformers import AutoConfig, AutoModel from accelerate import init_empty_weights +from sparseml.modifiers.quantization.modification import modify_model +from sparseml.pytorch.model_load.helpers import apply_recipe_structure_to_model from sparseml.transformers import SparseAutoConfig, SparseAutoModelForCausalLM -from sparseml.transformers.sparsification.modification import modify_model @pytest.fixture -def mistral_zoo_model(): - stub = "zoo:mistral-7b-evolcodealpaca_mistral_pretrain-pruned50_quantized" - config = SparseAutoConfig.from_pretrained(stub) +def bert_model(): + config = AutoConfig.from_pretrained("bert-base-uncased") with init_empty_weights(): - model = SparseAutoModelForCausalLM.from_config(config) + model = AutoModel.from_config(config) return model @pytest.fixture -def opt_zoo_model(): - stub = "zoo:opt-1.3b-opt_pretrain-quantW8A8" - config = SparseAutoConfig.from_pretrained(stub) +def distilbert_model(): + config = AutoConfig.from_pretrained("distilbert/distilbert-base-uncased") with init_empty_weights(): - model = SparseAutoModelForCausalLM.from_config(config) + model = AutoModel.from_config(config) return model @pytest.fixture -def llama_zoo_model(): - stub = "zoo:llama2-7b-llama2_chat_llama2_pretrain-base_quantized" - config = SparseAutoConfig.from_pretrained(stub) +def mobilebert_model(): + config = AutoConfig.from_pretrained("google/mobilebert-uncased") with init_empty_weights(): - model = SparseAutoModelForCausalLM.from_config(config) + model = AutoModel.from_config(config) return model @pytest.fixture -def bert_model(): - config = AutoConfig.from_pretrained("bert-base-uncased") +def opt_zoo_model(): + stub = "zoo:opt-1.3b-opt_pretrain-quantW8A8" + config = SparseAutoConfig.from_pretrained(stub) with init_empty_weights(): - model = AutoModel.from_config(config) + model = SparseAutoModelForCausalLM.from_config(config) return model @pytest.fixture -def distilbert_model(): - config = AutoConfig.from_pretrained("distilbert/distilbert-base-uncased") +def opt_model(): + config = AutoConfig.from_pretrained("facebook/opt-1.3b") with init_empty_weights(): model = AutoModel.from_config(config) return model @pytest.fixture -def mistral_model(): - config = AutoConfig.from_pretrained("NousResearch/Hermes-2-Pro-Mistral-7B") +def mistral_zoo_model(): + stub = "zoo:mistral-7b-evolcodealpaca_mistral_pretrain-pruned50_quantized" + config = SparseAutoConfig.from_pretrained(stub) with init_empty_weights(): - model = AutoModel.from_config(config) + model = SparseAutoModelForCausalLM.from_config(config) return model @pytest.fixture -def mobilebert_model(): - config = AutoConfig.from_pretrained("google/mobilebert-uncased") +def llama_zoo_model(): + stub = "zoo:llama2-7b-llama2_chat_llama2_pretrain-base_quantized" + config = SparseAutoConfig.from_pretrained(stub) with init_empty_weights(): - model = AutoModel.from_config(config) + model = SparseAutoModelForCausalLM.from_config(config) return model @pytest.fixture -def llama_model(): - config = AutoConfig.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") +def mistral_model(): + config = AutoConfig.from_pretrained("NousResearch/Hermes-2-Pro-Mistral-7B") with init_empty_weights(): model = AutoModel.from_config(config) return model @pytest.fixture -def opt_model(): - config = AutoConfig.from_pretrained("facebook/opt-1.3b") +def llama_model(): + config = AutoConfig.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") with init_empty_weights(): model = AutoModel.from_config(config) return model @@ -104,54 +105,72 @@ def shared_helper_functions(): class SharedHelperFunctions: @staticmethod - def check_model_modified( - original_model_, module_to_replace, func_to_validate_replacement + def check_model_modified_non_causal( + original_model_, modified_module, num_modified_modules=None ): - num_attn_blocks = original_model_.config.num_hidden_layers + num_attn_blocks = original_model_.config.num_hidden_layers + num_modified_modules = num_modified_modules or num_attn_blocks original_model = deepcopy(original_model_) modified_model = modify_model(original_model_) - modified_modules_original_model = [ - module - for module in original_model.modules() - if func_to_validate_replacement(module) - and isinstance(module, module_to_replace) - ] - - modified_modules_modified_model = [ - module - for module in modified_model.modules() - if func_to_validate_replacement(module) - and isinstance(module, module_to_replace) - ] - - original_modules_original_model = [ - module - for module in original_model.modules() - if not func_to_validate_replacement(module) - and isinstance(module, module_to_replace) - ] - - original_modules_modified_model = [ - module - for module in modified_model.modules() - if not func_to_validate_replacement(module) - and isinstance(module, module_to_replace) - ] - - # make sure that the original model has no modified modules - # and that the modified model has no original modules + # make sure that the original model has 0 modified modules + # and that the modified model has N modified modules + # where N is the number of transformer's attention blocks assert ( - len(modified_modules_original_model) - == len(original_modules_modified_model) + sum( + [ + module.__class__.__name__ == modified_module.__name__ + for module in modified_model.modules() + ] + ) + == num_modified_modules + ) + assert ( + sum( + [ + module.__class__.__name__ == modified_module.__name__ + for module in original_model.modules() + ] + ) == 0 ) - # make sure that the original model has N original modules + + @staticmethod + def check_model_modified_causal( + original_model_, + modified_module, + recipe, + ): + num_attn_blocks = original_model_.config.num_hidden_layers + + original_model = deepcopy(original_model_) + modified_model = original_model_ + + apply_recipe_structure_to_model( + model=modified_model, + model_path=None, + recipe_path=recipe, + ) + + # make sure that the original model has 0 modified modules # and that the modified model has N modified modules # where N is the number of transformer's attention blocks assert ( - len(modified_modules_modified_model) - == len(original_modules_original_model) + sum( + [ + module.__class__.__name__ == modified_module.__name__ + for module in modified_model.modules() + ] + ) == num_attn_blocks ) + assert ( + sum( + [ + module.__class__.__name__ == modified_module.__name__ + for module in original_model.modules() + ] + ) + == 0 + ) diff --git a/tests/sparseml/transformers/sparsification/modification/test_modifying_bert.py b/tests/sparseml/transformers/sparsification/modification/test_modifying_bert.py deleted file mode 100644 index edda094fc63..00000000000 --- a/tests/sparseml/transformers/sparsification/modification/test_modifying_bert.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from transformers.models.bert.modeling_bert import BertSelfAttention - - -def test_modifying_bert(bert_model, shared_helper_functions): - - shared_helper_functions.check_model_modified( - bert_model, - module_to_replace=BertSelfAttention, - func_to_validate_replacement=_is_bert_attention_modified, - ) - - -def _is_bert_attention_modified(module): - # only the modified "BertSelfAttention" modules have the - # modules have the "attention_scores_matmul" attribute - return hasattr(module, "attention_scores_matmul") diff --git a/tests/sparseml/transformers/sparsification/modification/test_modifying_distillbert.py b/tests/sparseml/transformers/sparsification/modification/test_modifying_distillbert.py deleted file mode 100644 index 999b0dec938..00000000000 --- a/tests/sparseml/transformers/sparsification/modification/test_modifying_distillbert.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention - - -def test_modifying_distilbert(distilbert_model, shared_helper_functions): - shared_helper_functions.check_model_modified( - distilbert_model, - module_to_replace=MultiHeadSelfAttention, - func_to_validate_replacement=_is_distilbert_attention_modified, - ) - - -def _is_distilbert_attention_modified(module): - # only the modified "MultiHeadSelfAttention" modules have the - # modules have the "attention_scores_matmul" attribute - return hasattr(module, "attention_scores_matmul") diff --git a/tests/sparseml/transformers/sparsification/modification/test_modifying_llama.py b/tests/sparseml/transformers/sparsification/modification/test_modifying_llama.py index 7a20145df62..9091d28b29e 100644 --- a/tests/sparseml/transformers/sparsification/modification/test_modifying_llama.py +++ b/tests/sparseml/transformers/sparsification/modification/test_modifying_llama.py @@ -13,10 +13,10 @@ # limitations under the License. import pytest -from transformers.models.llama.modeling_llama import LlamaAttention -from sparseml.pytorch.model_load.helpers import apply_recipe_structure_to_model -from sparseml.transformers.sparsification.modification import modify_model +from sparseml.transformers.sparsification.modification.modifying_llama import ( + LlamaAttentionWithQuantizableMatmuls, +) @pytest.fixture @@ -39,31 +39,22 @@ def llama_recipe(): symmetric: False""" -def test_modifying_llama(llama_model, shared_helper_functions): - - shared_helper_functions.check_model_modified( +def test_modify_with_quantization_recipe( + llama_model, llama_recipe, shared_helper_functions +): + shared_helper_functions.check_model_modified_causal( llama_model, - module_to_replace=LlamaAttention, - func_to_validate_replacement=_is_llama_attention_modified, + recipe=llama_recipe, + modified_module=LlamaAttentionWithQuantizableMatmuls, ) -def test_apply_recipe_fail(llama_recipe, llama_zoo_model): - - with pytest.raises(Exception): - apply_recipe_structure_to_model( - model=llama_zoo_model, model_path=None, recipe_path=llama_recipe - ) - - -def test_apply_recipe(llama_recipe, llama_zoo_model): - apply_recipe_structure_to_model( - model=modify_model(llama_zoo_model), model_path=None, recipe_path=llama_recipe +def test_modify_with_quantization_recipe_sparsezoo( + llama_zoo_model, llama_recipe, shared_helper_functions +): + # TODO: Improve that + shared_helper_functions.check_model_modified_causal( + llama_zoo_model, + recipe=llama_recipe, + modified_module=LlamaAttentionWithQuantizableMatmuls, ) - assert True - - -def _is_llama_attention_modified(module): - # only the modified "LlamaAttention" - # modules have the "attn_output_matmul" attribute - return hasattr(module, "attn_output_matmul") diff --git a/tests/sparseml/transformers/sparsification/modification/test_modifying_mistral.py b/tests/sparseml/transformers/sparsification/modification/test_modifying_mistral.py index 7d241fc3de1..e71364a53e7 100644 --- a/tests/sparseml/transformers/sparsification/modification/test_modifying_mistral.py +++ b/tests/sparseml/transformers/sparsification/modification/test_modifying_mistral.py @@ -13,10 +13,10 @@ # limitations under the License. import pytest -from transformers.models.mistral.modeling_mistral import MistralAttention -from sparseml.pytorch.model_load.helpers import apply_recipe_structure_to_model -from sparseml.transformers.sparsification.modification import modify_model +from sparseml.transformers.sparsification.modification.modifying_mistral import ( + MistralAttentionWithQuantizableMatmuls, +) @pytest.fixture @@ -37,32 +37,21 @@ def mistral_recipe(): symmetric: False""" -def test_modifying_mistral(mistral_model, shared_helper_functions): - shared_helper_functions.check_model_modified( +def test_modify_with_quantization_recipe( + mistral_model, mistral_recipe, shared_helper_functions +): + shared_helper_functions.check_model_modified_causal( mistral_model, - module_to_replace=MistralAttention, - func_to_validate_replacement=_is_mistral_attention_modified, + recipe=mistral_recipe, + modified_module=MistralAttentionWithQuantizableMatmuls, ) -def test_apply_recipe_fail(mistral_recipe, mistral_zoo_model): - with pytest.raises(Exception): - apply_recipe_structure_to_model( - model=mistral_zoo_model, model_path=None, recipe_path=mistral_recipe - ) - - -def test_apply_recipe(mistral_recipe, mistral_zoo_model): - - apply_recipe_structure_to_model( - model=modify_model(mistral_zoo_model), - model_path=None, - recipe_path=mistral_recipe, +def test_modify_with_quantization_recipe_sparsezoo( + mistral_zoo_model, mistral_recipe, shared_helper_functions +): + shared_helper_functions.check_model_modified_causal( + mistral_zoo_model, + recipe=mistral_recipe, + modified_module=MistralAttentionWithQuantizableMatmuls, ) - assert True - - -def _is_mistral_attention_modified(module): - # only the modified "MistralAttention" - # modules have the "attn_output_matmul" attribute - return hasattr(module, "attn_output_matmul") diff --git a/tests/sparseml/transformers/sparsification/modification/test_modifying_non_causal.py b/tests/sparseml/transformers/sparsification/modification/test_modifying_non_causal.py new file mode 100644 index 00000000000..003f036be5b --- /dev/null +++ b/tests/sparseml/transformers/sparsification/modification/test_modifying_non_causal.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sparseml.transformers.sparsification.modification.modifying_bert import ( + BertSelfAttentionWithQuantizableMatmuls, +) +from sparseml.transformers.sparsification.modification.modifying_distilbert import ( + MultiHeadSelfAttentionWithQuantizableMatmuls, +) +from sparseml.transformers.sparsification.modification.modifying_mobilebert import ( + MobileBertEmbeddingsWithQuantizableLinear, +) + + +def test_modify_distilbert(distilbert_model, shared_helper_functions): + shared_helper_functions.check_model_modified_non_causal( + distilbert_model, modified_module=MultiHeadSelfAttentionWithQuantizableMatmuls + ) + + +def test_modify_bert(bert_model, shared_helper_functions): + shared_helper_functions.check_model_modified_non_causal( + bert_model, modified_module=BertSelfAttentionWithQuantizableMatmuls + ) + + +def test_modify_mobilebert(mobilebert_model, shared_helper_functions): + shared_helper_functions.check_model_modified_non_causal( + mobilebert_model, + modified_module=MobileBertEmbeddingsWithQuantizableLinear, + num_modified_modules=1, + ) diff --git a/tests/sparseml/transformers/sparsification/modification/test_modifying_opt.py b/tests/sparseml/transformers/sparsification/modification/test_modifying_opt.py index 6870665f6b6..411371b0bbf 100644 --- a/tests/sparseml/transformers/sparsification/modification/test_modifying_opt.py +++ b/tests/sparseml/transformers/sparsification/modification/test_modifying_opt.py @@ -14,10 +14,10 @@ import pytest -from transformers.models.opt.modeling_opt import OPTAttention -from sparseml.pytorch.model_load.helpers import apply_recipe_structure_to_model -from sparseml.transformers.sparsification.modification import modify_model +from sparseml.transformers.sparsification.modification.modifying_opt import ( + OPTAttentionWithQuantizableMatmuls, +) @pytest.fixture @@ -40,32 +40,21 @@ def opt_recipe(): symmetric: False""" -def test_modifying_opt(opt_model, shared_helper_functions): - - shared_helper_functions.check_model_modified( +def test_modify_with_quantization_recipe( + opt_model, opt_recipe, shared_helper_functions +): + shared_helper_functions.check_model_modified_causal( opt_model, - module_to_replace=OPTAttention, - func_to_validate_replacement=_is_opt_attention_modified, + recipe=opt_recipe, + modified_module=OPTAttentionWithQuantizableMatmuls, ) -def test_apply_recipe_fail(opt_recipe, opt_zoo_model): - - with pytest.raises(Exception): - apply_recipe_structure_to_model( - model=opt_zoo_model, model_path=None, recipe_path=opt_recipe - ) - - -def test_apply_recipe(opt_recipe, opt_zoo_model): - - apply_recipe_structure_to_model( - model=modify_model(opt_zoo_model), model_path=None, recipe_path=opt_recipe +def test_modify_with_quantization_recipe_sparsezoo( + opt_zoo_model, opt_recipe, shared_helper_functions +): + shared_helper_functions.check_model_modified_causal( + opt_zoo_model, + recipe=opt_recipe, + modified_module=OPTAttentionWithQuantizableMatmuls, ) - assert True - - -def _is_opt_attention_modified(module): - # only the modified "OPTAttention" - # modules have the "attn_output_bmm" attribute - return hasattr(module, "attn_output_bmm") From a7315e48d3848c467c41b39d7678e443b6a3304e Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Wed, 1 May 2024 17:49:50 +0200 Subject: [PATCH 7/8] SparseML dependency on `compressed-tensors` (#2229) * initial commit * update setup.py * Update setup.py * fix setup.py * move all config to sparsetensors * cleanup class name and comments * refactor to compressed-tensors * all tests passing * remove unused config * bring back SparsityConfigMetadata * Update setup.py Co-authored-by: Rahul Tuli * update setup * fix import problem * fix clearml test * compressed-tensors are transformers dep * update compatibility to 0.3.2 --------- Co-authored-by: Sara Adkins Co-authored-by: Rahul Tuli --- setup.py | 1 + .../transformers/compression/README.md | 162 ------------ .../transformers/compression/__init__.py | 5 - .../compression/compressors/__init__.py | 19 -- .../compression/compressors/base.py | 83 ------ .../compression/compressors/dense.py | 32 --- .../compression/compressors/sparse_bitmask.py | 237 ------------------ .../compression/config/__init__.py | 19 -- .../transformers/compression/config/dense.py | 36 --- .../compression/config/sparse_bitmask.py | 36 --- .../{config/base.py => sparsity_config.py} | 34 +-- .../compression/utils/__init__.py | 19 -- .../transformers/compression/utils/helpers.py | 46 ---- .../compression/utils/safetensors_load.py | 196 --------------- .../compressed_tensors_utils.py} | 17 +- .../sparsification/sparse_model.py | 7 +- src/sparseml/transformers/utils/helpers.py | 2 +- .../transformers/compression/test_bitmask.py | 122 --------- .../compression/test_registries.py | 50 ---- .../test_compress_tensor_utils.py} | 17 +- 20 files changed, 37 insertions(+), 1103 deletions(-) delete mode 100644 src/sparseml/transformers/compression/README.md delete mode 100644 src/sparseml/transformers/compression/compressors/__init__.py delete mode 100644 src/sparseml/transformers/compression/compressors/base.py delete mode 100644 src/sparseml/transformers/compression/compressors/dense.py delete mode 100644 src/sparseml/transformers/compression/compressors/sparse_bitmask.py delete mode 100644 src/sparseml/transformers/compression/config/__init__.py delete mode 100644 src/sparseml/transformers/compression/config/dense.py delete mode 100644 src/sparseml/transformers/compression/config/sparse_bitmask.py rename src/sparseml/transformers/compression/{config/base.py => sparsity_config.py} (80%) delete mode 100644 src/sparseml/transformers/compression/utils/__init__.py delete mode 100644 src/sparseml/transformers/compression/utils/helpers.py delete mode 100644 src/sparseml/transformers/compression/utils/safetensors_load.py rename src/sparseml/transformers/{compression/utils/compress_save.py => sparsification/compressed_tensors_utils.py} (90%) delete mode 100644 tests/sparseml/transformers/compression/test_bitmask.py delete mode 100644 tests/sparseml/transformers/compression/test_registries.py rename tests/sparseml/transformers/{compression/test_sparse_auto.py => sparsification/test_compress_tensor_utils.py} (88%) diff --git a/setup.py b/setup.py index 65bc1738a99..281aa1d9ded 100644 --- a/setup.py +++ b/setup.py @@ -87,6 +87,7 @@ "evaluate>=0.4.1", "accelerate>=0.20.3", "safetensors>=0.4.1", + "compressed-tensors", ] _llm_deps = _transformers_deps + ["sentencepiece"] _yolov5_deps = _pytorch_vision_deps + [ diff --git a/src/sparseml/transformers/compression/README.md b/src/sparseml/transformers/compression/README.md deleted file mode 100644 index 51d49adecc5..00000000000 --- a/src/sparseml/transformers/compression/README.md +++ /dev/null @@ -1,162 +0,0 @@ -# Save/Load Compressed SafeTensors - -## Motivation - -* Reduce disk space by saving in a compressed format for sparse models. Models in this compressed format will be loaded by vLLM for more efficient inference -* Set up the save/load architecture such that we can easily expand to additional compression formats in the future. The config should be human readable so users can understand the compression format at a quick glance - -## SafeTensors File Format - -For each parameter in the uncompressed state_dict, we store the following attributes -needed for decompression in the compressed state_dict: - -* compressed tensor -* bitmask -* uncompressed shape -* row offsets - -```python -# dense -{ - PARAM_NAME: uncompressed_tensor -} - -# compressed -{ - PARAM_NAME.compressed: compressed_tensor # 1d tensor - PARAM_NAME.bitmask: value # 2d bitmask tensor (nrows x (ncols / 8)) - PARAM_NAME.shape: value # uncompressed shape tensor - PARAM_NAME.row_offsets: value # 1d offsets tensor -} -``` - -Config information gets stored in the HF config file -```json -// config.json -{ - "sparsity_config": { - "format": "sparse_bitmask", // "dense_sparsity" for original tensor format - - // informational - "sparsity_structure": "unstructured", // or 2:4, 8:16 etc... - "global_sparsity": "0.5" - } -} -``` - -## Saving/Loading Interface - -Loading in a compressed model requires no interface changes - -```python -from sparseml.transformers.utils import SparseAutoModelForCausalLM - -# should contain model.safetensors or model.safetensors.index.json -model_path = "/PATH/TO/COMPRESSED_MODEL" - -model = SparseAutoModelForCausalLM.from_pretrained( - model_name_or_path=model_path, - **model_kwargs, -) -``` - -Saving a compressed model with an explicitly provided compression config. The config -is saved to the model's `config.json` file. **Note:** the model must have been -initialized with SparseAutoModelForCausalLM.from_pretrained() - -```python -from sparseml.transformers.compression import BitmaskConfig - -output_dir = "/PATH/TO/SAVE/COMPRESSED_MODEL" -sparsity_config = BitmaskConfig() - -model.save_pretrained( - save_directory=output_dir, - sparsity_config=sparsity_config, -) -``` - -Saving a compressed model, inferring the config from the model attributes - -```python -model.save_pretrained( - save_directory=output_dir, - save_compressed=True -) -``` - -Saving a model in the dense format. If the model has at least 5% global sparsity a -sparsity config will still be included in `config.json` with format `dense_sparsity` - -```python -model.save_pretrained( - save_directory=output_dir -) -``` - -Saving a model in the dense format, bypassing the sparsity config calculation. When the -`skip_compression_stats` flag is set, no sparsity config will be written to -`config.json` - -```python -model.save_pretrained( - save_directory=output_dir - skip_compression_stats=True -) -``` - -## Enable Compression During One-Shot and Sparse Finetunining -Models that are saved in a supported compressed format on disk will automatically be -decompressed when loaded as input to `sparseml.transformers.oneshot` or -`sparseml.transformers.train` - -To enable compression on save after oneshot or finetuning simply add the -`save_compressed=True` argument to `sparseml.transformers.oneshot` or -`sparseml.transformers.train` - -```python -from sparseml.transformers import train - -train( - save_compressed=True, - model="neuralmagic/TinyLlama-1.1B-Chat-v1.0-pruned2.4", - recipe=RECIPE, - dataset=DATASET -) -``` - - -## Example Code - -Loads a 60% sparse model, compresses it using the inferred bitmask compression, then -reloads the compressed model. - -```python -from sparseml.transformers import SparseAutoModelForCausalLM -from sparseml.utils.pytorch.utils import measure_cuda_memory -import torch - -MODEL_PATH = "zoo:llama2-7b-open_platypus_orca_llama2_pretrain-pruned60" -OUTPUT_PATH = "./test_compress_output" -RECIPE = "zoo:llama2-7b-open_platypus_orca_llama2_pretrain-pruned60" - -torch.cuda.set_device(0) -with measure_cuda_memory() as m: - model = SparseAutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="cuda:0") -print(f"Load dense model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") - -sparsity_config = getattr(model,"sparsity_config", None) -print(f"Sparsity config before compression: {sparsity_config}") -with measure_cuda_memory() as m: - model.save_pretrained(OUTPUT_PATH, save_compressed=True) -print(f"Save compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") - -torch.cuda.set_device(1) -with measure_cuda_memory() as m: - model_again = SparseAutoModelForCausalLM.from_pretrained( - OUTPUT_PATH, device_map="cuda:1" - ) -print(f"Load compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB") -sparsity_config = getattr(model_again,"sparsity_config", None) -print(f"Sparsity config after compression: {sparsity_config}") -``` diff --git a/src/sparseml/transformers/compression/__init__.py b/src/sparseml/transformers/compression/__init__.py index ca37b25df52..0c44f887a47 100644 --- a/src/sparseml/transformers/compression/__init__.py +++ b/src/sparseml/transformers/compression/__init__.py @@ -11,8 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# flake8: noqa - -from .compressors import * -from .config import * diff --git a/src/sparseml/transformers/compression/compressors/__init__.py b/src/sparseml/transformers/compression/compressors/__init__.py deleted file mode 100644 index e8a36527c04..00000000000 --- a/src/sparseml/transformers/compression/compressors/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# flake8: noqa - -from .base import ModelCompressor -from .dense import DenseCompressor -from .sparse_bitmask import BitmaskCompressor diff --git a/src/sparseml/transformers/compression/compressors/base.py b/src/sparseml/transformers/compression/compressors/base.py deleted file mode 100644 index 2a1a37d9196..00000000000 --- a/src/sparseml/transformers/compression/compressors/base.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import operator -from typing import Dict, Generator, Tuple - -from torch import Tensor -from torch.nn import Module, Parameter -from tqdm import tqdm - -from sparseml.transformers.compression.config import CompressionConfig -from sparseml.transformers.utils.helpers import SPARSITY_CONFIG_NAME -from sparseml.utils.pytorch.module import set_layer -from sparsezoo.utils.registry import RegistryMixin - - -__all__ = ["ModelCompressor"] - - -class ModelCompressor(RegistryMixin): - """ - Base class representing a model compression algorithm. - - :param config: config specifying compression parameters - """ - - def __init__(self, config: CompressionConfig): - self.config = config - - def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: - """ - Compresses a dense state dict - - :param model_state: state dict of uncompressed model - :return: compressed state dict - """ - raise NotImplementedError() - - def decompress(self, model_path: str) -> Generator[Tuple[str, Tensor], None, None]: - """ - Reads a compressed state dict located at model_path and returns a - generator for sequentially decompressing back to a dense state dict - - :param model_path: path to compressed safetensors model - :return: compressed state dict - """ - raise NotImplementedError() - - @staticmethod - def replace_layer(param_name: str, data: Tensor, model: Module): - """ - Overwrites a parameterized layer with a new tensor, maintaining the device of - the original parameter - - :param param_name: name of parameterized layer to replace - :param data: tensor to insert into model - :param model: pytorch model to insert data into - """ - model_device = operator.attrgetter(param_name)(model).device - set_layer(param_name, Parameter(data.to(model_device)), model) - - def overwrite_weights(self, model_path: str, model: Module): - """ - Overwrites the weights in model with weights decompressed from model_path - - :param model_path: path to compressed weights - :param model: pytorch model to load decompressed weights into - """ - dense_gen = self.decompress(model_path) - for name, data in tqdm(dense_gen, desc="Decompressing model"): - ModelCompressor.replace_layer(name, data, model) - setattr(model, SPARSITY_CONFIG_NAME, self.config) diff --git a/src/sparseml/transformers/compression/compressors/dense.py b/src/sparseml/transformers/compression/compressors/dense.py deleted file mode 100644 index e40ea92e6c6..00000000000 --- a/src/sparseml/transformers/compression/compressors/dense.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Generator, Tuple - -from torch import Tensor - -from sparseml.transformers.compression.compressors import ModelCompressor - - -@ModelCompressor.register(name="dense_sparsity") -class DenseCompressor(ModelCompressor): - """ - Identity compressor for dense models, returns the original state_dict - """ - - def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: - return model_state - - def decompress(self, model_path: str) -> Generator[Tuple[str, Tensor], None, None]: - return iter([]) diff --git a/src/sparseml/transformers/compression/compressors/sparse_bitmask.py b/src/sparseml/transformers/compression/compressors/sparse_bitmask.py deleted file mode 100644 index 1c6f35c7171..00000000000 --- a/src/sparseml/transformers/compression/compressors/sparse_bitmask.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Dict, Generator, List, Tuple, Union - -import numpy -import torch -from torch import Tensor -from tqdm import tqdm - -from safetensors import safe_open -from sparseml.transformers.compression.compressors import ModelCompressor -from sparseml.transformers.compression.utils import ( - get_nested_weight_mappings, - merge_names, -) - - -__all__ = [ - "BitmaskCompressor", - "BitmaskTensor", - "bitmask_compress", - "bitmask_decompress", - "pack_bitmasks", - "unpack_bitmasks", -] - -_LOGGER: logging.Logger = logging.getLogger(__name__) - - -@ModelCompressor.register(name="sparse_bitmask") -class BitmaskCompressor(ModelCompressor): - """ - Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d - values tensor, with their locations stored in a 2d bitmask - """ - - COMPRESSION_PARAM_NAMES = ["shape", "compressed", "bitmask", "row_offsets"] - - def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: - """ - Compresses a dense state dict using bitmask compression - - :param model_state: state dict of uncompressed model - :return: compressed state dict - """ - compressed_dict = {} - _LOGGER.debug( - f"Compressing model with {len(model_state)} parameterized layers..." - ) - for name, value in tqdm(model_state.items(), desc="Compressing model"): - bitmask_tensor = BitmaskTensor.from_dense(value) - bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu") - for key in bitmask_dict.keys(): - if key in compressed_dict: - _LOGGER.warn( - f"Expected all compressed state_dict keys to be unique, but " - f"found an existing entry for {key}. The existing entry will " - "be replaced." - ) - compressed_dict.update(bitmask_dict) - - return compressed_dict - - def decompress(self, model_path: str) -> Generator[Tuple[str, Tensor], None, None]: - """ - Reads a bitmask compressed state dict located at model_path and returns a - generator for sequentially decompressing back to a dense state dict - - :param model_path: path to compressed safetensors model - :return: iterator for generating decompressed weights - """ - weight_mappings = get_nested_weight_mappings( - model_path, self.COMPRESSION_PARAM_NAMES - ) - for weight_name in weight_mappings.keys(): - weight_data = {} - for param_name, safe_path in weight_mappings[weight_name].items(): - full_name = merge_names(weight_name, param_name) - with safe_open(safe_path, framework="pt", device="cpu") as f: - weight_data[param_name] = f.get_tensor(full_name) - data = BitmaskTensor(**weight_data) - decompressed = data.decompress() - yield weight_name, decompressed - - -class BitmaskTensor: - """ - Owns compressions and decompression for a single bitmask compressed tensor. - Adapted from: https://github.com/mgoin/torch_bitmask/tree/main - - :param shape: shape of dense tensor - :compressed: flat tensor of non-zero values - :bitmask: 2d bitmask of non-zero values - :row_offsets: flat tensor indicating what index in values each dense row starts at - """ - - def __init__( - self, - shape: Union[torch.Size, List], - compressed: Tensor, - bitmask: Tensor, - row_offsets: Tensor, - ): - self.shape = list(shape) - self.compressed = compressed - self.bitmask = bitmask - self.row_offsets = row_offsets - - @staticmethod - def from_dense(tensor: Tensor) -> "BitmaskTensor": - """ - :param tensor: dense tensor to compress - :return: instantiated compressed tensor - """ - shape = tensor.shape - compressed, bitmask, row_offsets = bitmask_compress(tensor.cpu()) - return BitmaskTensor( - shape=shape, compressed=compressed, bitmask=bitmask, row_offsets=row_offsets - ) - - def decompress(self) -> Tensor: - """ - :return: reconstructed dense tensor - """ - return bitmask_decompress(self.compressed, self.bitmask, self.shape) - - def curr_memory_size_bytes(self): - """ - :return: size in bytes required to store compressed tensor on disk - """ - - def sizeof_tensor(a): - return a.element_size() * a.nelement() - - return ( - sizeof_tensor(self.compressed) - + sizeof_tensor(self.bitmask) - + sizeof_tensor(self.row_offsets) - ) - - def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]: - """ - :name_prefix: name of original tensor to store compressed weight as - :return: dict of compressed data for the stored weight - """ - return { - merge_names(name_prefix, "shape"): torch.tensor(self.shape, device=device), - merge_names(name_prefix, "compressed"): self.compressed.to(device), - merge_names(name_prefix, "bitmask"): self.bitmask.to(device), - merge_names(name_prefix, "row_offsets"): self.row_offsets.to(device), - } - - def __repr__(self): - return f"BitmaskTensor(shape={self.shape}, compressed=True)" - - -def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - """ - Compresses a dense tensor using bitmask compression - - :param tensor: dense tensor to compress - :return: tuple of compressed data representing tensor - """ - bytemasks = tensor != 0 - row_counts = bytemasks.sum(dim=-1) - row_offsets = torch.cumsum(row_counts, 0) - row_counts - values = tensor[bytemasks] - bitmasks_packed = pack_bitmasks(bytemasks) - - return values, bitmasks_packed, row_offsets - - -def bitmask_decompress( - values: Tensor, bitmasks: Tensor, original_shape: torch.Size -) -> Tensor: - """ - Reconstructs a dense tensor from a compressed one - - :param values: 1d tensor of non-zero values - :param bitmasks: 2d int8 tensor flagging locations of non-zero values in the - tensors original shape - :param original_shape: shape of the dense tensor - :return: decompressed dense tensor - """ - bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape) - - decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype) - decompressed_tensor[bytemasks_unpacked] = values - - return decompressed_tensor - - -def pack_bitmasks(bytemasks: Tensor) -> Tensor: - """ - Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be - compressed to R x ceil(C/8) - :param bytemasks: mask tensor where each byte corresponds to a weight - :return: mask tensor where each bit corresounds to a weight - """ - packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little") - packed_bits_torch = torch.from_numpy(packed_bits_numpy) - - return packed_bits_torch - - -def unpack_bitmasks(packed_bitmasks: Tensor, original_shape: torch.Size) -> Tensor: - """ - Converts a bitmask tensor back to a bytemask tensor for use during decompression - - :param packed_bitmasks: mask tensor where each bit corresponds to a weight - :param original_shape: dense shape to decompress to - :return: boolean mask of weights in the original dense shape - """ - # Unpack the bits - unpacked_bits = numpy.unpackbits( - packed_bitmasks.numpy(), axis=-1, count=original_shape[-1], bitorder="little" - ) - - # Reshape to match the original shape - unpacked_bitmasks_torch = torch.from_numpy( - unpacked_bits.reshape(original_shape).astype(bool) - ) - - return unpacked_bitmasks_torch diff --git a/src/sparseml/transformers/compression/config/__init__.py b/src/sparseml/transformers/compression/config/__init__.py deleted file mode 100644 index 6465c3c6d1b..00000000000 --- a/src/sparseml/transformers/compression/config/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# flake8: noqa - -from .base import CompressionConfig -from .dense import DenseSparsityConfig -from .sparse_bitmask import BitmaskConfig diff --git a/src/sparseml/transformers/compression/config/dense.py b/src/sparseml/transformers/compression/config/dense.py deleted file mode 100644 index e9903c4fdb2..00000000000 --- a/src/sparseml/transformers/compression/config/dense.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -from sparseml.transformers.compression.config import CompressionConfig - - -__all__ = ["DenseSparsityConfig"] - - -@CompressionConfig.register(name="dense_sparsity") -class DenseSparsityConfig(CompressionConfig): - """ - Identity configuration for storing a sparse model in - an uncompressed dense format - - :param global_sparsity: average sparsity of the entire model - :param sparsity_structure: structure of the sparsity, such as - "unstructured", "2:4", "8:16" etc - """ - - format: str = "dense_sparsity" - global_sparsity: Optional[float] = 0.0 - sparsity_structure: Optional[str] = "unstructured" diff --git a/src/sparseml/transformers/compression/config/sparse_bitmask.py b/src/sparseml/transformers/compression/config/sparse_bitmask.py deleted file mode 100644 index dfd71711104..00000000000 --- a/src/sparseml/transformers/compression/config/sparse_bitmask.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -from sparseml.transformers.compression.config import CompressionConfig - - -__all__ = ["BitmaskConfig"] - - -@CompressionConfig.register(name="sparse_bitmask") -class BitmaskConfig(CompressionConfig): - """ - Configuration for storing a sparse model using - bitmask compression - - :param global_sparsity: average sparsity of the entire model - :param sparsity_structure: structure of the sparsity, such as - "unstructured", "2:4", "8:16" etc - """ - - format: str = "sparse_bitmask" - global_sparsity: Optional[float] = 0.0 - sparsity_structure: Optional[str] = "unstructured" diff --git a/src/sparseml/transformers/compression/config/base.py b/src/sparseml/transformers/compression/sparsity_config.py similarity index 80% rename from src/sparseml/transformers/compression/config/base.py rename to src/sparseml/transformers/compression/sparsity_config.py index 071a8718f5a..b04edf333c3 100644 --- a/src/sparseml/transformers/compression/config/base.py +++ b/src/sparseml/transformers/compression/sparsity_config.py @@ -14,32 +14,20 @@ from typing import Dict, Optional -from pydantic import BaseModel from torch import Tensor from torch.nn import Module import sparseml.core.session as session_manager +from compressed_tensors import CompressionConfig from sparseml.pytorch.utils import ModuleSparsificationInfo -from sparsezoo.utils.registry import RegistryMixin -__all__ = ["CompressionConfig"] - - -class CompressionConfig(RegistryMixin, BaseModel): +class SparsityConfigMetadata: """ - Base data class for storing compression parameters - - :param format: name of compression format - :param global_sparsity: average sparsity of the entire model - :param sparsity_structure: structure of the sparsity, such as - "unstructured", "2:4", "8:16" etc + Class of helper functions for filling out a CompressionConfig with readable + metadata from the model """ - format: str - global_sparsity: Optional[float] = 0.0 - sparsity_structure: Optional[str] = "unstructured" - @staticmethod def infer_global_sparsity( model: Module, state_dict: Optional[Dict[str, Tensor]] = None @@ -95,14 +83,14 @@ def infer_config_from_model( :return: compression config inferred from the model """ - global_sparsity = CompressionConfig.infer_global_sparsity( + global_sparsity = SparsityConfigMetadata.infer_global_sparsity( model, state_dict=state_dict ) if global_sparsity < 0.05: return None - sparsity_structure = CompressionConfig.infer_sparsity_structure() + sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure() if compress: format = "sparse_bitmask" else: @@ -114,17 +102,21 @@ def infer_config_from_model( sparsity_structure=sparsity_structure, ) + @staticmethod def fill_config_details( - self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None + config: CompressionConfig, + model: Module, + state_dict: Optional[Dict[str, Tensor]] = None, ): """ Fills in informational sparsity parameters from a given model + :param config: sparsity config to fill in :param model: pytorch model to infer config parameters from :param state_dict: optional state_dict to replace that in model, used for gathering global FSDP model info """ - self.global_sparsity = CompressionConfig.infer_global_sparsity( + config.global_sparsity = SparsityConfigMetadata.infer_global_sparsity( model, state_dict=state_dict ) - self.sparsity_structure = CompressionConfig.infer_sparsity_structure() + config.sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure() diff --git a/src/sparseml/transformers/compression/utils/__init__.py b/src/sparseml/transformers/compression/utils/__init__.py deleted file mode 100644 index 560435126ad..00000000000 --- a/src/sparseml/transformers/compression/utils/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# flake8: noqa - -from .compress_save import * -from .helpers import * -from .safetensors_load import * diff --git a/src/sparseml/transformers/compression/utils/helpers.py b/src/sparseml/transformers/compression/utils/helpers.py deleted file mode 100644 index 4d96fa66cf3..00000000000 --- a/src/sparseml/transformers/compression/utils/helpers.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Optional - -from transformers import AutoConfig - -from sparseml.transformers.compression.compressors import ModelCompressor -from sparseml.transformers.compression.config import CompressionConfig -from sparseml.transformers.utils.helpers import SPARSITY_CONFIG_NAME - - -__all__ = ["infer_compressor_from_model_config"] - - -def infer_compressor_from_model_config( - pretrained_model_name_or_path: str, -) -> Optional[ModelCompressor]: - """ - Given a path to a model config, extract a sparsity config if it exists and return - the associated ModelCompressor - - :param pretrained_model_name_or_path: path to model config on disk or HF hub - :return: matching compressor if config contains a sparsity config - """ - config = AutoConfig.from_pretrained(pretrained_model_name_or_path) - sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None) - if sparsity_config is None: - return None - - format = sparsity_config.get("format") - sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config) - compressor = ModelCompressor.load_from_registry(format, config=sparsity_config) - return compressor diff --git a/src/sparseml/transformers/compression/utils/safetensors_load.py b/src/sparseml/transformers/compression/utils/safetensors_load.py deleted file mode 100644 index 4d71482a8e9..00000000000 --- a/src/sparseml/transformers/compression/utils/safetensors_load.py +++ /dev/null @@ -1,196 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -import re -import struct -from typing import Dict, List, Optional - -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file - - -__all__ = [ - "get_safetensors_folder", - "get_safetensors_header", - "match_param_name", - "merge_names", - "get_weight_mappings", - "get_nested_weight_mappings", -] - - -def get_safetensors_folder( - pretrained_model_name_or_path: str, cache_dir: Optional[str] = None -) -> str: - """ - Given a Hugging Face stub or a local path, return the folder containing the - safetensors weight files - - :param pretrained_model_name_or_path: local path to model or HF stub - :param cache_dir: optional cache dir to search through, if none is specified the - model will be searched for in the default TRANSFORMERS_CACHE - :return: local folder containing model data - """ - if os.path.exists(pretrained_model_name_or_path): - # argument is a path to a local folder - return pretrained_model_name_or_path - - safetensors_path = cached_file( - pretrained_model_name_or_path, - SAFE_WEIGHTS_NAME, - cache_dir=cache_dir, - _raise_exceptions_for_missing_entries=False, - ) - index_path = cached_file( - pretrained_model_name_or_path, - SAFE_WEIGHTS_INDEX_NAME, - cache_dir=cache_dir, - _raise_exceptions_for_missing_entries=False, - ) - if safetensors_path is not None: - # found a single cached safetensors file - return os.path.split(safetensors_path)[0] - if index_path is not None: - # found a cached safetensors weight index file - return os.path.split(index_path)[0] - - # model weights could not be found locally or cached from HF Hub - raise ValueError( - "Could not locate safetensors weight or index file from " - f"{pretrained_model_name_or_path}." - ) - - -def get_safetensors_header(safetensors_path: str) -> Dict[str, str]: - """ - Extracts the metadata from a safetensors file as JSON - - :param safetensors_path: path to a safetensors file - :return: dictionary of metadata extracted from the safetensors file - """ - with open(safetensors_path, "rb") as f: - length_of_header = struct.unpack(" str: - """ - Helper function extracting the uncompressed parameterized layer name from a - compressed name. Assumes the compressed name was merged using merge_names. - - :param full_name: full name of parameter in compressed model - :param param_name: compression paramater name - :return: uncompressed name of the uncompressed parameterized layer - """ - pattern = r"^(.*)\." + param_name + r"$" - regex = re.findall(pattern, full_name) - if len(regex) == 0: - return None - return regex[0] - - -def merge_names(parent_name: str, child_name: str) -> str: - """ - Helper function for merging an uncompressed parameterized layer name with a - compression parameter. Names merged with this function can then be parsed by - match_param_name. - - :param parent_name: uncompressed parameterized layer name - :param child_name: compression parameter name - :return: merged compressed name - """ - return parent_name + "." + child_name - - -def get_weight_mappings(model_path: str) -> Dict[str, str]: - """ - Takes a path to a state dict saved in safetensors format and returns a mapping - from parameterized layer name to file location. - - { - layer.weight.bitmask: file_location, - layer.weight.row_offsets: file_location, - layer.weight.shape: file_location, - layer.weight.compressed: file_location - } - - This generalizes to cases where the model is split into multiple safetensors files - - :param model_path: path to safetensors state dict, must contain either a single - safetensors file or multiple files with an index - :return: mapping of parameterized layer name to file location - """ - safetensors_path = os.path.join(model_path, SAFE_WEIGHTS_NAME) - index_path = os.path.join(model_path, SAFE_WEIGHTS_INDEX_NAME) - if os.path.exists(safetensors_path): - # we have a single safetensors file to read - header = get_safetensors_header(safetensors_path) - for key in header.keys(): - header[key] = SAFE_WEIGHTS_NAME - header.pop("__metadata__", None) - elif os.path.exists(index_path): - # we have multiple safetensors file, read from index - with open(index_path, "r", encoding="utf-8") as f: - index = json.load(f) - header = index["weight_map"] - else: - raise ValueError( - f"Could not find a safetensors weight or index file at {model_path}" - ) - - # convert weight locations to full paths - for key, value in header.items(): - header[key] = os.path.join(model_path, value) - - return header - - -def get_nested_weight_mappings( - model_path: str, params_to_nest: List[str] -) -> Dict[str, Dict[str, str]]: - """ - Takes a path to a state dict saved in safetensors format and returns a nested - mapping from uncompressed parameterized layer names to the file locations of each - of the layers compression parameters. - - layer.weight: { - bitmask: file_location, - row_offsets: file_location, - shape: file_location, - compressed: file_location - } - - This generalizes to cases where the model is split into multiple safetensors files - - :param model_path: path to safetensors state dict, must contain either a single - safetensors file or multiple files with an index - :return: nested mapping of parameterized layer name to file location - """ - weight_mappings = get_weight_mappings(model_path) - - nested_weight_mappings = {} - for key in weight_mappings.keys(): - for param_name in params_to_nest: - maybe_match = match_param_name(key, param_name) - if maybe_match is not None: - dense_param = maybe_match - if dense_param not in nested_weight_mappings: - nested_weight_mappings[dense_param] = {} - nested_weight_mappings[dense_param][param_name] = weight_mappings[key] - - return nested_weight_mappings diff --git a/src/sparseml/transformers/compression/utils/compress_save.py b/src/sparseml/transformers/sparsification/compressed_tensors_utils.py similarity index 90% rename from src/sparseml/transformers/compression/utils/compress_save.py rename to src/sparseml/transformers/sparsification/compressed_tensors_utils.py index 96315fd1685..ab9a7f5f5fc 100644 --- a/src/sparseml/transformers/compression/utils/compress_save.py +++ b/src/sparseml/transformers/sparsification/compressed_tensors_utils.py @@ -22,9 +22,8 @@ from transformers import PreTrainedModel from transformers.file_utils import CONFIG_NAME -from sparseml.transformers.compression.compressors import ModelCompressor -from sparseml.transformers.compression.config import CompressionConfig -from sparseml.transformers.utils.helpers import SPARSITY_CONFIG_NAME +from compressed_tensors import SPARSITY_CONFIG_NAME, CompressionConfig, ModelCompressor +from sparseml.transformers.compression.sparsity_config import SparsityConfigMetadata from sparseml.utils.pytorch import qat_active @@ -88,7 +87,15 @@ def save_pretrained_wrapper( ) if sparsity_config is not None: - sparsity_config.fill_config_details(model, state_dict=state_dict) + sparsity_config.global_sparsity = ( + SparsityConfigMetadata.infer_global_sparsity( + model, state_dict=state_dict + ) + ) + sparsity_config.sparsity_structure = ( + SparsityConfigMetadata.infer_sparsity_structure() + ) + elif not skip_compression_stats: # try to infer a sparsity config from the model if none is provided _LOGGER.info( @@ -97,7 +104,7 @@ def save_pretrained_wrapper( "calculation of compression statistics set " "skip_compression_stats=True" ) - sparsity_config = CompressionConfig.infer_config_from_model( + sparsity_config = SparsityConfigMetadata.infer_config_from_model( model, state_dict=state_dict, compress=save_compressed ) diff --git a/src/sparseml/transformers/sparsification/sparse_model.py b/src/sparseml/transformers/sparsification/sparse_model.py index 22ea6a21848..a22316ca179 100644 --- a/src/sparseml/transformers/sparsification/sparse_model.py +++ b/src/sparseml/transformers/sparsification/sparse_model.py @@ -30,14 +30,13 @@ ) from transformers.file_utils import WEIGHTS_NAME +from compressed_tensors import ModelCompressor, get_safetensors_folder from sparseml.modifiers.quantization.modification import modify_model from sparseml.pytorch.model_load.helpers import ( apply_recipe_structure_to_model, log_model_load, ) -from sparseml.transformers.compression.utils import ( - get_safetensors_folder, - infer_compressor_from_model_config, +from sparseml.transformers.sparsification.compressed_tensors_utils import ( modify_save_pretrained, ) from sparseml.transformers.utils.helpers import download_model_directory, resolve_recipe @@ -102,7 +101,7 @@ def skip(*args, **kwargs): ) # determine compression format, if any, from the model config - compressor = infer_compressor_from_model_config(pretrained_model_name_or_path) + compressor = ModelCompressor.from_pretrained(pretrained_model_name_or_path) # temporarily set the log level to error, to ignore printing out long missing # and unexpected key error messages (these are EXPECTED for quantized models) diff --git a/src/sparseml/transformers/utils/helpers.py b/src/sparseml/transformers/utils/helpers.py index 944d0bd32ff..cb95d376a75 100644 --- a/src/sparseml/transformers/utils/helpers.py +++ b/src/sparseml/transformers/utils/helpers.py @@ -75,7 +75,7 @@ class TaskNames(Enum): ALL_TASK_NAMES = list(set.union(*[task_names.value for task_names in TaskNames])) ONNX_MODEL_NAME_INTERMEDIATE = "model-orig.onnx" RECIPE_NAME = "recipe.yaml" -SPARSITY_CONFIG_NAME = "sparsity_config" + MANDATORY_DEPLOYMENT_FILES = { ONNX_MODEL_NAME, "tokenizer_config.json", diff --git a/tests/sparseml/transformers/compression/test_bitmask.py b/tests/sparseml/transformers/compression/test_bitmask.py deleted file mode 100644 index 40d683cb468..00000000000 --- a/tests/sparseml/transformers/compression/test_bitmask.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import shutil - -import pytest -import torch - -from safetensors.torch import save_file -from sparseml.transformers.compression import BitmaskCompressor, BitmaskConfig -from sparseml.transformers.compression.compressors.sparse_bitmask import BitmaskTensor - - -@pytest.mark.parametrize( - "shape,sparsity,dtype", - [ - [(512, 1024), 0.5, torch.float32], - [(830, 545), 0.8, torch.float32], - [(342, 512), 0.3, torch.bfloat16], - [(256, 700), 0.9, torch.float16], - ], -) -def test_bitmask_sizes(shape, sparsity, dtype): - test_tensor = torch.rand(shape, dtype=dtype) - mask = (test_tensor.abs() < (1 - sparsity)).int() - test_tensor *= mask - dense_state_dict = {"dummy.weight": test_tensor} - - sparsity_config = BitmaskConfig() - compressor = BitmaskCompressor(config=sparsity_config) - sparse_state_dict = compressor.compress(dense_state_dict) - - # each dense tensor has 4 parameters for compression - assert len(dense_state_dict) * 4 == len(sparse_state_dict) - - # bitmask should be 1 bit per dense element, rounded up to nearest int8 - sparse_shape = sparse_state_dict["dummy.weight.shape"] - assert torch.all(torch.eq(sparse_shape, torch.tensor(shape))) - bitmask_shape = sparse_state_dict["dummy.weight.bitmask"].shape - assert bitmask_shape[0] == sparse_shape[0] - assert bitmask_shape[1] == int(math.ceil(sparse_shape[1] / 8.0)) - - # one value for each non-zero weight - values_shape = sparse_state_dict["dummy.weight.compressed"].shape - assert values_shape[0] == torch.sum(test_tensor != 0) - row_offsets_shape = sparse_state_dict["dummy.weight.row_offsets"].shape - assert row_offsets_shape[0] == test_tensor.shape[0] - - -@pytest.mark.parametrize( - "shape,sparsity,dtype", - [ - [(256, 512), 0.5, torch.float32], - [(128, 280), 0.8, torch.float32], - [(1024, 256), 0.3, torch.bfloat16], - [(511, 350), 0.7, torch.float16], - ], -) -def test_match(shape, sparsity, dtype): - test_tensor1 = torch.rand(shape, dtype=dtype) - mask = (test_tensor1.abs() < (1 - sparsity)).int() - test_tensor1 *= mask - - test_tensor2 = torch.rand(shape, dtype=dtype) - mask = (test_tensor2.abs() < (1 - sparsity)).int() - test_tensor2 *= mask - - dense_state_dict = {"dummy.weight": test_tensor1, "dummy2.weight": test_tensor2} - - for key in dense_state_dict.keys(): - dense_tensor = dense_state_dict[key] - sparse_tensor = BitmaskTensor.from_dense(dense_tensor) - decompressed = sparse_tensor.decompress() - assert decompressed.dtype == dense_tensor.dtype == dtype - assert torch.equal(dense_tensor, decompressed) - - -@pytest.mark.parametrize( - "sparsity,dtype", - [ - [0.5, torch.float32], - [0.8, torch.float32], - [0.3, torch.bfloat16], - [0.7, torch.float16], - ], -) -def test_reload_match(sparsity, dtype, tmp_path): - test_tensor1 = torch.rand((256, 512), dtype=dtype) - mask = (test_tensor1.abs() < (1 - sparsity)).int() - test_tensor1 *= mask - - test_tensor2 = torch.rand((360, 720), dtype=dtype) - mask = (test_tensor2.abs() < (1 - sparsity)).int() - test_tensor2 *= mask - - dense_state_dict = {"dummy.weight": test_tensor1, "dummy2.weight": test_tensor2} - - sparsity_config = BitmaskConfig() - compressor = BitmaskCompressor(config=sparsity_config) - - sparse_state_dict = compressor.compress(dense_state_dict) - save_file(sparse_state_dict, tmp_path / "model.safetensors") - reconstructed_dense = compressor.decompress(tmp_path) - - for key, reconstructed_tensor in reconstructed_dense: - dense_tensor = dense_state_dict[key] - assert dense_tensor.dtype == reconstructed_tensor.dtype == dtype - assert torch.equal(dense_tensor, reconstructed_tensor) - - shutil.rmtree(tmp_path) diff --git a/tests/sparseml/transformers/compression/test_registries.py b/tests/sparseml/transformers/compression/test_registries.py deleted file mode 100644 index fb1ba37d3d0..00000000000 --- a/tests/sparseml/transformers/compression/test_registries.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from sparseml.transformers.compression import ( - BitmaskCompressor, - BitmaskConfig, - CompressionConfig, - DenseCompressor, - DenseSparsityConfig, - ModelCompressor, -) - - -@pytest.mark.parametrize( - "name,type", - [ - ["sparse_bitmask", BitmaskConfig], - ["dense_sparsity", DenseSparsityConfig], - ], -) -def test_configs(name, type): - config = CompressionConfig.load_from_registry(name) - assert isinstance(config, type) - assert config.format == name - - -@pytest.mark.parametrize( - "name,type", - [["sparse_bitmask", BitmaskCompressor], ["dense_sparsity", DenseCompressor]], -) -def test_compressors(name, type): - compressor = ModelCompressor.load_from_registry( - name, config=CompressionConfig(format="none") - ) - assert isinstance(compressor, type) - assert isinstance(compressor.config, CompressionConfig) - assert compressor.config.format == "none" diff --git a/tests/sparseml/transformers/compression/test_sparse_auto.py b/tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py similarity index 88% rename from tests/sparseml/transformers/compression/test_sparse_auto.py rename to tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py index 7a1fdec0266..38369617ed7 100644 --- a/tests/sparseml/transformers/compression/test_sparse_auto.py +++ b/tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py @@ -20,13 +20,10 @@ from transformers import AutoConfig import sparseml.core.session as session_manager +from compressed_tensors import SPARSITY_CONFIG_NAME +from compressed_tensors.config import BitmaskConfig, DenseSparsityConfig from sparseml.transformers import SparseAutoModelForCausalLM, oneshot -from sparseml.transformers.compression import ( - BitmaskConfig, - CompressionConfig, - DenseSparsityConfig, -) -from sparseml.transformers.utils.helpers import SPARSITY_CONFIG_NAME +from sparseml.transformers.compression.sparsity_config import SparsityConfigMetadata @pytest.mark.parametrize( @@ -68,9 +65,9 @@ def test_sparse_model_reload(compressed, config, dtype, tmp_path): tmp_path / "oneshot_out", torch_dtype=dtype ) - inferred_global_sparsity = CompressionConfig.infer_global_sparsity(model) + inferred_global_sparsity = SparsityConfigMetadata.infer_global_sparsity(model) assert math.isclose(inferred_global_sparsity, 19.6562, rel_tol=1e-3) - inferred_structure = CompressionConfig.infer_sparsity_structure() + inferred_structure = SparsityConfigMetadata.infer_sparsity_structure() assert inferred_structure == "0:0" model.save_pretrained( @@ -115,9 +112,9 @@ def test_dense_model_save(tmp_path, skip_compression_stats, save_compressed): model_path = "Xenova/llama2.c-stories15M" model = SparseAutoModelForCausalLM.from_pretrained(model_path) - inferred_global_sparsity = CompressionConfig.infer_global_sparsity(model) + inferred_global_sparsity = SparsityConfigMetadata.infer_global_sparsity(model) assert math.isclose(inferred_global_sparsity, 0.0, rel_tol=1e-3) - inferred_structure = CompressionConfig.infer_sparsity_structure() + inferred_structure = SparsityConfigMetadata.infer_sparsity_structure() assert inferred_structure == "unstructured" model.save_pretrained( From 90795bda37c891be62dbeeea7f108050461bdeaa Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 1 May 2024 16:21:01 +0000 Subject: [PATCH 8/8] quality --- .../transformers/sparsification/compressed_tensors_utils.py | 1 - src/sparseml/transformers/sparsification/sparse_model.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/sparseml/transformers/sparsification/compressed_tensors_utils.py b/src/sparseml/transformers/sparsification/compressed_tensors_utils.py index e32b48bafdc..b6852535a2c 100644 --- a/src/sparseml/transformers/sparsification/compressed_tensors_utils.py +++ b/src/sparseml/transformers/sparsification/compressed_tensors_utils.py @@ -122,7 +122,6 @@ def save_pretrained_wrapper( return - if sparsity_config is not None: sparsity_config.global_sparsity = ( SparsityConfigMetadata.infer_global_sparsity( diff --git a/src/sparseml/transformers/sparsification/sparse_model.py b/src/sparseml/transformers/sparsification/sparse_model.py index e7852bc23bc..995b349f513 100644 --- a/src/sparseml/transformers/sparsification/sparse_model.py +++ b/src/sparseml/transformers/sparsification/sparse_model.py @@ -36,6 +36,7 @@ apply_quantization_config, load_pretrained_quantization, ) +from sparseml.modifiers.quantization.modification import modify_model from sparseml.pytorch.model_load.helpers import ( apply_recipe_structure_to_model, log_model_load, @@ -152,8 +153,6 @@ class SparseAutoModel: Factory class for creating sparse models using transformers AutoModel classes """ - from sparseml.modifiers.quantization.modification import modify_model - @staticmethod def masked_language_modeling_from_pretrained( model_name_or_path: str,