Skip to content

Commit

Permalink
Deprecation updates
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Mar 8, 2025
1 parent 6172770 commit 242c602
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 433 deletions.
30 changes: 10 additions & 20 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64)


@deprecated(
"This function is deprecated and will be removed in a future release. Consider using `int8_vectorwise_dequant` instead.",
category=FutureWarning,
)
def get_inverse_transform_indices(
transform_tile: Callable[[torch.Tensor], torch.Tensor],
tile_size: Tuple[int, int],
Expand Down Expand Up @@ -80,6 +84,10 @@ def get_inverse_transform_indices(
return permuted_tile_indices


@deprecated(
"This function is deprecated and will be removed in a future release. Consider using `int8_vectorwise_dequant` instead.",
category=FutureWarning,
)
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
"""
Undo a tiled permutation such as turing or ampere layout
Expand Down Expand Up @@ -225,25 +233,9 @@ def supports_igemmlt(device: torch.device) -> bool:
return True


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def _get_tile_size(format):
assert format in (
"col_turing",
"col_ampere",
), f"please find this assert and manually enter tile size for {format}"
return (8, 32) if format == "col_turing" else (32, 32)


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_tile_inds(format, device):
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
with torch.no_grad():
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)


@dataclass
class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None
_tile_indices: Optional[torch.Tensor] = None # TODO: remove

force_no_igemmlt: bool = False

Expand Down Expand Up @@ -279,9 +271,7 @@ def reset_grads(self):

@property
def tile_indices(self):
if self._tile_indices is None:
self._tile_indices = get_tile_inds(self.formatB, self.CxB.device)
return self._tile_indices
raise ValueError("tile_indices is no longer supported.")


class MatMul8bitLt(torch.autograd.Function):
Expand Down
195 changes: 6 additions & 189 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,6 @@ def get_instance(cls):
return cls._instance


dtype2bytes = {}
dtype2bytes[torch.float32] = 4
dtype2bytes[torch.float16] = 2
dtype2bytes[torch.bfloat16] = 2
dtype2bytes[torch.uint8] = 1
dtype2bytes[torch.int8] = 1

FIRST_CUDA_DEVICE = torch.device("cuda", index=0)

# When multiple GPUs are present, we use a context manager to
Expand All @@ -207,7 +200,7 @@ def _cuda_device_of(a: torch.Tensor):


def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
num_bytes = dtype2bytes[dtype] * prod(shape)
num_bytes = dtype.itemsize * prod(shape)
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
Expand All @@ -217,15 +210,14 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
return out


def prefetch_tensor(A, to_cpu=False):
def prefetch_tensor(A: torch.Tensor, to_cpu=False):
assert A.is_paged, "Only paged tensors can be prefetched!"
if to_cpu:
deviceid = -1
else:
deviceid = A.page_deviceid

num_bytes = dtype2bytes[A.dtype] * A.numel()
lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid))
lib.cprefetch(get_ptr(A), ct.c_size_t(A.nbytes), ct.c_int32(deviceid))


def elementwise_func(func_name, A, B, value, prefetch=True):
Expand Down Expand Up @@ -499,106 +491,6 @@ def post_call(prev_device):
torch.cuda.set_device(prev_device)


@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def get_transform_func(dtype, orderA, orderOut, transpose=False):
name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
if not hasattr(lib, name):
print(name)
raise ValueError(
f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}",
)
else:
return getattr(lib, name)


@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False):
# init_func = torch.empty
init_func = torch.zeros
dims = len(shape)

if dims == 2:
rows = shape[0]
elif dims == 3:
rows = shape[0] * shape[1]
cols = shape[-1]

state = (shape, to_order)
if transpose:
# swap dims
tmp = rows
rows = cols
cols = tmp
state = (shape[::-1], to_order)

if to_order == "row" or to_order == "col":
return init_func(shape, dtype=dtype, device=device), state
elif to_order == "col32":
# blocks of 32 columns (padded)
cols = 32 * ((cols + 31) // 32)
return init_func((rows, cols), dtype=dtype, device=device), state
elif to_order == "col_turing":
# blocks of 32 columns and 8 rows
cols = 32 * ((cols + 31) // 32)
rows = 8 * ((rows + 7) // 8)
return init_func((rows, cols), dtype=dtype, device=device), state
elif to_order == "col_ampere":
# blocks of 32 columns and 32 rows
cols = 32 * ((cols + 31) // 32)
rows = 32 * ((rows + 31) // 32)
return init_func((rows, cols), dtype=dtype, device=device), state
else:
raise NotImplementedError(f"To_order not supported: {to_order}")


@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def nvidia_transform(
A,
to_order,
from_order="row",
out=None,
transpose=False,
state=None,
ld=None,
):
if state is None:
state = (A.shape, from_order)
else:
from_order = state[1]
if out is None:
out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1])
else:
new_state = (state[1], to_order)
func = get_transform_func(A.dtype, from_order, to_order, transpose)

shape = state[0]
if len(shape) == 2:
dim1 = ct.c_int32(shape[0])
dim2 = ct.c_int32(shape[1])
elif ld is not None:
n = prod(shape)
dim1 = prod([shape[i] for i in ld])
dim2 = ct.c_int32(n // dim1)
dim1 = ct.c_int32(dim1)
else:
dim1 = ct.c_int32(shape[0] * shape[1])
dim2 = ct.c_int32(shape[2])

ptr = CUBLAS_Context.get_instance().get_context(A.device)
func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)

return out, new_state


def estimate_quantiles(
A: Tensor,
out: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -1715,6 +1607,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
return current_gnorm, clip_value, gnorm_scale


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
assert len(histogram.shape) == 2
assert histogram.dtype == torch.float32
Expand Down Expand Up @@ -2105,6 +1998,7 @@ def int8_mm_dequant(
return result


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_colrow_absmax(
A: torch.Tensor,
row_stats: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -2162,6 +2056,7 @@ def get_colrow_absmax(
return row_stats, col_stats, outlier_mask


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_row_absmax(A: torch.Tensor, threshold=0.0):
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
Expand Down Expand Up @@ -2366,58 +2261,6 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
return torch.ops.bitsandbytes.int8_vectorwise_quant.default(A, threshold)


@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):
prev_device = pre_call(A.device)
if state is None:
state = (A.shape, from_order)
else:
from_order = state[1]
if out is None:
out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
else:
new_state = (state[0], to_order) # (shape, order)

shape = state[0]
if len(shape) == 2:
dim1 = ct.c_int32(shape[0])
dim2 = ct.c_int32(shape[1])
else:
dim1 = ct.c_int32(shape[0] * shape[1])
dim2 = ct.c_int32(shape[2])

is_on_gpu([A, out])
if to_order == "col32":
if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2)
elif to_order == "col_turing":
if transpose:
lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2)
elif to_order == "col_ampere":
if transpose:
lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2)
else:
lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2)
elif to_order == "row":
if from_order == "col_turing":
lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2)
elif from_order == "col_ampere":
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else:
raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}")

post_call(prev_device)

return out, new_state


def spmm_coo(
cooA: Union[COOSparseTensor, torch.Tensor],
B: torch.Tensor,
Expand Down Expand Up @@ -2692,29 +2535,3 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
return x.to(dtype)
else:
return None


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def extract_outliers(A, SA, idx):
shapeA = SA[0]
formatA = SA[1]
assert formatA in ["col_turing", "col_ampere"]
assert A.device.type == "cuda"

out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)

idx_size = ct.c_int32(idx.numel())
rows = ct.c_int32(shapeA[0])
cols = ct.c_int32(shapeA[1])
ptrA = get_ptr(A)
ptrIdx = get_ptr(idx)
ptrOut = get_ptr(out)

prev_device = pre_call(A.device)
if formatA == "col_turing":
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
post_call(prev_device)

return out
4 changes: 1 addition & 3 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch.nn.functional as F

import bitsandbytes as bnb
from bitsandbytes.autograd._functions import get_tile_inds, undo_layout
from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import (
Expand Down Expand Up @@ -654,8 +653,7 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format]

if weight_format != "row":
tile_indices = get_tile_inds(weight_format, weight.device)
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
raise ValueError(f"Only 'row' weight format is supported, got {weight_format}")


class Embedding8bit(nn.Embedding):
Expand Down
Loading

0 comments on commit 242c602

Please sign in to comment.