Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Save and load compressed model #3050

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nncf/torch/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def get_config(self) -> Dict[str, Any]:
Returns the compression module config.
"""

@abstractclassmethod
@classmethod
@abstractmethod
def from_config(cls, state: Dict[str, Any]) -> object:
"""
Creates a compression module instance from the given config.
Expand Down
82 changes: 73 additions & 9 deletions nncf/torch/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ def get_scale_shape(input_shape: List[int], is_weights: bool, per_channel: bool,
return get_per_channel_scale_shape(input_shape, is_weights, channel_idx)


class BaseWeightsDecompressor(nn.Module, ABC):
class BaseWeightsDecompressor(nn.Module, StatefullModuleInterface, ABC):
"""
Base class for implementing weights decompression modules within NNCF.

Expand Down Expand Up @@ -1081,6 +1081,7 @@ def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
"""


@COMPRESSION_MODULES.register()
class INT8AsymmetricWeightsDecompressor(BaseWeightsDecompressor):
"""
Applies asymmetric decompression of compressed weights in the forward pass
Expand All @@ -1103,17 +1104,32 @@ def quantization_mode(self) -> QuantizationMode:

def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if torch.is_floating_point(weight):
raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
raise nncf.ValidationError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
if torch.any((weight < 0) | (weight > 255)):
raise ValueError("Weight values are not in [0, 255].")
raise nncf.ValidationError("Weight values are not in [0, 255].")
return weight.type(dtype=torch.uint8)

def forward(self, x) -> torch.Tensor:
result = decompress_asymmetric(x, self._scale, self._zero_point)
result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result
return result

def get_config(self) -> Dict[str, Any]:
return {
"scale_shape": self._scale.shape,
"zero_point_shape": self._zero_point.shape,
"result_dtype": self.result_dtype if self.result_dtype is not None else "",
}

@classmethod
def from_config(cls, state: Dict[str, Any]) -> object:
scale = torch.ones(state["scale_shape"], dtype=torch.float16)
zero_point = torch.zeros(state["zero_point_shape"], dtype=torch.uint8)
result_dtype = state["result_dtype"] if state["result_dtype"] else None
return cls(scale, zero_point, result_dtype)


@COMPRESSION_MODULES.register()
class INT8SymmetricWeightsDecompressor(BaseWeightsDecompressor):
"""
Applies symmetric decompression of compressed weights in the forward pass
Expand All @@ -1134,17 +1150,30 @@ def quantization_mode(self) -> QuantizationMode:

def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if torch.is_floating_point(weight):
raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
raise nncf.ValidationError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
if torch.any((weight < -128) | (weight > 127)):
raise ValueError("Weight values are not in [-128, 127].")
raise nncf.ValidationError("Weight values are not in [-128, 127].")
return weight.type(dtype=torch.int8)

def forward(self, x) -> torch.Tensor:
result = decompress_symmetric(x, self._scale)
result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result
return result

def get_config(self) -> Dict[str, Any]:
return {
"scale_shape": self._scale.shape,
"result_dtype": self.result_dtype if self.result_dtype is not None else "",
}

@classmethod
def from_config(cls, state: Dict[str, Any]) -> object:
scale = torch.ones(state["scale_shape"], dtype=torch.float16)
result_dtype = state["result_dtype"] if state["result_dtype"] else None
return cls(scale, result_dtype)


@COMPRESSION_MODULES.register()
class INT4AsymmetricWeightsDecompressor(BaseWeightsDecompressor):
def __init__(
self,
Expand Down Expand Up @@ -1177,9 +1206,9 @@ def quantization_mode(self) -> QuantizationMode:

def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if torch.is_floating_point(weight):
raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
raise nncf.ValidationError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
if torch.any((weight < 0) | (weight > 15)):
raise ValueError("Weight values are not in [0, 15].")
raise nncf.ValidationError("Weight values are not in [0, 15].")
return pack_uint4(weight.type(dtype=torch.uint8))

def forward(self, x):
Expand All @@ -1194,7 +1223,26 @@ def forward(self, x):
result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result
return result

def get_config(self) -> Dict[str, Any]:
return {
"scale_shape": self._scale.shape,
"zero_point_shape": self.zero_point_shape,
"compressed_weight_shape": self.compressed_weight_shape,
"result_shape": self.result_shape if self.result_shape is not None else "",
"result_dtype": self.result_dtype if self.result_dtype is not None else "",
}

@classmethod
def from_config(cls, state: Dict[str, Any]) -> object:
scale = torch.ones(state["scale_shape"], dtype=torch.float16)
zero_point = torch.zeros(state["zero_point_shape"], dtype=torch.uint8)
compressed_weight_shape = state["compressed_weight_shape"]
result_shape = state["result_shape"] if state["result_shape"] else None
result_dtype = state["result_dtype"] if state["result_dtype"] else None
return cls(scale, zero_point, compressed_weight_shape, result_shape, result_dtype)


@COMPRESSION_MODULES.register()
class INT4SymmetricWeightsDecompressor(BaseWeightsDecompressor):
def __init__(
self,
Expand Down Expand Up @@ -1222,9 +1270,9 @@ def quantization_mode(self) -> QuantizationMode:

def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
if torch.is_floating_point(weight):
raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
raise nncf.ValidationError(f"Invalid weight dtype {weight.type}. Integer types are supported.")
if torch.any((weight < -8) | (weight > 7)):
raise ValueError("Tensor values are not in [-8, 7].")
raise nncf.ValidationError("Tensor values are not in [-8, 7].")
return pack_int4(weight.type(dtype=torch.int8))

def forward(self, x):
Expand All @@ -1235,3 +1283,19 @@ def forward(self, x):
result = result.reshape(self.result_shape) if self.result_shape is not None else result
result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result
return result

def get_config(self) -> Dict[str, Any]:
return {
"scale_shape": self._scale.shape,
"compressed_weight_shape": self.compressed_weight_shape,
"result_shape": self.result_shape if self.result_shape is not None else "",
"result_dtype": self.result_dtype if self.result_dtype is not None else "",
}

@classmethod
def from_config(cls, state: Dict[str, Any]) -> object:
scale = torch.ones(state["scale_shape"], dtype=torch.float16)
compressed_weight_shape = state["compressed_weight_shape"]
result_shape = state["result_shape"] if state["result_shape"] else None
result_dtype = state["result_dtype"] if state["result_dtype"] else None
return cls(scale, compressed_weight_shape, result_shape, result_dtype)
36 changes: 36 additions & 0 deletions tests/torch/ptq/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,39 @@ def test_pack_int4():
assert packed_w.numel() * 2 == w_int8.numel()
unpacked_w = unpack_int4(packed_w).reshape(w_int8.shape)
assert torch.all(unpacked_w == w_int8)


@pytest.mark.parametrize("mode", SUPPORTED_MODES)
def test_save_load(mode, tmp_path):
model = ShortTransformer(8, 16)
input_ids = torch.randint(0, 10, (8,))
wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True)

kwargs = {}
if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]:
kwargs["group_size"] = 4
compressed_model = compress_weights(wrapped_model, mode=mode, **kwargs)

state_dict = compressed_model.state_dict()
compression_config = compressed_model.nncf.get_config()

ckpt_path = tmp_path / f"{mode}_model.pt"
torch.save(
{
"model_state_dict": state_dict,
"compression_config": compression_config,
},
ckpt_path,
)

compressed_result = compressed_model(input_ids)

restored_model = ShortTransformer(8, 16)

ckpt = torch.load(ckpt_path)
restored_model = nncf.torch.load_from_config(restored_model, ckpt["compression_config"], input_ids)
restored_model.load_state_dict(ckpt["model_state_dict"])

restored_compressed_result = restored_model(input_ids)

assert torch.allclose(compressed_result, restored_compressed_result)
Loading