Skip to content

Commit

Permalink
add metadata mismatch test
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 committed Feb 27, 2025
1 parent b5c8acb commit 04c26a2
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 0 deletions.
16 changes: 16 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,22 @@ def test_copy_(self, apply_quant):
output2 = ql2(example_input)
self.assertEqual(output, output2)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize(
"apply_quant", get_quantization_functions(False, True, "cuda", False)
)
def test_copy__mismatch_metadata(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device="cuda")
ql2 = apply_quant(linear2)

# copy should fail due to shape mismatch
with self.assertRaisesRegex(
ValueError, "Not supported args for copy_ due to metadata mistach:"
):
ql2.weight.copy_(ql.weight)


class TestAffineQuantizedBasic(TestCase):
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
Expand Down
1 change: 1 addition & 0 deletions torchao/dtypes/uintx/plain_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def _same_metadata(self: "PlainAQTTensorImpl", src: "PlainAQTTensorImpl") -> boo
return (
isinstance(self, PlainAQTTensorImpl)
and isinstance(src, PlainAQTTensorImpl)
and self.shape == src.shape
and self.int_data.shape == src.int_data.shape
and self.scale.shape == src.scale.shape
and self.zero_point.shape == src.zero_point.shape
Expand Down
1 change: 1 addition & 0 deletions torchao/dtypes/uintx/tensor_core_tiled_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _same_metadata(
return (
isinstance(self, TensorCoreTiledAQTTensorImpl)
and isinstance(src, TensorCoreTiledAQTTensorImpl)
and self.shape == src.shape
and self.packed_weight.shape == src.packed_weight.shape
and self.scale_and_zero.shape == src.scale_and_zero.shape
and self.transposed == src.transposed
Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def _same_metadata(
return (
isinstance(self, LinearActivationQuantizedTensor)
and isinstance(src, LinearActivationQuantizedTensor)
and self.shape == src.shape
and self.input_quant_func == src.input_quant_func
and self.quant_kwargs == src.quant_kwargs
)
Expand Down

0 comments on commit 04c26a2

Please sign in to comment.