From d08f60a767e1edad021a949d41a72c11715bc850 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 25 Nov 2024 12:59:06 -0800 Subject: [PATCH] Update bitsandbytes import --- test/quantization/test_galore_quant.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index 1eabf479ce..37709c4128 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -8,7 +8,7 @@ except ImportError: pytest.skip("triton is not installed", allow_module_level=True) -import bitsandbytes.functional as F +from bitsandbytes.functional import create_dynamic_map, quantize_blockwise, dequantize_blockwise import torch from torchao.prototype.galore.kernels import ( @@ -36,9 +36,9 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 - qmap = F.create_dynamic_map(signed).to(g.device) + qmap = create_dynamic_map(signed).to(g.device) - ref_bnb, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize) + ref_bnb, qstate = quantize_blockwise(g, code=qmap, blocksize=blocksize) bnb_norm = (g.reshape(-1, blocksize) / qstate.absmax[:, None]).reshape(g.shape) tt_q, tt_norm, tt_absmax = triton_quantize_blockwise( @@ -82,10 +82,10 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize): g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 - qmap = F.create_dynamic_map(signed).to(g.device) + qmap = create_dynamic_map(signed).to(g.device) - q, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize) + q, qstate = quantize_blockwise(g, code=qmap, blocksize=blocksize) - dq_ref = F.dequantize_blockwise(q, qstate) + dq_ref = dequantize_blockwise(q, qstate) dq = triton_dequant_blockwise(q, qmap, qstate.absmax, group_size=blocksize) assert torch.allclose(dq, dq_ref)