Skip to content

Commit

Permalink
Add support for copy_ for plain layout and tensor core tiled layout
Browse files Browse the repository at this point in the history
Summary:
att, only support copy_ from AQT to another AQT with same metadata (shapes etc.)

Tested int4wo, int8wo, int8dq

Test Plan:
python test/dtypes/test_affine_quantized.py -k test_copy_

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Feb 27, 2025
1 parent e1cb44a commit 277728c
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 0 deletions.
18 changes: 18 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,24 @@ def test_print_quantized_module(self, apply_quant):
ql = apply_quant(linear)
assert "AffineQuantizedTensor" in str(ql)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize(
"apply_quant", get_quantization_functions(False, True, "cuda", False)
)
def test_copy_(self, apply_quant):
print("apply_quant:", apply_quant)
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql2 = apply_quant(linear2)

example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
output = ql(example_input)
ql2.weight.copy_(ql.weight)
ql2.bias = ql.bias
output2 = ql2(example_input)
self.assertEqual(output, output2)


class TestAffineQuantizedBasic(TestCase):
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
Expand Down
35 changes: 35 additions & 0 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,27 @@ def deregister_aqt_quantized_linear_dispatch(dispatch_condition):
)


def _same_metadata(self: AffineQuantizedTensor, src: AffineQuantizedTensor):
return (
isinstance(self, AffineQuantizedTensor)
and isinstance(src, AffineQuantizedTensor)
and all(
[
getattr(self, attr) == getattr(src, attr)
for attr in [
"block_size",
"shape",
"quant_min",
"quant_max",
"zero_point_domain",
"dtype",
]
]
)
and type(self.tensor_impl) == type(src.tensor_impl)
)


class QuantizedLinearNotImplementedError(NotImplementedError):
"""Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table"""

Expand Down Expand Up @@ -317,6 +338,20 @@ def _(func, types, args, kwargs):
)


@implements(aten.copy_.default)
def _(func, types, args, kwargs):
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)


@implements(aten.t.default)
def _(func, types, args, kwargs):
block_size = args[0].block_size
Expand Down
23 changes: 23 additions & 0 deletions torchao/dtypes/uintx/plain_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@
aten = torch.ops.aten


def _same_metadata(self: "PlainAQTTensorImpl", src: "PlainAQTTensorImpl") -> bool:
return (
isinstance(self, PlainAQTTensorImpl)
and isinstance(src, PlainAQTTensorImpl)
and self.int_data.shape == src.int_data.shape
and self.scale.shape == src.scale.shape
and self.zero_point.shape == src.zero_point.shape
and type(self._layout) == type(src._layout)
)


@register_layout(PlainLayout)
class PlainAQTTensorImpl(AQTTensorImpl):
"""
Expand Down Expand Up @@ -113,6 +124,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

if func is aten.copy_.default:
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)

elif func is aten.t.default:
tensor = args[0]
new = tensor.__class__(
Expand Down
28 changes: 28 additions & 0 deletions torchao/dtypes/uintx/tensor_core_tiled_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ def _aqt_is_tensor_core_tile_uint4(aqt):
)


def _same_metadata(
self: "TensorCoreTiledAQTTensorImpl", src: "TensorCoreTiledAQTTensorImpl"
) -> bool:
print(
f"{isinstance(self, TensorCoreTiledAQTTensorImpl)}, {isinstance(src, TensorCoreTiledAQTTensorImpl)}, {self.packed_weight.shape == src.packed_weight.shape}, {self.scale_and_zero.shape == src.scale_and_zero.shape}, {self.transposed, src.transposed}, {type(self._layout)}, {type(src._layout)}"
)
return (
isinstance(self, TensorCoreTiledAQTTensorImpl)
and isinstance(src, TensorCoreTiledAQTTensorImpl)
and self.packed_weight.shape == src.packed_weight.shape
and self.scale_and_zero.shape == src.scale_and_zero.shape
and self.transposed == src.transposed
and type(self._layout) == type(src._layout)
)


def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
return (
# input is native bfloat16 tensor
Expand Down Expand Up @@ -296,6 +312,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

if func is aten.copy_.default:
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)

if func is aten.t.default:
"""we don't need to repack the weight and just rely on external
shape being changed and record the status of transpose/no-transpose
Expand Down
25 changes: 25 additions & 0 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ def to(self, *args, **kwargs):
)


def _same_metadata(
self: LinearActivationQuantizedTensor, src: LinearActivationQuantizedTensor
):
return (
isinstance(self, LinearActivationQuantizedTensor)
and isinstance(src, LinearActivationQuantizedTensor)
and self.input_quant_func == src.input_quant_func
and self.quant_kwargs == src.quant_kwargs
)


implements = LinearActivationQuantizedTensor.implements


Expand Down Expand Up @@ -191,6 +202,20 @@ def _(func, types, args, kwargs):
)


@implements(aten.copy_.default)
def _(func, types, args, kwargs):
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)


@implements(aten.t.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
Expand Down

0 comments on commit 277728c

Please sign in to comment.