diff --git a/.github/workflows/regression_test_rocm.yml b/.github/workflows/regression_test_rocm.yml new file mode 100644 index 0000000000..9a9a6c0071 --- /dev/null +++ b/.github/workflows/regression_test_rocm.yml @@ -0,0 +1,49 @@ +name: Run Regression Tests on ROCm + +on: + push: + branches: + - main + tags: + - ciflow/rocm/* + +concurrency: + group: regression_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} + cancel-in-progress: true + +env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + +jobs: + test-nightly: + strategy: + fail-fast: false + matrix: + include: + - name: ROCM Nightly + runs-on: linux.rocm.gpu.torchao + torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/rocm6.3' + gpu-arch-type: "rocm" + gpu-arch-version: "6.3" + + permissions: + id-token: write + contents: read + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + timeout: 120 + no-sudo: ${{ matrix.gpu-arch-type == 'rocm' }} + runner: ${{ matrix.runs-on }} + gpu-arch-type: ${{ matrix.gpu-arch-type }} + gpu-arch-version: ${{ matrix.gpu-arch-version }} + submodules: recursive + script: | + conda create -n venv python=3.9 -y + conda activate venv + python -m pip install --upgrade pip + pip install ${{ matrix.torch-spec }} + pip install -r dev-requirements.txt + pip install . + export CONDA=$(dirname $(dirname $(which conda))) + export LD_LIBRARY_PATH=$CONDA/lib/:$LD_LIBRARY_PATH + pytest test --verbose -s diff --git a/benchmarks/float8/bench_linear_float8.py b/benchmarks/float8/bench_linear_float8.py index d160d7241d..a7b1e17934 100644 --- a/benchmarks/float8/bench_linear_float8.py +++ b/benchmarks/float8/bench_linear_float8.py @@ -23,10 +23,6 @@ ScalingType, ) from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_linear_utils import ( - linear_requires_sync, - sync_float8_amax_and_scale_history, -) from torchao.float8.float8_tensor import ScaledMMConfig # estimating TOPs for matmuls in fp32, fp16, fp8 @@ -122,39 +118,18 @@ def main( scaling_type_grad_output = ScalingType(scaling_type_grad_output) scaling_granularity = ScalingGranularity(scaling_granularity) - if scaling_type_input is ScalingType.STATIC: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) - if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - ) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, @@ -185,7 +160,7 @@ def main( copy.deepcopy(linear_ref), config=config, ) - scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}" + scaling_repr = linear_float8.extra_repr() if fast_accum: linear_float8.forward_config = ScaledMMConfig(False, True, False) @@ -196,8 +171,6 @@ def main( ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward() def float8_forw_backward(): - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(linear_float8) linear_float8(input_tensor).sum().backward() def n_times(n, fn, *args, **kwargs): diff --git a/benchmarks/float8/bench_multi_gpu.py b/benchmarks/float8/bench_multi_gpu.py deleted file mode 100644 index 34a690edbe..0000000000 --- a/benchmarks/float8/bench_multi_gpu.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import os -from typing import Callable - -import fire -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import torch.nn as nn -import torch.utils.benchmark as benchmark -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType -from torchao.float8.float8_linear_utils import ( - convert_to_float8_training, - sync_float8_amax_and_scale_history, -) - -torch.manual_seed(0) - -# TODO: Add more shapes for the benchmark -B, M, K, N = 32, 1024, 1024, 1024 -lr = 0.01 - -config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), -) - - -def benchmark_torch_function_in_microseconds( - func: Callable, - *args, - **kwargs, -) -> float: - t0 = benchmark.Timer( - stmt="func(*args, **kwargs)", - globals={"args": args, "kwargs": kwargs, "func": func}, - ) - return t0.blocked_autorange().median * 1e6 - - -def setup(rank, world_size): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - - # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) - - -def cleanup(): - dist.destroy_process_group() - - -def get_model(K, N, is_fp8, base_dtype=torch.float32): - modules = [ - nn.Linear(K, N, dtype=base_dtype), - nn.ReLU(), - ] - N_LAYERS = 20 - # N linear layers - for _ in range(N_LAYERS - 1): - modules.append(nn.Linear(N, N, dtype=base_dtype)) - modules.append(nn.ReLU()) - m = nn.Sequential(*modules) - if is_fp8: - convert_to_float8_training( - m, - config=config, - ) - return m - - -def fsdp_main(rank, world_size, args): - setup(rank, world_size) - torch.cuda.set_device(rank) - - base_dtype, input_global, compile = args - - # basic distributed data sampling - assert B % world_size == 0 - bsz_local_start = int(rank / world_size * B) - bsz_local_end = int((rank + 1) / world_size * B) - input_tensor = input_global[bsz_local_start:bsz_local_end].to(rank) - - fp8_model = get_model(K, N, is_fp8=True, base_dtype=base_dtype).to(rank) - # Need use_orig_params=True to compile FSDP - fp8_model = FSDP(fp8_model, use_orig_params=True) - fp8_optimizer = torch.optim.SGD(fp8_model.parameters(), lr=lr * world_size) - - # Run one iteration to make compile work, see experiments doc for more context of this issue. - fp8_optimizer.zero_grad() - y_local = fp8_model(input_tensor) - y_local.sum().backward() - fp8_optimizer.step() - sync_float8_amax_and_scale_history(fp8_model) - - sync_float8_func = sync_float8_amax_and_scale_history - if compile: - # TODO: Need to fix issues with compile - fp8_model = torch.compile(fp8_model) - sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) - - def float8_forw_backward(): - fp8_optimizer.zero_grad() - y_local = fp8_model(input_tensor) - y_local.sum().backward() - fp8_optimizer.step() - sync_float8_func(fp8_model) - - ref_model = get_model(K, N, is_fp8=False, base_dtype=base_dtype).to(rank) - ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=lr * world_size) - if compile: - ref_model = torch.compile(ref_model) - - ref_model = FSDP(ref_model, use_orig_params=True) - - def ref_forw_backward(): - ref_optimizer.zero_grad() - ref_model(input_tensor).sum().backward() - ref_optimizer.step() - - def run_n_iterations(n, fn): - for _ in range(n): - fn() - # make sure training is done on all ranks - dist.barrier() - - # warmup - run_n_iterations(50, ref_forw_backward) - run_n_iterations(50, float8_forw_backward) - - N_ITER = 50 - ref_time = ( - benchmark_torch_function_in_microseconds( - run_n_iterations, N_ITER, ref_forw_backward - ) - * 1e-6 - / N_ITER - ) - float8_time = ( - benchmark_torch_function_in_microseconds( - run_n_iterations, N_ITER, float8_forw_backward - ) - * 1e-6 - / N_ITER - ) - - if rank == 0: - print("ref_time", ref_time) - print("float8_time", float8_time) - print("float8 speedup", ref_time / float8_time) - - cleanup() - - -def run(compile: bool): - base_dtype = torch.bfloat16 - WORLD_SIZE = torch.cuda.device_count() - print(f"{base_dtype = }") - print(f"{compile = }") - print(f"{WORLD_SIZE = }") - - # generate input data - ref_input = torch.randn(B, M, K).cuda().to(base_dtype) - # run fsdp model - args = (base_dtype, ref_input, compile) - mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) - - -# Usgae: -# CUDA_VISIBLE_DEVICES=0,1 python benchmarks/bench_multi_gpu.py -if __name__ == "__main__": - fire.Fire(run) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 684ed0af2a..6f30e5eff7 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -58,9 +58,7 @@ ) from torchao.float8 import ( - CastConfig, Float8LinearConfig, - ScalingType, convert_to_float8_training, ) from torchao.float8.roofline_utils import ( @@ -219,24 +217,6 @@ def run( scaling_type_weight="dynamic", scaling_type_grad_output="dynamic", ) - fp8_mem_time_sympy_del_limit = get_float8_mem_sympy( - M, - K, - N, - model_torch_compile_limitations=True, - scaling_type_input="delayed", - scaling_type_weight="delayed", - scaling_type_grad_output="delayed", - ) - fp8_mem_time_sympy_del_nolimit = get_float8_mem_sympy( - M, - K, - N, - model_torch_compile_limitations=False, - scaling_type_input="delayed", - scaling_type_weight="delayed", - scaling_type_grad_output="delayed", - ) if gemm_time_strategy == "roofline": bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16) @@ -258,16 +238,12 @@ def run( # roofline memory overhead estimates "fp8_oh_dyn_limit", "fp8_oh_dyn_nolimit", - "fp8_oh_del_limit", - "fp8_oh_del_nolimit", # actual e2e measurements "bf16_s", "fp8_dyn_s", - "fp8_del_s", "fp8_dyn_axs_s", # 'fp8_lw_s', "fp8_dyn_sp", - "fp8_del_sp", "fp8_dyn_axs_sp", # 'fp8_lw_sp', ] @@ -309,12 +285,6 @@ def run( fp8_mem_time_dyn_nolimit_s = ( fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val) ) - fp8_mem_time_del_limit_s = ( - fp8_mem_time_sympy_del_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val) - ) - fp8_mem_time_del_nolimit_s = ( - fp8_mem_time_sympy_del_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val) - ) # create the model m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16() @@ -333,19 +303,6 @@ def run( m_fp8_dyn = torch.compile(m_fp8_dyn) fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x) - # get the float8 delayed scaling gpu kernel time - torch._dynamo.reset() - config = Float8LinearConfig( - enable_amax_init=False, - enable_pre_and_post_forward=False, - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), - ) - m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config) - m_fp8_del = torch.compile(m_fp8_del) - fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x) - # get the float8 dynamic axiswise scaling gpu kernel time torch._dynamo.reset() config = Float8LinearConfig.from_recipe_name("rowwise") @@ -374,16 +331,12 @@ def run( # roofline overhead estimates fp8_mem_time_dyn_limit_s, fp8_mem_time_dyn_nolimit_s, - fp8_mem_time_del_limit_s, - fp8_mem_time_del_nolimit_s, # e2e numbers bf16_time_actual_s, fp8_dyn_time_actual_s, - fp8_del_time_actual_s, fp8_dyn_axs_time_actual_s, # fp8_lw_time_actual_s, bf16_time_actual_s / fp8_dyn_time_actual_s, - bf16_time_actual_s / fp8_del_time_actual_s, bf16_time_actual_s / fp8_dyn_axs_time_actual_s, # bf16_time_actual_s / fp8_lw_time_actual_s, ] diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_lowp_training.py similarity index 74% rename from benchmarks/float8/profile_linear_float8.py rename to benchmarks/float8/profile_lowp_training.py index 687684d4e2..dd629e7f95 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_lowp_training.py @@ -4,6 +4,9 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +# This is a convenience script to profile fwd+bwd of individual layers with +# float8 training or mx training on a single GPU. + import copy import functools import io @@ -33,21 +36,19 @@ kernel_name_to_category, parse_bw_and_kernel_name, profiler_output_to_filtered_time_by_kernel_name, - profiler_output_to_gpu_time_for_key, update_triton_kernels_in_prof_chome_trace_with_torch_logs, ) -from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( Float8LinearConfig, - ScalingType, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) -from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.prototype.mx_formats.config import MXLinearConfig +from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear +from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.prototype.mx_formats.utils import to_blocked # don't truncate long kernel names pd.options.display.max_colwidth = 100 @@ -261,7 +262,6 @@ def profile_function( # set up AC for max(abs(tensor)) # context: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts ops_to_save = [ - torch.ops.aten.abs.default, torch.ops.aten.max.default, ] @@ -279,16 +279,14 @@ def policy_fn(ctx, op, *args, **kwargs): def main( profile_path_prefix: pathlib.Path, compile: bool = True, - scaling_type_input: str = "dynamic", - scaling_type_weight: str = "dynamic", - scaling_type_grad_output: str = "dynamic", - recipe_name: Optional[str] = None, + float8_recipe_name: Optional[str] = None, + mx_recipe_name: Optional[str] = None, model_type: str = "linear", - dtype_filter: str = "both", - add_inductor_metadata_to_trace: bool = True, - enable_sync_amax_history: bool = True, + experiment_filter: str = "both", + add_inductor_metadata_to_trace: bool = False, enable_activation_checkpointing: bool = False, - enable_float8_delayed_scaling_inductor_passes: bool = False, + mode_filter: str = "fwd_bwd", + forward_only: bool = False, ): assert model_type in ( "linear", @@ -296,41 +294,41 @@ def main( "norm_ffn_norm", "norm_ffn_norm_small", ), "unsupported" - assert dtype_filter in ("both", "float8", "bfloat16") - - scaling_type_input = ScalingType(scaling_type_input) - scaling_type_weight = ScalingType(scaling_type_weight) - scaling_type_grad_output = ScalingType(scaling_type_grad_output) - - if recipe_name is None: - config = get_test_float8_linear_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - emulate=False, + assert experiment_filter in ( + "both", + "lowp", + "ref", + ), "experiment_filter must be one of `both`, `lowp`, `ref`" + assert ( + mode_filter + in ( + "fwd_bwd", + "fwd", + "cast_only", + "cast_with_to_blocked", ) - elif recipe_name is not None: - config = Float8LinearConfig.from_recipe_name(recipe_name) - - scaling_repr = "_".join( - [ - s.short_str() - for s in (scaling_type_input, scaling_type_weight, scaling_type_grad_output) - ] - ) + ), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`" + if mode_filter == "cast_only": + assert experiment_filter == "lowp", "unsupported" + + assert not ( + float8_recipe_name is not None and mx_recipe_name is not None + ), "either float8_recipe_name or mx_recipe_name can be specified, but not both" + + if float8_recipe_name is None and mx_recipe_name is None: + config = Float8LinearConfig() + elif float8_recipe_name is not None: + config = Float8LinearConfig.from_recipe_name(float8_recipe_name) + elif mx_recipe_name is not None: + config = MXLinearConfig.from_recipe_name(mx_recipe_name) print(f"Compile is set to | {compile}") print(f"model_type is set to | {model_type}") - print(f"scaling_repr is set to | {scaling_repr}") print( f"enable_activation_checkpointing is set to {enable_activation_checkpointing}" ) - print( - f"enable_float8_delayed_scaling_inductor_passes is set to {enable_float8_delayed_scaling_inductor_passes}" - ) - - if enable_float8_delayed_scaling_inductor_passes: - _prototype_register_float8_delayed_scaling_inductor_passes() + print(f"mode_filter is set to {mode_filter}") + print(f"config: {config}") device = "cuda" ref_dtype = torch.bfloat16 @@ -371,49 +369,74 @@ def main( m_ref = m_ref.to(device).to(ref_dtype) - m_float8 = copy.deepcopy(m_ref) - convert_to_float8_training(m_float8, config=config) + # get gradient shape + with torch.no_grad(): + _ = m_ref(input_tensor) + grad_output = torch.ones_like(_) + + m_lowp = copy.deepcopy(m_ref) + if mx_recipe_name is None: + convert_to_float8_training(m_lowp, config=config) + else: + swap_linear_with_mx_linear(m_lowp, config=config) + + # this function is only used for cast_only + to_mx_func = MXTensor.to_mx + + # this function is used for cast_with_to_blocked + def cast_with_to_blocked(x_hp): + x_mx = MXTensor.to_mx( + x_hp, + config.elem_dtype, + config.block_size, + gemm_kernel_choice=config.gemm_kernel_choice, + ) + m, k = x_hp.shape + scale_blocked = to_blocked(x_mx._scale_e8m0.reshape(m, k // config.block_size)) + return x_mx._data, scale_blocked + + print("m_ref", m_ref) + print("m_lowp", m_lowp) + print("input_tensor.shape", input_tensor.shape) + print("grad_output.shape", grad_output.shape) + print() def ref_forw_backward(x): + assert mode_filter not in ("cast_only", "cast_with_to_blocked"), "unsupported" if enable_activation_checkpointing: out = checkpoint(m_ref, x, use_reentrant=False, context_fn=context_fn) else: out = m_ref(x) - out.sum().backward() + if mode_filter == "fwd_bwd": + out.backward(grad_output) + + def lowp_forw_backward_wrapper(x): + if mode_filter == "cast_only": + # just cast and return early + _input_tensor_mx = to_mx_func( + input_tensor, + config.elem_dtype, + config.block_size, + gemm_kernel_choice=config.gemm_kernel_choice, + ) + return + elif mode_filter == "cast_with_to_blocked": + _input_tensor_mx, scale = cast_with_to_blocked(input_tensor) + return - def float8_forw(x): if enable_activation_checkpointing: - out = checkpoint(m_float8, x, use_reentrant=False, context_fn=context_fn) + out = checkpoint(m_lowp, x, use_reentrant=False, context_fn=context_fn) else: - out = m_float8(x) - return out - - sync_amax_history = sync_float8_amax_and_scale_history - - def float8_forw_backward_wrapper(x): - # sync_float8_amax_and_scale_history is not full graph torch - # compile friendly, so we add a high level wrapper to allow - # inspection of the fw+bw torch.compile without the scale - # syncing code - # TODO(future): make this better - if linear_requires_sync(config) and enable_sync_amax_history: - with record_function("scale_amax_and_scales"): - sync_amax_history(m_float8) - out = float8_forw(x) - - # out.sum().backward() is also not torch.compile fullgraph - # friendly - with record_function("backward"): - out.sum().backward() + out = m_lowp(x) + if mode_filter == "fwd_bwd": + with record_function("backward"): + out.backward(grad_output) if compile: m_ref = torch.compile(m_ref, fullgraph=True) - float8_forw = torch.compile(float8_forw, fullgraph=True) - # Note: it's faster to compile the combination of sync_amax_history wit - # forward because we only look up from dynamo cache once. - # However, compiling the sync function separately makes it more - # convenient to analyze the total time spent on it. - sync_amax_history = torch.compile(sync_amax_history) + m_lowp = torch.compile(m_lowp, fullgraph=True) + to_mx_func = torch.compile(to_mx_func, fullgraph=True) + cast_with_to_blocked = torch.compile(cast_with_to_blocked, fullgraph=True) # if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output # to populate triton kernel bandwidth further down in the script @@ -423,15 +446,21 @@ def float8_forw_backward_wrapper(x): else: f = io.StringIO() context = redirect_stdout(f) + + # if we are skipping forward, enable torch.no_grad() + maybe_no_grad_context = ( + torch.no_grad() if mode_filter != "fwd_bwd" else nullcontext() + ) + try: - with context: + with context, maybe_no_grad_context: profile_iters = 5 - ref_times, float8_times = None, None + ref_times, lowp_times = None, None data = [] num_leaf_tensors = 1 + len(list(m_ref.parameters())) - if dtype_filter != "float8": + if experiment_filter != "lowp": # Profile Reference Model print("profiling ref") ref_trace_suffix = f"_{model_type}_ref_compile_{compile}.json" @@ -477,50 +506,46 @@ def float8_forw_backward_wrapper(x): ] ) - if dtype_filter != "bfloat16": - # Profile Float8 Model - print("profiling float8") - float8_trace_suffix = ( - f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json" - ) - float8_log_suffix = ( - f"_{model_type}_float8_compile_{compile}_{scaling_repr}.txt" - ) - trace_float8_path = profile_path_prefix + float8_trace_suffix - log_float8_path = profile_path_prefix + float8_log_suffix - trace_float8_modified_path = trace_float8_path.replace( + if experiment_filter != "ref": + # Profile lowp Model + print("profiling lowp") + lowp_trace_suffix = f"_{model_type}_lowp_compile_{compile}.json" + lowp_log_suffix = f"_{model_type}_lowp_compile_{compile}.txt" + trace_lowp_path = profile_path_prefix + lowp_trace_suffix + log_lowp_path = profile_path_prefix + lowp_log_suffix + trace_lowp_modified_path = trace_lowp_path.replace( ".json", "_modified.json" ) profile_config = ProfileConfig( - trace_float8_path, - log_float8_path, - trace_float8_modified_path, - float8_trace_suffix, + trace_lowp_path, + log_lowp_path, + trace_lowp_modified_path, + lowp_trace_suffix, iters=profile_iters, warmup_iters=2, sync=True, ) p = profile_function( profile_config, - float8_forw_backward_wrapper, + lowp_forw_backward_wrapper, add_inductor_metadata_to_trace, input_tensor, ) - print(f"saved profiling trace to {trace_float8_path}") + print(f"saved profiling trace to {trace_lowp_path}") if add_inductor_metadata_to_trace: - print(f"saved torch logs to {log_float8_path}") - print(f"saved modified trace to {trace_float8_modified_path}") - float8_times = profiler_output_to_filtered_time_by_kernel_name( + print(f"saved torch logs to {log_lowp_path}") + print(f"saved modified trace to {trace_lowp_modified_path}") + lowp_times = profiler_output_to_filtered_time_by_kernel_name( p, profile_iters, num_leaf_tensors ) total_time_ms = ( - sum(v for v in float8_times.values()) / 1e3 / profile_iters + sum(v for v in lowp_times.values()) / 1e3 / profile_iters ) - for k, v in float8_times.items(): + for k, v in lowp_times.items(): v_ms = v / 1e3 / profile_iters data.append( [ - "1_float8", + "1_lowp", k, kernel_name_to_category(k), v / 1e3 / profile_iters, @@ -529,18 +554,12 @@ def float8_forw_backward_wrapper(x): ] ) - # get the time spent per user annotation - sync_time_us = profiler_output_to_gpu_time_for_key( - p, "scale_amax_and_scales" - ) - sync_time_ms = sync_time_us / profile_iters / 1e3 - print(f"Sync time ms: {sync_time_ms}") - finally: if f is not None: # print the redirected stdout back to regular stdout print(f.getvalue()) + # TODO(future PR): this seems to no longer work, fix it or delete it if os.environ.get("TORCHINDUCTOR_PROFILE", "") != "": # populate the triton kernel bandwidth for line in f.getvalue().split("\n"): @@ -578,21 +597,13 @@ def float8_forw_backward_wrapper(x): fill_value=0, margins=True, ) - # drop last row, which has totals across ref + float8 which does not make sense + # drop last row, which has totals across ref + lowp which does not make sense df_p = df_p[:-1] df_p = df_p.transpose() - if dtype_filter == "both": - df_p["f8_div_ref"] = df_p["1_float8"] / df_p["0_ref"] - df_p["ref_div_f8"] = df_p["0_ref"] / df_p["1_float8"] - - # calculate sync time as pct of total float time - # note: this time is not useful if TORCHINDUCTOR_PROFILE is on - total_float8_ms = df_p.iloc[3]["1_float8"] - sync_approx_ratio = sync_time_ms / total_float8_ms - print( - f"\nFloat8 amax/scale sync approx ratio of total time: {sync_approx_ratio:.3f}" - ) + if experiment_filter == "both": + df_p["lowp_div_ref"] = df_p["1_lowp"] / df_p["0_ref"] + df_p["ref_div_lowp"] = df_p["0_ref"] / df_p["1_lowp"] print("\nSummary of time (ms) by kernel category\n\n", df_p) diff --git a/benchmarks/float8/utils.py b/benchmarks/float8/utils.py index 60e402e60e..a7faf4757d 100644 --- a/benchmarks/float8/utils.py +++ b/benchmarks/float8/utils.py @@ -73,14 +73,6 @@ def profiler_output_to_filtered_time_by_kernel_name( # forward pass sum assert e.count == num_iter, f"unexpected number of iter for {e.key}" continue - elif e.key == "aten::fill_": - # filling the forward pass sum with 1.0 - assert e.count == num_iter, f"unexpected number of iter for {e.key}" - continue - elif e.key == "aten::copy_": - # copying 1.0 from grad_out of `sum` to grad_out of next op - assert e.count == num_iter, f"unexpected number of iter for {e.key}" - continue elif e.key == "aten::add_": # accumulating gradients into leaf tensors assert e.count == ( @@ -110,25 +102,16 @@ def profiler_output_to_gpu_time_for_key(prof, key): def kernel_name_to_category(k): # number prefix is for easy sorting - if k in ("aten::mm", "aten::addmm", "aten::_scaled_mm"): - return "0_gemm" - elif ( - # max(abs(tensor)) - ("abs" in k and "max" in k) - or - # casting pointwise to float8 - ("clamp" in k) - or - # things related to scaled_mm - ("scaled_mm" in k) - or - # syncing amaxes and scales - ("roll" in k) + if k in ( + "aten::mm", + "aten::addmm", + "aten::_scaled_mm", + "torchao::mx_fp8_bf16", + "torchao::mx_fp4_bf16", ): - # note: the above filter is approximate and will give false - # positives if model code contains other code to abs/max/clamp - return "1_f8_overhead" - return "2_other" + return "0_gemm" + else: + return "1_other" def parse_bw_and_kernel_name(line): diff --git a/benchmarks/microbenchmarks/test/results/results.csv b/benchmarks/microbenchmarks/test/results/results.csv deleted file mode 100644 index 036d6b532c..0000000000 --- a/benchmarks/microbenchmarks/test/results/results.csv +++ /dev/null @@ -1,13 +0,0 @@ -quantization,m,k,n,shape_name,precision,compile,device,model_type,output_dir,name,benchmark_model_inference_in_microseconds -baseline,1024,1024,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_baseline_linear_m1024_k1024_n1024_compile,64510.37060469389 -baseline,2048,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_baseline_linear_m2048_k4096_n1024_compile,53887.79062777758 -baseline,4096,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_baseline_linear_m4096_k4096_n1024_compile,36628.598207607865 -int8wo,1024,1024,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int8wo_linear_m1024_k1024_n1024_compile,56611.56056448817 -int8wo,2048,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int8wo_linear_m2048_k4096_n1024_compile,55212.84379065037 -int8wo,4096,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int8wo_linear_m4096_k4096_n1024_compile,51695.895195007324 -int4wo-128,1024,1024,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int4wo-128_linear_m1024_k1024_n1024_compile,40540.05299694836 -int4wo-128,2048,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int4wo-128_linear_m2048_k4096_n1024_compile,39183.96681547165 -int4wo-128,4096,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int4wo-128_linear_m4096_k4096_n1024_compile,40781.22219070792 -int4wo-128-hqq,1024,1024,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int4wo-128-hqq_linear_m1024_k1024_n1024_compile,37873.45583550632 -int4wo-128-hqq,2048,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int4wo-128-hqq_linear_m2048_k4096_n1024_compile,37539.9901997298 -int4wo-128-hqq,4096,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int4wo-128-hqq_linear_m4096_k4096_n1024_compile,38310.51839515567 diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 112cab8684..6b3a447070 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -20,6 +20,7 @@ quantize_, ) from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain +from torchao.testing.utils import skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, @@ -104,6 +105,7 @@ def test_tensor_core_layout_transpose(self): "apply_quant", get_quantization_functions(is_cusparselt_available, True, "cuda", True), ) + @skip_if_rocm("ROCm enablement in progress") def test_weights_only(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") if isinstance(apply_quant, AOBaseConfig): @@ -196,6 +198,7 @@ def apply_uint6_weight_only_quant(linear): "apply_quant", get_quantization_functions(is_cusparselt_available, True) ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") def test_print_quantized_module(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") if isinstance(apply_quant, AOBaseConfig): @@ -213,6 +216,7 @@ class TestAffineQuantizedBasic(TestCase): @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) + @skip_if_rocm("ROCm enablement in progress") def test_flatten_unflatten(self, device, dtype): if device == "cuda" and dtype == torch.bfloat16 and is_fbcode(): raise unittest.SkipTest("TODO: Failing for cuda + bfloat16 in fbcode") diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 76b6b74a3d..b60f3251dc 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -1,5 +1,6 @@ import unittest +import pytest import torch from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.testing._internal import common_utils @@ -27,6 +28,9 @@ except ModuleNotFoundError: has_gemlite = False +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) + class TestAffineQuantizedTensorParallel(DTensorTestBase): """Basic test case for tensor subclasses""" diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 8bb39b2cc8..0953e33b0f 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -27,6 +27,7 @@ fpx_weight_only, quantize_, ) +from torchao.testing.utils import skip_if_rocm from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) @@ -109,6 +110,7 @@ def test_to_copy_device(self, ebits, mbits): @parametrize("bias", [False, True]) @parametrize("dtype", [torch.half, torch.bfloat16]) @unittest.skipIf(is_fbcode(), reason="broken in fbcode") + @skip_if_rocm("ROCm enablement in progress") def test_fpx_weight_only(self, ebits, mbits, bias, dtype): N, OC, IC = 4, 256, 64 device = "cuda" diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index caa1a6c7bd..4ed90d06ca 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -33,6 +33,7 @@ nf4_weight_only, to_nf4, ) +from torchao.testing.utils import skip_if_rocm bnb_available = False @@ -111,6 +112,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype): @unittest.skipIf(not bnb_available, "Need bnb availble") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype): # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47 @@ -133,6 +135,7 @@ def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype): @unittest.skipIf(not bnb_available, "Need bnb availble") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_nf4_bnb_linear(self, dtype: torch.dtype): """ diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index e148d68abb..cf4077a78c 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -28,6 +28,7 @@ from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, ) +from torchao.testing.utils import skip_if_rocm from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -92,6 +93,7 @@ def test_basic_tensor_ops(self): # only test locally # print("x:", x[0]) + @skip_if_rocm("ROCm enablement in progress") def test_gpu_quant(self): for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: x = torch.randn(*x_shape) @@ -104,6 +106,7 @@ def test_gpu_quant(self): # make sure it runs opt(x) + @skip_if_rocm("ROCm enablement in progress") def test_pt2e_quant(self): from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( QuantizationConfig, diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 156c8abe87..463b618fa8 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn +from torchao.testing.utils import skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89, @@ -25,7 +26,6 @@ from torchao.float8.config import ( - CastConfig, Float8LinearConfig, Float8LinearRecipeName, ScalingGranularity, @@ -36,8 +36,6 @@ from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) from torchao.float8.float8_python_api import addmm_float8_unwrapped from torchao.float8.float8_scaling_utils import ( @@ -54,11 +52,9 @@ from torchao.float8.float8_utils import ( FP8_TYPES, compute_error, - config_has_stateful_scaling, fp8_tensor_statistics, tensor_to_scale, ) -from torchao.float8.stateful_float8_linear import StatefulFloat8Linear from torchao.testing.float8.test_utils import get_test_float8_linear_config random.seed(0) @@ -284,16 +280,10 @@ def _test_linear_impl( config: Float8LinearConfig, use_ac: bool = False, ): - if config_has_stateful_scaling(config): - m_fp8 = StatefulFloat8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) - else: - m_fp8 = Float8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) + m_fp8 = Float8Linear.from_float( + copy.deepcopy(m_ref), + config, + ) for _ in range(2): if use_ac: @@ -301,8 +291,6 @@ def _test_linear_impl( else: y_fp8 = m_fp8(x) y_fp8.sum().backward() - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m_fp8) if use_ac: y_ref = torch.utils.checkpoint.checkpoint(m_ref, x, use_reentrant=False) @@ -320,65 +308,21 @@ def _test_linear_impl( if m_ref.bias is not None: torch.testing.assert_close(m_ref.bias.grad, m_fp8.bias.grad) - # verify all of the amax buffers got updated - if linear_requires_sync(config): - # only check buffers that are actually used, based on per-tensor - # scaling settings - amax_buffer_names = [] - amax_history_buffer_names = [] - scale_buffer_names = [] - if config.cast_config_input.scaling_type is ScalingType.DELAYED: - amax_buffer_names.append("fp8_amax_input") - amax_history_buffer_names.append("fp8_amax_history_input") - scale_buffer_names.append("fp8_scale_input") - if config.cast_config_weight.scaling_type is ScalingType.DELAYED: - amax_buffer_names.append("fp8_amax_weight") - amax_history_buffer_names.append("fp8_amax_history_weight") - scale_buffer_names.append("fp8_scale_weight") - if config.cast_config_grad_output.scaling_type is ScalingType.DELAYED: - amax_buffer_names.append("fp8_amax_grad_output") - amax_history_buffer_names.append("fp8_amax_history_grad_output") - scale_buffer_names.append("fp8_scale_grad_output") - - # verify all of the amax buffers got updated - max_float8_pos = {torch.finfo(dtype).max for dtype in FP8_TYPES} - for buffer_name in amax_buffer_names: - buffer_value = getattr(m_fp8, buffer_name) - for init_val in max_float8_pos: - assert torch.ne( - buffer_value, torch.tensor(init_val) - ), f"{buffer_name} not filled, current value {buffer_value}" - - # verify all of the amax history buffers got updated - for buffer_name in amax_history_buffer_names: - buffer_value = getattr(m_fp8, buffer_name) - assert torch.max(buffer_value) > 0.0, f"{buffer_name} not filled" - - # verify all of the scale buffers got updated - for buffer_name in scale_buffer_names: - buffer_value = getattr(m_fp8, buffer_name) - assert torch.ne( - buffer_value, torch.tensor(1.0) - ), f"{buffer_name} not filled, current value {buffer_value}" - - # verify initialization flags got updated - assert m_fp8.is_amax_initialized, "Amax was not properly initialized" - @pytest.mark.parametrize( "emulate", [True, False] if is_sm_at_least_89() else [True] ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize( "scaling_type_input", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) @@ -426,6 +370,7 @@ def test_linear_from_config_params( @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_bias", [True, False]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @skip_if_rocm("ROCm enablement in progress") def test_linear_from_recipe( self, recipe_name, @@ -465,9 +410,6 @@ def test_autocast_outputs( nn.Linear(32, 32, device="cuda", dtype=linear_dtype), ) config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), emulate=emulate, ) m = convert_to_float8_training(copy.deepcopy(m_ref), config=config) @@ -475,21 +417,15 @@ def test_autocast_outputs( # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) y = m(x) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on with torch.autocast("cuda"): y = m(x) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" with torch.autocast("cuda", dtype=torch.bfloat16): y = m(x) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) assert ( y.dtype == torch.bfloat16 ), f"y.dtype is {y.dtype}, expected {torch.bfloat16}" @@ -508,40 +444,18 @@ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): # Cast the module to dtype m = m.to(dtype=linear_dtype) - if linear_requires_sync(config): - # Check amax buffer types - for key in [ - "fp8_amax_input", - "fp8_amax_history_input", - "fp8_scale_input", - "fp8_amax_weight", - "fp8_amax_history_weight", - "fp8_scale_weight", - "fp8_amax_grad_output", - "fp8_amax_history_grad_output", - "fp8_scale_grad_output", - ]: - assert ( - m._buffers[key].dtype == torch.float32 - ), f"{key}.dtype is {m._buffers[key].dtype}, expected torch.float32" # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on with torch.autocast("cuda"): - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" with torch.autocast("cuda", dtype=torch.bfloat16): - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) y = m(x) assert ( y.dtype == torch.bfloat16 @@ -550,7 +464,6 @@ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): def test_repr(self): m = nn.Linear(32, 16) config = Float8LinearConfig( - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), emulate=True, ) m = Float8Linear.from_float( @@ -558,7 +471,7 @@ def test_repr(self): config=config, ) s = m.__repr__() - assert "i:dyn_ten_e4m3,w:del_ten_e4m3,go:dyn_ten_e5m2" in s + assert "i:dyn_ten_e4m3,w:dyn_ten_e4m3,go:dyn_ten_e5m2" in s @unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available") def test_inference_mode(self): diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 0c02db26a6..7c31bf6f08 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -7,7 +7,6 @@ import random import sys import unittest -from dataclasses import replace from io import StringIO import pytest @@ -26,7 +25,6 @@ from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend -from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( CastConfig, Float8LinearConfig, @@ -35,20 +33,11 @@ e4m3_dtype, ) from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_linear_utils import ( - convert_to_float8_training, - get_float8_layers, - sync_float8_amax_and_scale_history, -) from torchao.float8.float8_scaling_utils import ( - hp_tensor_to_float8_delayed, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig -from torchao.float8.float8_utils import config_has_stateful_scaling -from torchao.float8.stateful_float8_linear import StatefulFloat8Linear from torchao.testing.float8.test_utils import get_test_float8_linear_config -from torchao.utils import is_fbcode def _test_compile_base( @@ -66,16 +55,10 @@ def _test_compile_base( x_ref = copy.deepcopy(x) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - if config_has_stateful_scaling(config): - m_fp8 = StatefulFloat8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) - else: - m_fp8 = Float8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) + m_fp8 = Float8Linear.from_float( + copy.deepcopy(m_ref), + config, + ) m_fp8 = torch.compile(m_fp8, backend=backend, fullgraph=fullgraph) m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph) @@ -94,16 +77,14 @@ def _test_compile_base( @pytest.mark.parametrize("fullgraph", [True]) -@pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] -) +@pytest.mark.parametrize("scaling_type_input", [ScalingType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @@ -133,16 +114,14 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) -@pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] -) +@pytest.mark.parametrize("scaling_type_input", [ScalingType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @@ -171,16 +150,14 @@ def test_aot_eager( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) -@pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] -) +@pytest.mark.parametrize("scaling_type_input", [ScalingType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @unittest.skipIf( not torch.cuda.is_available() or not is_sm_at_least_89(), @@ -241,16 +218,12 @@ class TestGraphBreaks(DynamoTestCase): class MockLinear(torch.nn.Module): def __init__(self, graph_break: bool): super().__init__() - self.register_buffer("fp8_amax_x", torch.tensor(1.0)) - self.register_buffer("fp8_scale_x", torch.tensor(1.0)) self.graph_break = graph_break def forward(self, x): - x_fp8 = hp_tensor_to_float8_delayed( + x_fp8 = hp_tensor_to_float8_dynamic( x, - self.fp8_scale_x, e4m3_dtype, - self.fp8_amax_x, LinearMMConfig(), ) if self.graph_break: @@ -330,30 +303,6 @@ def test_float8_graph_output(self): ) -@unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), - "CUDA with float8 support not available", -) -def test_sync_amax_func(): - torch._dynamo.reset() - cnts = CompileCounterWithBackend("inductor") - module = torch.nn.Sequential( - nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) - ) - config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), - ) - float8_mod = convert_to_float8_training( - module, - config=config, - ) - compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts) - compiled_swap_func(float8_mod) - assert cnts.frame_count == 1, "Compiled graph should have 1 frame!" - - class capture_stderr(list): """ Replace sys.stderr with a temporary StringIO @@ -371,38 +320,6 @@ def __exit__(self, *args): sys.stderr = self.sys_stderr -@unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), - "CUDA with float8 support not available", -) -def test_sync_amax_func_cuda_graph_success(): - torch._dynamo.reset() - with capture_stderr() as stderr: - my_module = nn.Sequential( - nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) - ).to("cuda") - config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), - ) - convert_to_float8_training( - my_module, - config=config, - ) - inpt = torch.randn( - 16, 16, device="cuda", dtype=torch.float32, requires_grad=True - ) - sync_func = torch.compile( - sync_float8_amax_and_scale_history, mode="reduce-overhead", fullgraph=True - ) - fp8_layers = get_float8_layers(my_module) - my_module(inpt) - sync_func(my_module, fp8_layers) - - assert "skipping cudagraphs due to mutaton on input" not in stderr[0] - - @unittest.skipIf( not is_sm_at_least_89(), "CUDA not available", @@ -475,70 +392,5 @@ def test_dynamic_scale_numeric_parity( assert torch.equal(float8_eager._data, float8_compile._data) -@unittest.skipIf( - not is_sm_at_least_89() or not is_fbcode(), - "CUDA with float8 support not available; or not on fbcode (the test needs be run with the latest pytorch package)", -) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) -def test_delayed_scaling_pattern_replacement(dtype: torch.dtype): - from torch._inductor import config as inductor_config - from torch._inductor import metrics - - inductor_config.loop_ordering_after_fusion = True - - def clear_all(): - metrics.reset() - from torch._inductor.fx_passes.post_grad import ( - pass_patterns as post_grad_patterns_all, - ) - - post_grad_patterns_all[1].clear() - post_grad_patterns_all[1].seen_patterns.clear() - - def compile_and_run_single_layer(): - random.seed(0) - torch.manual_seed(0) - x_shape = (2048, 3072) - linear_dtype = dtype - - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype).requires_grad_() - m_ref = nn.Linear(3072, 2048, bias=True, device="cuda", dtype=linear_dtype) - - config = get_test_float8_linear_config( - ScalingType.DELAYED, - ScalingType.DELAYED, - ScalingType.DELAYED, - False, - ) - - config = replace(config, enable_amax_init=False) - - m_fp8 = StatefulFloat8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) - - m_fp8 = torch.compile(m_fp8, backend="inductor", fullgraph=True) - m_ref = torch.compile(m_ref, backend="inductor", fullgraph=True) - - y_fp8 = m_fp8(x) - y_fp8.sum().backward() - - return m_fp8.weight.grad - - clear_all() - ref_output = compile_and_run_single_layer() - ref_count_kernel = metrics.generated_kernel_count - - clear_all() - _prototype_register_float8_delayed_scaling_inductor_passes() - new_output = compile_and_run_single_layer() - new_count_kernel = metrics.generated_kernel_count - - torch.equal(ref_output, new_output) - # With the pattern replacement workaround, amax reduction kernels for the 3 tensors (weight, activation, gradient) are fused. - assert ref_count_kernel == new_count_kernel + 3 - - if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py index ca9f21dde1..1a6a888246 100644 --- a/test/float8/test_float8_utils.py +++ b/test/float8/test_float8_utils.py @@ -4,6 +4,7 @@ import torch from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 +from torchao.testing.utils import skip_if_rocm from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 if not TORCH_VERSION_AT_LEAST_2_5: @@ -30,6 +31,7 @@ # ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]), ], ) +@skip_if_rocm("ROCm enablement in progress") def test_round_scale_down_to_power_of_2_valid_inputs( test_case: dict, ): diff --git a/test/float8/test_fsdp.py b/test/float8/test_fsdp.py index 863256dc35..3017c8b539 100644 --- a/test/float8/test_fsdp.py +++ b/test/float8/test_fsdp.py @@ -35,11 +35,9 @@ FullyShardedDataParallel as FSDP, ) -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import Float8LinearConfig from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) from torchao.float8.float8_utils import compute_error @@ -77,19 +75,13 @@ def get_model(K, N, base_dtype=torch.float32): def fsdp_main(rank, world_size, args): setup(rank, world_size) torch.cuda.set_device(rank) + print("args", args) - emulate, base_dtype, compile, use_weight_dynamic_scaling = args + emulate, base_dtype, compile = args model = get_model(K, N, base_dtype=base_dtype).to(rank) model_fp8 = copy.deepcopy(model) - scaling_type_weight = ( - ScalingType.DYNAMIC if use_weight_dynamic_scaling else ScalingType.DELAYED - ) - config = Float8LinearConfig( - cast_config_weight=CastConfig(scaling_type=scaling_type_weight), - # TODO(future): delete this arg as it's always False - emulate=False, - ) + config = Float8LinearConfig() # Note: we only iterate over `scaling_type_weight` because FSDP only interacts # with weights. @@ -110,6 +102,7 @@ def fsdp_main(rank, world_size, args): # Note: we need two different inputs to properly measure the impact of # delayed scaling, before the first input uses dynamic scaling to # populate the buffers + # TODO(future PR): delete ^, since we deleted delayed scaling ref_input_global = [ torch.randn(B, M, K).cuda().to(base_dtype), torch.randn(B, M, K).cuda().to(base_dtype), @@ -133,16 +126,10 @@ def fsdp_main(rank, world_size, args): ref_grad_global[idx][bsz_local_start:bsz_local_end].to(rank) ) - sync_float8_func = sync_float8_amax_and_scale_history - if compile: - sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) - def forward_backward(model, optim, is_fp8, i): optim.zero_grad() y_local = model(ref_input_local[i]) y_local.backward(ref_grad_local[i]) - if is_fp8 and linear_requires_sync(config): - sync_float8_func(model) optim.step() return y_local @@ -193,7 +180,7 @@ def forward_backward(model, optim, is_fp8, i): cleanup() -def run(compile_fsdp: bool = False, use_weight_dynamic_scaling: bool = False): +def run(compile_fsdp: bool = False): base_dtype = torch.bfloat16 emulate = False @@ -207,7 +194,7 @@ def run(compile_fsdp: bool = False, use_weight_dynamic_scaling: bool = False): emulate = True WORLD_SIZE = torch.cuda.device_count() - args = (emulate, base_dtype, compile_fsdp, use_weight_dynamic_scaling) + args = (emulate, base_dtype, compile_fsdp) mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) diff --git a/test/float8/test_fsdp.sh b/test/float8/test_fsdp.sh index 3ff19d917d..6f135a2e76 100755 --- a/test/float8/test_fsdp.sh +++ b/test/float8/test_fsdp.sh @@ -4,12 +4,12 @@ set -e launch() { - echo "launching compile_fsdp $COMPILE, use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING" + echo "launching compile_fsdp $COMPILE" # the NCCL_DEBUG setting is to avoid log spew # the CUDA_VISIBLE_DEVICES setting is for easy debugging NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/float8/test_fsdp.py \ - --compile_fsdp $COMPILE --use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING + --compile_fsdp $COMPILE echo "✅ All Tests Passed ✅" } @@ -19,10 +19,5 @@ if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; exit fi -# COMPILE, USE_WEIGHT_DYNAMIC_SCALING -for i in False,False False,True True,False True,True -do - IFS=","; set -- $i; - COMPILE=$1; USE_WEIGHT_DYNAMIC_SCALING=$2 - launch -done +COMPILE=False launch +COMPILE=True launch diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index fbe5c9b508..a36fc3e249 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -43,6 +43,9 @@ if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) +if torch.version.hip is not None: + pytest.skip("ROCm enablement in progress", allow_module_level=True) + class TestFloat8Common: def broadcast_module(self, module: nn.Module) -> None: @@ -101,7 +104,6 @@ def test_transformer_parity(self): "precompute": [False, True], "scaling_type_weight": [ ScalingType.DYNAMIC, - ScalingType.DELAYED, ], "compile_transformer_block": [False, True], "dtype": [torch.float32, torch.bfloat16], @@ -119,8 +121,6 @@ def _test_transformer_parity( ): if not enable_fsdp_float8_all_gather and precompute: return - elif scaling_type_weight is ScalingType.DELAYED and precompute: - return # NOTE: Weight-tying does not compose with fp8 all-gather because the # embedding weight and output linear weight are tied but only the @@ -462,16 +462,10 @@ def test_fp32_fp8_single_module_parity(self): """ choices = itertools.product( [False, True], - [ScalingType.DYNAMIC, ScalingType.DELAYED, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) for enable_fsdp_float8_all_gather, scaling_type_weight in choices: - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig(scaling_type=scaling_type_weight) float8_linear_config1 = Float8LinearConfig( enable_fsdp_float8_all_gather=False, @@ -514,7 +508,7 @@ def test_fp32_fp8_multi_module_parity(self): """ choices = itertools.product( [False, True], - [ScalingType.DYNAMIC, ScalingType.DELAYED], + [ScalingType.DYNAMIC], ) for enable_fsdp_float8_all_gather, scaling_type_weight in choices: float8_linear_config1 = Float8LinearConfig( @@ -584,26 +578,6 @@ def test_bf16_mp_fp8_dynamic_multi_parity(self): self.get_local_inp(torch.bfloat16), ) - @unittest.skipIf(not TEST_CUDA, "no cuda") - def test_delayed_scaling_inplace_update(self): - """ - Verify that `WeightWithDelayedFloat8CastTensor` updates buffers inplace - """ - module = self.init_single_module() - float8_linear_config = Float8LinearConfig( - enable_fsdp_float8_all_gather=True, - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - ) - m_fp8 = convert_to_float8_training( - module, - config=float8_linear_config, - ) - - fp8_amax_weight_old = m_fp8.fp8_amax_weight.clone().detach() - dummy_mesh = None - data, scale = m_fp8.weight.fsdp_pre_all_gather(dummy_mesh) - self.assertNotEqual(fp8_amax_weight_old.item(), m_fp8.fp8_amax_weight.item()) - if __name__ == "__main__": run_tests() diff --git a/test/float8/test_fsdp_compile.py b/test/float8/test_fsdp_compile.py index 1d95801f67..a78a30925c 100644 --- a/test/float8/test_fsdp_compile.py +++ b/test/float8/test_fsdp_compile.py @@ -26,10 +26,8 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torchao.float8 import Float8LinearConfig -from torchao.float8.config import CastConfig, ScalingType from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - sync_float8_amax_and_scale_history, ) torch.manual_seed(0) @@ -63,10 +61,6 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): # https://gist.github.com/vkuzo/ed8e168fd9f7463f1fce34301334ab55 # to get around this, we can disable amax init config = Float8LinearConfig( - enable_amax_init=False, - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), emulate=emulate, ) @@ -102,7 +96,6 @@ def fsdp_main(rank, world_size, args): optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size) input_local = torch.randn(B, M, K, N, device="cuda") - sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) model = torch.compile(model) @@ -111,7 +104,6 @@ def fsdp_main(rank, world_size, args): with torch.autocast("cuda"): y_local = model(input_local) y_local.sum().backward() - sync_float8_func(model) optimizer.step() print("done!") diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 01e4cbb20d..f25c876189 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -31,8 +31,6 @@ ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) from torchao.float8.float8_utils import IS_ROCM, compute_error from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -115,7 +113,7 @@ def _test_impl(self, config: Float8LinearConfig) -> None: # Note: you need two different inputs to properly test numerics # of delayed scaling, because the first time around the initialization # logic of delayed scaling behaves as dynamic scaling - # TODO(future): also make unit tests do this properly + # TODO(future PR): delete ^, since we deleted delayed scaling shape = (1, 8192, 4096) data1 = torch.randn(*shape, device="cuda", dtype=data_dtype) data2 = torch.randn(*shape, device="cuda", dtype=data_dtype) @@ -127,36 +125,21 @@ def _test_impl(self, config: Float8LinearConfig) -> None: model_ref_out = model_ref(data2) model_ref_out.sum().backward() - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(model_fp8) model_fp8(data1).sum().backward() # zero out grads without stepping, since we just want to compare grads # of the second datum optim_fp8.zero_grad() - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(model_fp8) model_fp8_out = model_fp8(data2) model_fp8_out.sum().backward() out_sqnr = compute_error(model_ref_out, model_fp8_out) - any_static_scaling = ( - config.cast_config_input.scaling_type is ScalingType.STATIC - or config.cast_config_weight.scaling_type is ScalingType.STATIC - or config.cast_config_grad_output.scaling_type is ScalingType.STATIC - ) - if any_static_scaling: - assert out_sqnr > 10.0 - else: - assert out_sqnr > 20.0 + assert out_sqnr > 20.0 ref_name_to_grad = { name: param.grad for name, param in model_ref.named_parameters() } - if any_static_scaling: - grad_sqnr_threshold = 10.0 - else: - grad_sqnr_threshold = 20.0 + grad_sqnr_threshold = 20.0 for name, param in model_fp8.named_parameters(): ref_grad = ref_name_to_grad[name] @@ -166,15 +149,15 @@ def _test_impl(self, config: Float8LinearConfig) -> None: @pytest.mark.parametrize( "scaling_type_input", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.skipif( not is_sm_at_least_89(), reason="requires SM89 compatible machine" diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index d18ff59f99..7bbd52db09 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -9,6 +9,7 @@ quantize_, uintx_weight_only, ) +from torchao.testing.utils import skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, ) @@ -109,6 +110,7 @@ def test_hqq_plain_5bit(self): ref_dot_product_error=0.000704, ) + @skip_if_rocm("ROCm enablement in progress") def test_hqq_plain_4bit(self): self._test_hqq( dtype=torch.uint4, diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 56bcaf17df..4eccdc86e2 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -76,6 +76,7 @@ from torchao.quantization.utils import ( compute_error as SQNR, ) +from torchao.testing.utils import skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, @@ -95,6 +96,7 @@ except ModuleNotFoundError: has_gemlite = False + logger = logging.getLogger("INFO") torch.manual_seed(0) @@ -582,6 +584,7 @@ def test_per_token_linear_cpu(self): self._test_per_token_linear_impl("cpu", dtype) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") def test_per_token_linear_cuda(self): for dtype in (torch.float32, torch.float16, torch.bfloat16): self._test_per_token_linear_impl("cuda", dtype) @@ -700,6 +703,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -719,6 +723,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -912,6 +917,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -931,6 +937,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1102,6 +1109,7 @@ def test_gemlite_layout(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -1235,8 +1243,6 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): y_wo, (code,) = run_and_get_code(m_c, x) sqnr = compute_error(y_ref, y_wo) self.assertGreaterEqual(sqnr, 38) - if device == "cuda": - self.assertTrue("mixed_mm" in code, f"got code: {code}") @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") diff --git a/test/kernel/test_fused_kernels.py b/test/kernel/test_fused_kernels.py index c5bf6e17f0..9c5bc19aaf 100644 --- a/test/kernel/test_fused_kernels.py +++ b/test/kernel/test_fused_kernels.py @@ -11,6 +11,8 @@ import torch from galore_test_utils import get_kernel, make_copy, make_data +from torchao.testing.utils import skip_if_rocm + torch.manual_seed(0) MAX_DIFF_no_tf32 = 1e-5 MAX_DIFF_tf32 = 1e-3 @@ -104,6 +106,7 @@ def run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize("kernel, dtype, M, N, rank, allow_tf32", TEST_CONFIGS) +@skip_if_rocm("ROCm enablement in progress") def test_galore_fused_kernels(kernel, dtype, M, N, rank, allow_tf32): torch.backends.cuda.matmul.allow_tf32 = allow_tf32 diff --git a/test/kernel/test_galore_downproj.py b/test/kernel/test_galore_downproj.py index bab65fc2fb..fc8b784a9f 100644 --- a/test/kernel/test_galore_downproj.py +++ b/test/kernel/test_galore_downproj.py @@ -11,6 +11,7 @@ from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk from torchao.prototype.galore.kernels.matmul import triton_mm_launcher +from torchao.testing.utils import skip_if_rocm torch.manual_seed(0) @@ -29,6 +30,7 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS) +@skip_if_rocm("ROCm enablement in progress") def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype): torch.backends.cuda.matmul.allow_tf32 = allow_tf32 MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32 diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 1b91983bc0..1bfdf57aca 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -5,7 +5,11 @@ import torch from torchao.quantization import quantize_ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 +from torchao.testing.utils import skip_if_rocm +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_5, +) if TORCH_VERSION_AT_LEAST_2_3: from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_ @@ -113,6 +117,7 @@ def test_awq_loading(device, qdtype): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@skip_if_rocm("ROCm enablement in progress") def test_save_weights_only(): dataset_size = 100 l1, l2, l3 = 512, 256, 128 diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index d7d6fe7dc8..453210abda 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -26,6 +26,7 @@ from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8 +from torchao.testing.utils import skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, @@ -42,6 +43,8 @@ except ImportError: lpmm = None +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) _DEVICES = get_available_devices() @@ -112,6 +115,7 @@ class TestOptim(TestCase): ) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) + @skip_if_rocm("ROCm enablement in progress") def test_optim_smoke(self, optim_name, dtype, device): if optim_name.endswith("Fp8") and device == "cuda": if not TORCH_VERSION_AT_LEAST_2_4: @@ -185,6 +189,7 @@ def test_subclass_slice(self, subclass, shape, device): not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA", ) + @skip_if_rocm("ROCm enablement in progress") @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) def test_optim_8bit_correctness(self, optim_name): device = "cuda" @@ -413,6 +418,7 @@ def world_size(self) -> int: not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required." ) @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) + @skip_if_rocm("ROCm enablement in progress") def test_fsdp2(self): optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit] if torch.cuda.get_device_capability() >= (8, 9): @@ -523,6 +529,7 @@ def _test_fsdp2(self, optim_cls): not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required." ) @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) + @skip_if_rocm("ROCm enablement in progress") def test_uneven_shard(self): in_dim = 512 out_dim = _FSDP_WORLD_SIZE * 16 + 1 diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 02b41e8e32..d90990143c 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -20,6 +20,9 @@ TORCH_VERSION_AT_LEAST_2_5, ) +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) + class ToyLinearModel(torch.nn.Module): def __init__(self, m=512, n=256, k=128): diff --git a/test/prototype/test_splitk.py b/test/prototype/test_splitk.py index 48793ba907..37aeac1334 100644 --- a/test/prototype/test_splitk.py +++ b/test/prototype/test_splitk.py @@ -13,13 +13,15 @@ except ImportError: triton_available = False -from torchao.utils import skip_if_compute_capability_less_than + +from torchao.testing.utils import skip_if_compute_capability_less_than, skip_if_rocm @unittest.skipIf(not triton_available, "Triton is required but not available") @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") class TestFP8Gemm(TestCase): @skip_if_compute_capability_less_than(9.0) + @skip_if_rocm("ROCm enablement in progress") def test_gemm_split_k(self): dtype = torch.float16 qdtype = torch.float8_e4m3fn diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index 3eb9b0a2c5..6b26b948f5 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -18,6 +18,7 @@ triton_dequant_blockwise, triton_quantize_blockwise, ) +from torchao.testing.utils import skip_if_rocm SEED = 0 torch.manual_seed(SEED) @@ -82,6 +83,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): "dim1,dim2,dtype,signed,blocksize", TEST_CONFIGS, ) +@skip_if_rocm("ROCm enablement in progress") def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize): g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index 1fd60acb52..f8581b1307 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -18,9 +18,11 @@ MappingType, choose_qparams_and_quantize_affine_qqq, ) +from torchao.testing.utils import skip_if_rocm from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +@skip_if_rocm("ROCm enablement in progress") class TestMarlinQQQ(TestCase): def setUp(self): super().setUp() @@ -40,6 +42,7 @@ def setUp(self): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_marlin_qqq(self): output_ref = self.model(self.input) for group_size in [-1, 128]: @@ -61,6 +64,7 @@ def test_marlin_qqq(self): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_marlin_qqq_compile(self): model_copy = copy.deepcopy(self.model) model_copy.forward = torch.compile(model_copy.forward, fullgraph=True) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 9aeaa53664..4d685169a1 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -133,6 +133,21 @@ def forward(self, x): return x +class ModelWithLinearBias(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(512, 256, bias=True) + self.linear2 = torch.nn.Linear(256, 512, bias=True) + + def example_inputs(self): + return (torch.randn(1, 512),) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + class TestQAT(unittest.TestCase): SEED = 123 @@ -1366,6 +1381,25 @@ def test_fake_quantizer_repr(self): self.assertTrue("PerGroup" in fake_quantizer_repr) self.assertTrue("MappingType.SYMMETRIC" in fake_quantizer_repr) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_qat_linear_bias(self): + """ + Test that QAT supports linear bias. + """ + m = ModelWithLinearBias() + activation_config = FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False + ) + weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=32) + quantize_( + m, + intx_quantization_aware_training(activation_config, weight_config), + ) + example_inputs = m.example_inputs() + m(*example_inputs) + if __name__ == "__main__": unittest.main() diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index a53f47ac14..4af429940f 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -48,6 +48,7 @@ Int8WeightOnlyQuantizedLinearWeight, ) from torchao.quantization.utils import compute_error +from torchao.testing.utils import skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, @@ -819,6 +820,7 @@ def test_int4wo_cpu(self, dtype, x_dim): uintx_weight_only(dtype=torch.uint4), ], ) + @skip_if_rocm("ROCm enablement in progress") def test_workflow_e2e_numerics(self, config): """ Simple test of e2e int4_weight_only workflow, comparing numerics diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index 4da7304a24..dc4489f05e 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -15,6 +15,7 @@ ) from torchao.sparsity.marlin import inject_24, pack_to_marlin_24, unpack_from_marlin_24 from torchao.sparsity.sparse_api import apply_fake_sparsity +from torchao.testing.utils import skip_if_rocm from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -37,6 +38,7 @@ def setUp(self): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") def test_quant_sparse_marlin_layout_eager(self): apply_fake_sparsity(self.model) model_copy = copy.deepcopy(self.model) @@ -48,13 +50,13 @@ def test_quant_sparse_marlin_layout_eager(self): # Sparse + quantized quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout())) sparse_result = self.model(self.input) - assert torch.allclose( dense_result, sparse_result, atol=3e-1 ), "Results are not close" @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") def test_quant_sparse_marlin_layout_compile(self): apply_fake_sparsity(self.model) model_copy = copy.deepcopy(self.model) diff --git a/test/test_ops.py b/test/test_ops.py index b3b160e85f..076ab9ab16 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -20,6 +20,9 @@ from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) + try: import torchao.ops except RuntimeError: diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index 95175caacf..abf09cd2f9 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -183,7 +183,7 @@ def __tensor_unflatten__( def get_plain(self): from torchao.quantization.marlin_qqq import ( unpack_from_marlin_qqq, - ) # avoid circular import + ) int_data_expanded, s_group_expanded, s_channel_expanded = ( unpack_from_marlin_qqq( @@ -211,7 +211,7 @@ def from_plain( from torchao.quantization.marlin_qqq import ( const, pack_to_marlin_qqq, - ) # avoid circular import + ) assert isinstance(_layout, MarlinQQQLayout) diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py index 22763eb0c2..01d4562b7f 100644 --- a/torchao/dtypes/uintx/marlin_sparse_layout.py +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -206,7 +206,7 @@ def __tensor_unflatten__( def get_plain(self): from torchao.sparsity.marlin import ( unpack_from_marlin_24, - ) # avoid circular import + ) int_data_expanded, scales_expanded = unpack_from_marlin_24( self.int_data, @@ -231,7 +231,7 @@ def from_plain( from torchao.sparsity.marlin import ( const, pack_to_marlin_24, - ) # avoid circular import + ) assert isinstance(_layout, MarlinSparseLayout) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 4dbc556d83..65105d1f89 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -15,8 +15,6 @@ throughput speedups of up to 1.5x on 128 GPU LLaMa 3 70B pretraining jobs. # Single GPU User API -We provide three per-tensor scaling strategies: dynamic, delayed and static. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`input`), weights (`weight`) and gradients (`grad_output`). - ## float8 linear with dynamic tensorwise scaling This is the default recipe, with a good balance of performance and accuracy. @@ -114,67 +112,6 @@ for _ in range(10): optimizer.step() ``` -## float8 linear with delayed scaling - -:warning: We plan to deprecate delayed scaling in a future release, see https://github.com/pytorch/ao/issues/1680 for more details. - -This is theoretically the most performant recipe as it minimizes memory reads. - -```python -import torch -import torch.nn as nn -from torchao.float8 import ( - convert_to_float8_training, - sync_float8_amax_and_scale_history, - Float8LinearConfig, - ScalingType, - CastConfig, -) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") - -# Recommended: enable additional torchinductor passes to improve the performance of delayed scaling -torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes() - -# create model and sample input -m = nn.Sequential( - nn.Linear(2048, 4096), - nn.Linear(4096, 128), -).bfloat16().cuda() -x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16) -optimizer = torch.optim.SGD(m.parameters(), lr=0.1) - -# configure delayed scaling -config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), -) - -# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior -convert_to_float8_training(m, config=config) - -# enable torch.compile for competitive performance -m = torch.compile(m) - -# toy training loop -for _ in range(10): - optimizer.zero_grad() - y = m(x) - y.sum().backward() - - # Specific to delayed scaling: separate step to sync scales/amaxes. - # On the first call, this function also sets the `is_amax_initialized` flag to - # mark the amax and scale buffers as initialized. - # Make sure you run this after every model forward+backward pass. - # In the future, this may move to a context manager. - sync_float8_amax_and_scale_history(m) - - optimizer.step() -``` - # Multi GPU User API We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/stable/distributed.tensor.parallel.html), @@ -226,10 +163,6 @@ There are three observations we can make about the formula above: For small shapes, a combination of (2) and (3) leads to speedup < 1. For medium shapes, (1) and (3) are of similar magnitude and the speedup depends on M, K, N and framework and compiler behavior. For large shapes, (1) leads to speedup > 1. -## Scaling type vs speedup - -Delayed scaling is theoretically faster than dynamic scaling because of reduced read/write traffic requirements. Today, torch.compile has a couple of limitations (see the performance section of https://github.com/pytorch/ao/issues/556) which prevent us from reaching the optimal behavior for delayed scaling without workarounds. We have a prototype workaround (API subject to change) with the `torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes()` API to improve delayed scaling performance. - ## torch.compile behavior vs speedup There are a couple of limitations in how torch.compile generates float8 scaling and casting kernels (see the performance section of https://github.com/pytorch/ao/issues/556). As the limitations get resolved, we expect to reach improved performance. diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index 258db53be0..18ef82a507 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -6,15 +6,12 @@ # Lets define a few top level things here from torchao.float8.config import ( CastConfig, - DelayedScalingConfig, Float8GemmConfig, Float8LinearConfig, ScalingType, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) from torchao.float8.float8_tensor import ( Float8Tensor, @@ -23,11 +20,7 @@ ScaledMMConfig, ) from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp -from torchao.float8.inductor_utils import ( - _prototype_register_float8_delayed_scaling_inductor_passes, -) from torchao.float8.inference import Float8MMConfig -from torchao.float8.stateful_float8_linear import WeightWithDelayedFloat8CastTensor from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 if TORCH_VERSION_AT_LEAST_2_5: @@ -41,22 +34,17 @@ GemmInputRole, LinearMMConfig, Float8MMConfig, - WeightWithDelayedFloat8CastTensor, ] ) __all__ = [ # configuration - "DelayedScalingConfig", "ScalingType", "Float8GemmConfig", "Float8LinearConfig", "CastConfig", # top level UX "convert_to_float8_training", - "linear_requires_sync", - "sync_float8_amax_and_scale_history", "precompute_float8_dynamic_scale_for_fsdp", - "_prototype_register_float8_delayed_scaling_inductor_passes", # note: Float8Tensor and Float8Linear are not public APIs ] diff --git a/torchao/float8/config.py b/torchao/float8/config.py index fa03d55b11..d2998d890f 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -15,20 +15,14 @@ class ScalingType(enum.Enum): - DELAYED = "delayed" DYNAMIC = "dynamic" - STATIC = "static" # ScalingType.DISABLED means "skip scaling for this tensor, leave it in # its original precision. DISABLED = "disabled" def short_str(self): - if self is ScalingType.DELAYED: - return "del" - elif self is ScalingType.DYNAMIC: + if self is ScalingType.DYNAMIC: return "dyn" - elif self is ScalingType.STATIC: - return "sta" else: assert self is ScalingType.DISABLED return "dis" @@ -90,7 +84,6 @@ class CastConfig: scaling_type: ScalingType = ScalingType.DYNAMIC scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE - static_scale: Optional[torch.Tensor] = None target_dtype: Optional[torch.dtype] = None def short_str(self): @@ -98,10 +91,6 @@ def short_str(self): return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}_{dtype}" def __post_init__(self): - if self.scaling_type is ScalingType.STATIC: - assert ( - self.static_scale is not None - ), "static_scale must be specified for static scaling" if self.scaling_granularity is ScalingGranularity.AXISWISE: assert ( self.scaling_type is ScalingType.DYNAMIC @@ -111,30 +100,6 @@ def __post_init__(self): ), "must specify a 8-bit floating-point dtype" -@dataclass(frozen=True) -class DelayedScalingConfig: - """ - Configuration for delayed scaling. - - Note: for now, `history_len` values must be the same for all layers in the - model using delayed scaling. - - TODO(future): serialization for recipes - """ - - # Controls the history length of amax buffers - history_len: int = 16 - - # Controls the way to calculate current scale from amax history - # TODO(future): add other functions as needed, hardcoded or user defined - scale_fn_name: str = "max" - - def __post_init__(self): - assert ( - self.scale_fn_name == "max" - ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." - - @dataclass(frozen=True) class Float8GemmConfig: """ @@ -215,14 +180,6 @@ class Float8LinearConfig: # Per-linear configuration # - # This configuration option is deprecated and no longer has an effect. It may - # be removed in a future release. - enable_amax_init: bool = True - - # This configuration option is deprecated and no longer has an effect. It may - # be removed in a future release. - enable_pre_and_post_forward: bool = True - # If True, then uses a tensor subclass for the float8 linear module's weight that # implements pre/post-all-gather methods to do float8 all-gather with FSDP2. enable_fsdp_float8_all_gather: bool = False @@ -236,13 +193,6 @@ class Float8LinearConfig: # If True, emulation is used instead of hardware accelerated gemm emulate: bool = False - # Configuration for delayed scaling - # Note: this is actually applied per-tensor, but only using the same - # configuration for all tensors and layers in the model is currently - # supported. If in the future we add support for a more fine grained - # configuration, this field may move to per-tensor configs. - delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig() - # If the option is enabled, fp8_weight will always be re-computed in backward. # It's recommended to enable this flag when using FSDP. # Otherwise, the entire fp8_weight, instead of the sharded weight may be saved. @@ -336,16 +286,6 @@ def __post_init__(self): "When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd." ) - # Future deprecation warning for delayed scaling - if ( - self.cast_config_input.scaling_type != ScalingType.DYNAMIC - or self.cast_config_weight.scaling_type != ScalingType.DYNAMIC - or self.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC - ): - logger.warning( - "Note: delayed and static scaling will be deprecated in a future release of torchao. Please see https://github.com/pytorch/ao/issues/1680 for more details." - ) - @staticmethod def from_recipe_name( recipe_name: Union[Float8LinearRecipeName, str], diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index d822d33042..9d5cdd3242 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -64,8 +64,6 @@ class matmul_with_hp_or_float8_args(torch.autograd.Function): * if the arguments are in high precision, they are cast to float8 according to the specified config * if the arguments are in float8, we assume the cast honored the config - - Only supports dynamic scaling, does not support delayed/static scaling. """ @staticmethod @@ -259,8 +257,7 @@ class Float8Linear(torch.nn.Linear): inside of this repository. Please file an issue if you would benefit from this being a public API. - A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks - scales in way friendly to delayed scaling. + A wrapper around a `torch.nn.Linear` module which does fp8 compute. """ def __init__(self, *args, **kwargs): @@ -411,6 +408,7 @@ def from_float( # 1. weight needs to be on the correct device to create the buffers # 2. buffers need to be already created for the delayed scaling version # of the weight wrapper to be initialized + # TODO(future PR): see if we can simplify ^ now that delayed scaling is deleted if config.enable_fsdp_float8_all_gather: assert config.cast_config_weight.scaling_type is ScalingType.DYNAMIC new_mod.weight = torch.nn.Parameter( diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index 3649b741cc..db9889567f 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -6,56 +6,15 @@ import logging from typing import Callable, Optional -import torch -import torch.distributed as dist import torch.nn as nn -from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce -from torchao.float8.config import Float8LinearConfig, ScalingType +from torchao.float8.config import Float8LinearConfig from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_utils import ( - amax_history_to_scale_stack, - config_has_stateful_scaling, -) -from torchao.float8.stateful_float8_linear import StatefulFloat8Linear log = logging.getLogger(__name__) log.addHandler(logging.NullHandler()) -def linear_requires_sync(config: Float8LinearConfig): - """Returns whether the given linear_type requires sync before forward.""" - return any( - [ - config.cast_config_input.scaling_type is ScalingType.DELAYED, - config.cast_config_weight.scaling_type is ScalingType.DELAYED, - config.cast_config_grad_output.scaling_type is ScalingType.DELAYED, - ] - ) - - -def _update_history_stack( - new_amax: torch.Tensor, amax_history_stack: torch.Tensor -) -> torch.Tensor: - """ - Updates `amax_history` (the last N cur_amax values) inplace with the value - of `new_amax`. - - Args: - new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1) - amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length) - """ - assert ( - amax_history_stack.dim() == 2 - ), f"Expected amat_history_stack to be 2D, got {amax_history_stack.shape()}" - assert ( - new_amax.size(0) == amax_history_stack.size(0) - ), f"Expected new_amax to have the same size as the first dimension of amax_history_stack, got {new_amax.size(0)} and {amax_history_stack.size(0)}" - new_amax_history_stack = torch.roll(amax_history_stack, 1, dims=1) - new_amax_history_stack[:, 0] = new_amax.squeeze(-1) - amax_history_stack.copy_(new_amax_history_stack) - - def swap_linear_layers( module: nn.Module, from_float_func: Callable[[nn.Linear], nn.Linear], @@ -144,196 +103,13 @@ def convert_to_float8_training( if config is None: config = Float8LinearConfig() - if config_has_stateful_scaling(config): - from_float = lambda m: StatefulFloat8Linear.from_float( - m, - config=config, - ) - else: - from_float = lambda m: Float8Linear.from_float( - m, - config=config, - ) + from_float = lambda m: Float8Linear.from_float( + m, + config=config, + ) return swap_linear_layers( module, from_float, module_filter_fn=module_filter_fn, ) - - -def get_float8_layers(model: torch.nn.Module): - """Iterates through the model and returns all the Float8Linear layers. - Args: - model (torch.nn.Module): The model to look for Float8Linear layers in. - """ - - # Get all fp8 layers and tensors - fp8_layers = [child for child in model.modules() if isinstance(child, Float8Linear)] - if not torch.compiler.is_compiling(): - for layer in fp8_layers: - for buf in layer.buffers(): - torch._dynamo.mark_static_address(buf, guard=True) - return fp8_layers - - -@torch.no_grad() -def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) -> None: - """ - Manages the float8 amax and scale bookkeeping. In detail, it does the - following: - 1. in distributed contexts, syncs amax values across workers for activations and gradients - 2. adds the `amax` values to history - 3. calculates the scales to be used for next iteration - 4. sets the `amax_and_scale_synced` flag on the Float8Linear modules - to signal that they have been synced - - TODO(future): design the UX for this (context manager, etc) - - PERFORMANCE NOTE: - When you can, it is much more efficient to call get_float8_layers once at - the beginning of the training loop and pass the result to this function. - Because of how this interacts with torch.compile - - Args: - model (torch.nn.Module): The model to track amaxes for - fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored, - and we loop over all fp8_layers to sync and update amax scale histories. - Users can use get_float8_layers to get all fp8 layers. - """ - # TODO(future): consider adding a flag to control setting the `is_amax_initialized` - # flag only on the first iteration. - - if fp8_layers is None: - fp8_layers = get_float8_layers(model) - - if len(fp8_layers) == 0: - log.warn( - "Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers" - ) - return - - def inner_func(): - """Why do we have this inner_function? - - There are two portions of the outer sync_function that cause graph_breaks: - 1. The `get_float8_layers` call can cause graph breaks if the user did not pass - in the fp8_layers. - 2. At the end of syncing all the amaxes and scales we set the attr on the module - signaling that we have synced the amaxes and scales and the next forward can be run. - # TODO Maybe we should remove this safety check to remove the graph break? - - By having this inner function, we can ensure that although the outer function may cause graph breaks - the inner function will not. - """ - # Loop over all fp8 layers and grab the needed tensors - fp8_amax_input_tensor_list = [None] * len(fp8_layers) - fp8_amax_weight_tensor_list = [None] * len(fp8_layers) - fp8_amax_grad_output_tensor_list = [None] * len(fp8_layers) - - fp8_input_amax_history_stack = [None] * len(fp8_layers) - fp8_weight_amax_history_stack = [None] * len(fp8_layers) - fp8_grad_output_amax_history_stack = [None] * len(fp8_layers) - - input_dtypes = set() - weight_dtypes = set() - grad_output_dtypes = set() - scale_fn_recipes = set() - - for idx, child in enumerate(fp8_layers): - fp8_amax_input_tensor_list[idx] = child.fp8_amax_input - fp8_amax_weight_tensor_list[idx] = child.fp8_amax_weight - fp8_amax_grad_output_tensor_list[idx] = child.fp8_amax_grad_output - - fp8_input_amax_history_stack[idx] = child.fp8_amax_history_input - fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight - fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output - - input_dtypes.add(child.config.cast_config_input.target_dtype) - weight_dtypes.add(child.config.cast_config_weight.target_dtype) - grad_output_dtypes.add(child.config.cast_config_grad_output.target_dtype) - scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name) - - (input_dtype,) = input_dtypes - (weight_dtype,) = weight_dtypes - (grad_output_dtype,) = grad_output_dtypes - - if len(scale_fn_recipes) != 1: - raise ValueError( - f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" - ) - scale_fn_recipe = next(iter(scale_fn_recipes)) - - assert ( - len(fp8_amax_input_tensor_list) - == len(fp8_amax_weight_tensor_list) - == len(fp8_amax_grad_output_tensor_list) - ), "Mismatched lengths of amax tensors." - - if dist.is_initialized(): - all_amax_tensors = torch.cat( - fp8_amax_input_tensor_list - + fp8_amax_weight_tensor_list - + fp8_amax_grad_output_tensor_list - ) - all_reduced_amax_tensor = all_reduce( - all_amax_tensors, "MAX", list(range(dist.get_world_size())) - ) - if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor): - all_reduced_amax_tensor = all_reduced_amax_tensor.wait() - - ( - reduced_fp8_amax_input_tensor, - reduced_fp8_amax_weight_tensor, - reduced_fp8_amax_grad_output_tensor, - ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_input_tensor_list)) - - for idx, child in enumerate(fp8_layers): - child.fp8_amax_input.copy_(reduced_fp8_amax_input_tensor[idx]) - child.fp8_amax_weight.copy_(reduced_fp8_amax_weight_tensor[idx]) - child.fp8_amax_grad_output.copy_( - reduced_fp8_amax_grad_output_tensor[idx] - ) - - # We create two stacked tensor groups, one for the amax history and one for the current scales - fp8_amax_input_tensors = torch.vstack(fp8_amax_input_tensor_list) - fp8_amax_weight_tensors = torch.vstack(fp8_amax_weight_tensor_list) - fp8_amax_grad_output_tensors = torch.vstack(fp8_amax_grad_output_tensor_list) - - fp8_input_amax_history_stack = torch.vstack(fp8_input_amax_history_stack) - fp8_weight_amax_history_stack = torch.vstack(fp8_weight_amax_history_stack) - fp8_grad_output_amax_history_stack = torch.vstack( - fp8_grad_output_amax_history_stack - ) - - # Update the history stacks with the new amax values - _update_history_stack(fp8_amax_input_tensors, fp8_input_amax_history_stack) - _update_history_stack(fp8_amax_weight_tensors, fp8_weight_amax_history_stack) - _update_history_stack( - fp8_amax_grad_output_tensors, fp8_grad_output_amax_history_stack - ) - - # Calculate the new scales from the updated history stacks - new_input_scales = amax_history_to_scale_stack( - fp8_input_amax_history_stack, input_dtype, scale_fn_recipe - ) - new_weight_scales = amax_history_to_scale_stack( - fp8_weight_amax_history_stack, weight_dtype, scale_fn_recipe - ) - new_grad_output_scales = amax_history_to_scale_stack( - fp8_grad_output_amax_history_stack, grad_output_dtype, scale_fn_recipe - ) - - # Iterate through the layers and update the scales - for idx, child in enumerate(fp8_layers): - child.fp8_scale_input.copy_(new_input_scales[idx]) - child.fp8_scale_weight.copy_(new_weight_scales[idx]) - child.fp8_scale_grad_output.copy_(new_grad_output_scales[idx]) - - # This allows for the compile to succeed on the inner func and fail on the graph breaks - # at the beginning and and of syncing - inner_func() - - for child in fp8_layers: - # Set a flag to signal that initialization is done - child.is_amax_initialized = True diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index b96c7a9b58..31f2db6b4e 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -21,8 +21,6 @@ hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import ( - amax_history_to_scale, - tensor_to_amax, tensor_to_scale, ) @@ -74,72 +72,6 @@ def hp_tensor_to_float8_dynamic( ) -def hp_tensor_to_float8_delayed( - hp_tensor: torch.Tensor, - s: torch.Tensor, - float8_dtype: torch.dtype, - amax_buffer: torch.Tensor, - linear_mm_config: Optional[LinearMMConfig] = None, - gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, -) -> Float8Tensor: - """ - Given a high precision tensor `hp_tensor` and relevant metadata, scales it using - delayed scaling and returns a `Float8Tensor` of the result. Specifically: - 1. calculates max(abs(hp_tensor)) and stores the result in `amax_buffer`, inplace - 2. scales `hp_tensor` by `s` and returns the result wrapped in Float8Tensor - - Args: - hp_tensor: the tensor to convert - s: the scale to use to convert the tensor - float8_dtype: the float8 dtype to use - amax_buffer: the buffer to modify inplace with max(abs(hp_tensor)) - linear_mm_config: Defines the configuration for the scaled_mm for - the 3 fwd/bwd gemms of linear - gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in - the 3 fwd/bwd gemms of linear - """ - amax_buffer.fill_(tensor_to_amax(hp_tensor)) - return hp_tensor_and_scale_to_float8( - hp_tensor, - s, - float8_dtype, - linear_mm_config, - gemm_input_role, - ) - - -def hp_tensor_to_float8_static( - hp_tensor: torch.Tensor, - scale: torch.Tensor, - float8_dtype: torch.dtype, - linear_mm_config: LinearMMConfig, - gemm_input_role: GemmInputRole = GemmInputRole.INPUT, -) -> Float8Tensor: - """ - Given a high precision tensor `hp_tensor` and a scale, - scales `hp_tensor` returns a `Float8Tensor` of the result. - - Args: - hp_tensor: the tensor to convert - scale: the scale to use - float8_dtype: the float8 dtype to use - linear_mm_config: Defines the configuration for the scaled_mm for - the 3 fwd/bwd gemms of linear - gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in - the 3 fwd/bwd gemms of linear - """ - if tensor_already_casted_to_fp8(hp_tensor): - return hp_tensor - - return hp_tensor_and_scale_to_float8( - hp_tensor, - scale, - float8_dtype, - linear_mm_config, - gemm_input_role, - ) - - def get_maybe_axiswise_dim( axiswise_dim: int, scaling_granularity: ScalingGranularity, @@ -155,95 +87,6 @@ def get_maybe_axiswise_dim( return None -def _maybe_initialize_amaxes_scales_for_float8_cast( - x, - cur_amax, - amax_history, - scale, - scale_fn_name, - float8_dtype, - is_initialized, - reduce_amax, -): - """ - If x is about to be cast to `float8` and the amax buffers are not initialized, - initializes them inplace. - """ - if is_initialized: - return - with torch.no_grad(): - # Note: we need to enable distributed reduction here in order - # to match numerics between single GPU and multi GPU code for - # activations and gradients - new_amax = tensor_to_amax(x, reduce_amax=reduce_amax) - cur_amax.fill_(new_amax) - amax_history[0] = new_amax - new_scale = amax_history_to_scale(amax_history, float8_dtype, scale_fn_name) - scale.copy_(new_scale) - - -@torch._dynamo.allow_in_graph -class NoopFwToFloat8BwDelayed(torch.autograd.Function): - """ - Forward: no-op - Backward: convert to float8_e5m2 with delayed scaling, initialize if needed - """ - - @staticmethod - def forward( - ctx, - tensor, - fp8_amax_grad_output, - fp8_amax_history_grad_output, - fp8_scale_grad_output, - scale_fn_name, - is_amax_initialized, - linear_mm_config: LinearMMConfig, - target_dtype: torch.dtype, - ): - ctx.save_for_backward( - fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output - ) - ctx.scale_fn_name = scale_fn_name - ctx.is_amax_initialized = is_amax_initialized - ctx.linear_mm_config = linear_mm_config - ctx.target_dtype = target_dtype - return tensor - - @staticmethod - def backward(ctx, go): - ( - fp8_amax_grad_output, - fp8_amax_history_grad_output, - fp8_scale_grad_output, - ) = ctx.saved_tensors - scale_fn_name = ctx.scale_fn_name - is_amax_initialized = ctx.is_amax_initialized - - _maybe_initialize_amaxes_scales_for_float8_cast( - go, - fp8_amax_grad_output, - fp8_amax_history_grad_output, - fp8_scale_grad_output, - scale_fn_name, - ctx.target_dtype, - is_amax_initialized, - reduce_amax=True, - ) - - fp8_amax_grad_output.fill_(tensor_to_amax(go)) - - res = hp_tensor_and_scale_to_float8( - go, - fp8_scale_grad_output, - ctx.target_dtype, - ctx.linear_mm_config, - GemmInputRole.GRAD_OUTPUT, - ) - empty_grads = None, None, None, None, None, None, None - return res, *empty_grads - - @torch._dynamo.allow_in_graph class NoopFwToFloat8BwDynamic(torch.autograd.Function): """ @@ -275,38 +118,3 @@ def backward(ctx, gradY): GemmInputRole.GRAD_OUTPUT, ) return fp8_tensor, None, None - - -@torch._dynamo.allow_in_graph -class NoopFwToFloat8BwStatic(torch.autograd.Function): - """ - Forward: no-op - Backward: convert to float8_e5m2 with static scaling - """ - - @staticmethod - def forward( - ctx, - tensor, - scale, - linear_mm_config: LinearMMConfig, - target_dtype: torch.dtype, - ): - ctx.save_for_backward(scale) - ctx.linear_mm_config = linear_mm_config - ctx.target_dtype = target_dtype - return tensor - - @staticmethod - def backward(ctx, gradY): - if tensor_already_casted_to_fp8(gradY): - return gradY, None, None, None - (gradY_scale,) = ctx.saved_tensors - fp8_tensor = hp_tensor_and_scale_to_float8( - gradY, - gradY_scale, - ctx.target_dtype, - ctx.linear_mm_config, - GemmInputRole.GRAD_OUTPUT, - ) - return fp8_tensor, None, None, None diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index a52b38b6bf..abc74e3ff6 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -27,8 +27,7 @@ def _float8_linear_supports_float8_allgather(m): - # TODO(future): add support for delayed scaling for activations - # and gradients + # TODO(future PR): also gate this by granularity return ( m.scaling_type_input == ScalingType.DYNAMIC and m.scaling_type_grad_output == ScalingType.DYNAMIC diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 926b97edb8..625fb29235 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -4,13 +4,13 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Literal, Optional, Tuple, Union +from typing import Iterable, Optional, Tuple, Union import torch import torch.distributed as dist from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce -from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType +from torchao.float8.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -53,44 +53,6 @@ def amax_to_scale( return res -@torch.no_grad() -def amax_history_to_scale( - amax_history: torch.Tensor, - float8_dtype: torch.Tensor, - history_to_scale_fn_type: Literal["max"], -): - """Takes in a history of amax values and returns a scale tensor. - Args: - amax_history: A tensor containing the history of amax values. - float8_dtype: The float8 dtype. - history_to_scale_fn_type: The type of function to use to convert the history to a scale. - """ - if history_to_scale_fn_type == "max": - amax = torch.max(amax_history) - return amax_to_scale(amax, float8_dtype) - raise NotImplementedError() - - -@torch.no_grad() -def amax_history_to_scale_stack( - amax_history: torch.Tensor, - float8_dtype: torch.dtype, - history_to_scale_fn_type: Literal["max"], -) -> torch.Tensor: - """Takes in a stack of amax_history tensors and returns a scale tensor. - Args: - amax_history: A 2D tensor containing a stack of amax histories. - float8_dtype: The float8 dtype. - history_to_scale_fn_type: The type of function to use to convert the history to a scale. - """ - if history_to_scale_fn_type == "max": - amax_stack = torch.max(amax_history, dim=1).values - return amax_to_scale(amax_stack, float8_dtype) - raise NotImplementedError( - f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}" - ) - - @torch.no_grad() def tensor_to_amax( x: torch.Tensor, @@ -274,17 +236,6 @@ def pad_tensor_for_matmul( return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1)) -def config_has_stateful_scaling(config: Float8LinearConfig) -> bool: - """ - Returns True if `config` has any delayed or static scaling, and False otherwise. - """ - return ( - config.cast_config_input.scaling_type != ScalingType.DYNAMIC - or config.cast_config_weight.scaling_type != ScalingType.DYNAMIC - or config.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC - ) - - def _round_scale_down_to_power_of_2(scale: torch.Tensor): assert scale.dtype == torch.float32, "scale must be float32 tensor" return torch.exp2(torch.floor(torch.log2(scale))) diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index f246879a7c..7b24dc2b53 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -13,8 +13,6 @@ from torch._prims_common import suggest_memory_format from torchao.float8.float8_scaling_utils import ( - _maybe_initialize_amaxes_scales_for_float8_cast, - hp_tensor_to_float8_delayed, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import ( @@ -39,14 +37,8 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: """ from torch.distributed._tensor import DTensor - from torchao.float8.config import ScalingType from torchao.float8.float8_linear import Float8Linear - if any( - isinstance(m, Float8Linear) and m.scaling_type_weight is ScalingType.DELAYED - for m in module.modules() - ): - raise NotImplementedError("Only supports dynamic scaling") float8_linears: List[Float8Linear] = [ m for m in module.modules() @@ -274,331 +266,3 @@ def fsdp_post_all_gather( self._linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, ), (data,) - - -class WeightWithDelayedFloat8CastTensor(torch.Tensor): - @staticmethod - def __new__( - cls, - tensor: torch.Tensor, - amax_buffer: torch.Tensor, - amax_history_buffer: torch.Tensor, - scale_buffer: torch.Tensor, - linear_mm_config: LinearMMConfig, - dtype: torch.dtype, - is_amax_initialized: bool, - ): - return torch.Tensor._make_wrapper_subclass( - cls, - tensor.size(), - strides=tensor.stride(), - storage_offset=tensor.storage_offset(), - memory_format=suggest_memory_format(tensor), - dtype=tensor.dtype, - layout=tensor.layout, - device=tensor.device, - pin_memory=tensor.is_pinned(), - requires_grad=tensor.requires_grad, - ) - - def __init__( - self, - tensor: torch.Tensor, - amax_buffer: torch.Tensor, - amax_history_buffer: torch.Tensor, - scale_buffer: torch.Tensor, - linear_mm_config: LinearMMConfig, - dtype: torch.dtype, - is_amax_initialized: bool, - ): - self._tensor = tensor - self._amax_buffer = amax_buffer - self._amax_history_buffer = amax_history_buffer - self._scale_buffer = scale_buffer - self._linear_mm_config = linear_mm_config - self._dtype = dtype - - # Note: is_amax_initialized is not a buffer to avoid data dependent - # control flow visible to dynamo - # TODO(future PR): add serialization for this flag - self.is_amax_initialized = is_amax_initialized - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - if func == torch.ops.aten.detach.default: - return WeightWithDelayedFloat8CastTensor( - args[0]._tensor, - args[0]._amax_buffer, - args[0]._amax_history_buffer, - args[0]._scale_buffer, - args[0]._linear_mm_config, - args[0]._dtype, - args[0].is_amax_initialized, - ) - mm_config: Optional[LinearMMConfig] = None - dtype: Optional[torch.dtype] = None - amax_buffer: Optional[torch.Tensor] = None - amax_history_buffer: Optional[torch.Tensor] = None - scale_buffer: Optional[torch.Tensor] = None - is_amax_initialized: Optional[bool] = None - - def unwrap(t): - nonlocal mm_config - if mm_config is None: - mm_config = t._linear_mm_config - else: - assert t._linear_mm_config == mm_config - nonlocal dtype - if dtype is None: - dtype = t._dtype - else: - assert t._dtype == dtype - nonlocal amax_buffer - if amax_buffer is None: - amax_buffer = t._amax_buffer - nonlocal amax_history_buffer - if amax_history_buffer is None: - amax_history_buffer = t._amax_history_buffer - nonlocal scale_buffer - if scale_buffer is None: - scale_buffer = t._scale_buffer - nonlocal is_amax_initialized - if is_amax_initialized is None: - is_amax_initialized = t.is_amax_initialized - return t._tensor - - args, kwargs = pytree.tree_map_only( - WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {}) - ) - out = func(*args, **kwargs) - if func not in _ops_to_preserve_subclass: - return out - return pytree.tree_map_only( - torch.Tensor, - lambda x: WeightWithDelayedFloat8CastTensor( - x, - amax_buffer, - amax_history_buffer, - scale_buffer, - mm_config, - dtype, - is_amax_initialized, - ), - out, - ) - - def __tensor_flatten__(self): - return ( - [ - "_tensor", - "_amax_buffer", - "_amax_history_buffer", - "_scale_buffer", - ], - { - "mm_config": self._linear_mm_config, - "dtype": self._dtype, - "is_amax_initialized": self.is_amax_initialized, - }, - ) - - @staticmethod - def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): - return WeightWithDelayedFloat8CastTensor( - inner_tensors["_tensor"], - inner_tensors["_amax_buffer"], - inner_tensors["_amax_history_buffer"], - inner_tensors["_scale_buffer"], - metadata["mm_config"], - metadata["dtype"], - metadata["is_amax_initialized"], - ) - - def __repr__(self): - return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config}, dtype={self._dtype})" - - def fsdp_pre_all_gather(self, mesh): - # initialize if needed - # TODO(before land): ensure settings are consistent between Float8Linear and here - if not self.is_amax_initialized: - _maybe_initialize_amaxes_scales_for_float8_cast( - self._tensor, - self._amax_buffer, - self._amax_history_buffer, - self._scale_buffer, - "max", # TODO(before land): read this from parent - self._dtype, - self.is_amax_initialized, - reduce_amax=True, - ) - self.is_amax_initialized = True - - float8_tensor = hp_tensor_to_float8_delayed( - self._tensor, - self._scale_buffer, - self._dtype, - self._amax_buffer, - self._linear_mm_config, - GemmInputRole.WEIGHT, - ) - return (float8_tensor._data,), (float8_tensor._scale,) - - def fsdp_post_all_gather( - self, - all_gather_outputs: Tuple[torch.Tensor, ...], - metadata: Any, - param_dtype: torch.dtype, - *, - out: Optional[torch.Tensor] = None, - ): - (data,) = all_gather_outputs - (scale,) = metadata - if out is not None: - assert isinstance(out, Float8Tensor), f"{type(out)}" - out._scale = scale - return - return Float8Tensor( - data, - scale, - param_dtype, - self._linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ), (data,) - - -class WeightWithStaticFloat8CastTensor(torch.Tensor): - @staticmethod - def __new__( - cls, - tensor: torch.Tensor, - static_scale: torch.Tensor, - linear_mm_config: LinearMMConfig, - dtype: torch.dtype, - ): - return torch.Tensor._make_wrapper_subclass( - cls, - tensor.size(), - strides=tensor.stride(), - storage_offset=tensor.storage_offset(), - memory_format=suggest_memory_format(tensor), - dtype=tensor.dtype, - layout=tensor.layout, - device=tensor.device, - pin_memory=tensor.is_pinned(), - requires_grad=tensor.requires_grad, - ) - - def __init__( - self, - tensor: torch.Tensor, - static_scale: torch.Tensor, - linear_mm_config: LinearMMConfig, - dtype: torch.dtype, - ): - self._tensor = tensor - self._static_scale = static_scale - self._linear_mm_config = linear_mm_config - self._dtype = dtype - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - if func == torch.ops.aten.detach.default: - return WeightWithStaticFloat8CastTensor( - args[0]._tensor, - args[0]._static_scale, - args[0]._linear_mm_config, - args[0]._dtype, - ) - static_scale: Optional[torch.Tensor] = None - mm_config: Optional[LinearMMConfig] = None - dtype: Optional[torch.dtype] = None - - def unwrap(t): - nonlocal static_scale - if static_scale is None: - static_scale = t._static_scale - nonlocal mm_config - if mm_config is None: - mm_config = t._linear_mm_config - else: - assert t._linear_mm_config == mm_config - nonlocal dtype - if dtype is None: - dtype = t._dtype - else: - assert t._dtype == dtype - return t._tensor - - args, kwargs = pytree.tree_map_only( - WeightWithStaticFloat8CastTensor, unwrap, (args, kwargs or {}) - ) - out = func(*args, **kwargs) - if func not in _ops_to_preserve_subclass: - return out - return pytree.tree_map_only( - torch.Tensor, - lambda x: WeightWithStaticFloat8CastTensor( - x, static_scale, mm_config, dtype - ), - out, - ) - - def __tensor_flatten__(self): - return ["_tensor", "_static_scale"], { - "mm_config": self._linear_mm_config, - "dtype": self._dtype, - } - - @staticmethod - def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - return WeightWithStaticFloat8CastTensor( - inner_tensors["_tensor"], - inner_tensors["_static_scale"], - flatten_spec["mm_config"], - flatten_spec["dtype"], - ) - - def __repr__(self): - return f"WeightWithStaticFloat8CastTensor(tensor={self._tensor}, static_scale={self._static_scale}, linear_mm_config={self._linear_mm_config}, dtype={self.dtype})" - - def fsdp_pre_all_gather(self, mesh): - float8_tensor = hp_tensor_and_scale_to_float8( - self._tensor, - self._static_scale, - self._dtype, - self._linear_mm_config, - GemmInputRole.WEIGHT, - ) - return (float8_tensor._data,), (float8_tensor._scale,) - - def fsdp_post_all_gather( - self, - all_gather_outputs: Tuple[torch.Tensor, ...], - metadata: Any, - param_dtype: torch.dtype, - *, - out: Optional[torch.Tensor] = None, - ): - (data,) = all_gather_outputs - (scale,) = metadata - if out is not None: - from torch.distributed._tensor import DTensor - - if isinstance(out, Float8Tensor): - out._scale = scale - elif isinstance(out, DTensor) and isinstance( - out._local_tensor, Float8Tensor - ): - out._local_tensor._scale = scale - else: - raise RuntimeError( - f"out must be a Float8Tensor or DTensor(_local_tensor=Float8Tensor), but got {out}" - ) - return - return Float8Tensor( - data, - scale, - param_dtype, - self._linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ), (data,) diff --git a/torchao/float8/inductor_utils.py b/torchao/float8/inductor_utils.py deleted file mode 100644 index 3e86202536..0000000000 --- a/torchao/float8/inductor_utils.py +++ /dev/null @@ -1,126 +0,0 @@ -import functools -import inspect -import traceback -from collections import deque - -import torch - - -def amax_with_scaling_pattern(tensor_x_inp, scale_x, fp8_dtype, fp8_max): - tensor_x = tensor_x_inp.to(torch.float32) * scale_x - tensor_x = tensor_x.clamp(min=-1 * fp8_max, max=fp8_max) - tensor_x = tensor_x.to(fp8_dtype) - amax = torch.max(torch.abs(tensor_x_inp)) - return (tensor_x, amax) - - -def amax_with_scaling_tiled_replacement(tensor_x_inp, scale_x, fp8_dtype, fp8_max): - tensor_x = tensor_x_inp.to(torch.float32) * scale_x - tensor_x = tensor_x.clamp(min=-1 * fp8_max, max=fp8_max) - tensor_x = tensor_x.to(fp8_dtype) - amax_1 = torch.max(torch.abs(tensor_x_inp), dim=-1).values - amax = torch.max(amax_1) - return (tensor_x, amax) - - -# The amax_with_scaling_pattern will also match dynamic scaling cases, we want to avoid that. -# `scale_x` of delayed scaling comes from the previous iteration, instead of from `tensor_x_inp`. -# We check that `scale_x` is not a dependency of `tensor_x_inp` -def fp8_delayed_scaling_extra_check(match): - scale_x_inputs = deque([match.kwargs["scale_x"]]) - max_num_node_to_check = 20 # Don't traverse too many nodes - current_num_node = 0 - while len(scale_x_inputs) > 0 and current_num_node < max_num_node_to_check: - current_node = scale_x_inputs.popleft() - for n in current_node.all_input_nodes: - if n == match.kwargs["tensor_x_inp"]: - return False - scale_x_inputs.append(n) - current_num_node += 1 - return True - - -def partialize_and_update_signature(func, **kwargs): - """ - Equivalent to functools.partial but also updates the signature on returned function - """ - original_sig = inspect.signature(func) - parameters = original_sig.parameters - - new_parameters = { - key: value for key, value in parameters.items() if key not in kwargs - } - new_sig = inspect.Signature(parameters=list(new_parameters.values())) - - partial_func = functools.partial(func, **kwargs) - - def wrapper(*args, **kwargs): - return partial_func(*args, **kwargs) - - wrapper.__signature__ = new_sig # type: ignore[attr-defined] - wrapper.__name__ = func.__name__ - - return wrapper - - -def register_fp8_delayed_scaling_patterns_inner(): - from torch._inductor.fx_passes.post_grad import ( - pass_patterns as post_grad_patterns_all, - ) - from torch._inductor.pattern_matcher import fwd_only, register_replacement - - post_grad_patterns = post_grad_patterns_all[1] # medium priority - - if torch.cuda.is_available(): - for fp8_dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, - torch.float8_e4m3fnuz, - torch.float8_e5m2fnuz, - ]: - # torch.float16 has the same pattern as torch.bfloat16, because they both needs `tensor_x_inp.to(torch.float32)` - for dtype in [torch.float32, torch.bfloat16]: - device = "cuda" - register_replacement( - partialize_and_update_signature( - amax_with_scaling_pattern, - fp8_dtype=fp8_dtype, - fp8_max=torch.finfo(fp8_dtype).max, - ), - partialize_and_update_signature( - amax_with_scaling_tiled_replacement, - fp8_dtype=fp8_dtype, - fp8_max=torch.finfo(fp8_dtype).max, - ), - [ - torch.tensor((16, 16), device=device, dtype=dtype), - torch.tensor(2.0, device=device, dtype=torch.float32), - ], - fwd_only, - post_grad_patterns, - extra_check=fp8_delayed_scaling_extra_check, - ) - - -""" -This a short-term workaround of the delayed scaling performance issue. -It explicitly replaces `max(x)` with `max(max(x, dim=-1))`, enabling the fusion of amax scaling factor calculation and fp8 casting. - -Usage: - To use this solution, add the following line at the beginning of your user code: - torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes() -""" - - -def _prototype_register_float8_delayed_scaling_inductor_passes() -> None: - # To make the fp8 delayed scaling pattern work, we need a fix pr from inductor, https://github.com/pytorch/pytorch/pull/139321 - # Will throw the error if the pattern registration did not work, up to user to decide what to do with it - try: - register_fp8_delayed_scaling_patterns_inner() - except AssertionError as e: - if "assert pattern_repr not in _seen_patterns" in traceback.format_exc(): - print( - f"Caught duplicated patterns in register_fp8_delayed_scaling_patterns: {traceback.format_exc()}", - "\nPlease update your pytorch dependency to the latest main branch to fix it.\n", - ) - raise e diff --git a/torchao/float8/roofline_utils.py b/torchao/float8/roofline_utils.py index 16cf847fe2..58c84c5fa6 100644 --- a/torchao/float8/roofline_utils.py +++ b/torchao/float8/roofline_utils.py @@ -38,78 +38,30 @@ def get_tensor_memory_traffic_bytes( # assumes input bf16, output f8 numel = dim0 * dim1 - if scaling_type == "dynamic": - # x_bf16 = ... - # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp - # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs - # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8 - - if fuse_with_prev: - kernel_1_rw = 0 - else: - # kernel 1: read numel, write 0 (assume size(tmp) ~ 0) - kernel_1_rw = BYTES_PER_EL_BF16 * numel - - # kernel 3: read in bf16, write twice in float8 (row-major and col-major) - kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel - - if model_torch_compile_limitations: - # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...) - # has an extra memory read of the input in fp8 - # context: https://github.com/pytorch/pytorch/issues/130015 - tc_adjustment = numel * BYTES_PER_EL_FLOAT8 - else: - tc_adjustment = 0 - - return kernel_1_rw + kernel_3_rw + tc_adjustment + assert scaling_type == "dynamic", "unsupported" + # x_bf16 = ... + # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp + # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs + # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8 + + if fuse_with_prev: + kernel_1_rw = 0 + else: + # kernel 1: read numel, write 0 (assume size(tmp) ~ 0) + kernel_1_rw = BYTES_PER_EL_BF16 * numel + + # kernel 3: read in bf16, write twice in float8 (row-major and col-major) + kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel + if model_torch_compile_limitations: + # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...) + # has an extra memory read of the input in fp8 + # context: https://github.com/pytorch/pytorch/issues/130015 + tc_adjustment = numel * BYTES_PER_EL_FLOAT8 else: - assert scaling_type == "delayed", "unsupported" - # x_bf16 = ... - # kernel 1: x_bf16 -> max_abs_stage_1_and_to_float8 -> x_float8, tmp - # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs - # kernel 3 (not modeled): scale -> reciprocal -> inv_scale - - if fuse_with_prev: - kernel_1_r = 0 - else: - kernel_1_r = numel * BYTES_PER_EL_BF16 - # write twice: once in row major, once in col-major - kernel_1_w = numel * BYTES_PER_EL_FLOAT8 * 2 - - if model_torch_compile_limitations: - # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...) - # has an extra memory read of the input in fp8 - # context: https://github.com/pytorch/pytorch/issues/130015 - tc_adjustment = numel * BYTES_PER_EL_FLOAT8 - - # https://github.com/pytorch/pytorch/issues/128063 - # instead of - # kernel 1: x_bf16 -> max(abs(x)), x_fp8 - # kernel 2: not modeled - # kernel 3: not modeled - # we get - # kernel 1: x_bf16 -> max(abs(x)) - # reads: same as before - # writes: 0 - # ... - # kernel 4: x_bf16, scale -> x_fp8 - # reads: numel * BYTES_PER_EL_BF16 - # writes: 2 * numel * BYTES_PER_EL_FLOAT8 - # Note that assuming worst case, this issue brings the memory - # traffic for delayed scaling to be equal to that of dynamic scaling. - tc_adjustment += ( - # subtract writes from kernel 1 - -1 * 2 * numel * BYTES_PER_EL_FLOAT8 - # add reads for kernel 4 - + numel * BYTES_PER_EL_BF16 - # add writes for kernel 4 - + 2 * numel * BYTES_PER_EL_FLOAT8 - ) - else: - tc_adjustment = 0 - - return kernel_1_r + kernel_1_w + tc_adjustment + tc_adjustment = 0 + + return kernel_1_rw + kernel_3_rw + tc_adjustment def get_gemm_time_sympy(M, K, N, dtype): @@ -131,9 +83,9 @@ def get_float8_mem_sympy( scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", ): - assert scaling_type_input in ("dynamic", "delayed"), "unsupported" - assert scaling_type_weight in ("dynamic", "delayed"), "unsupported" - assert scaling_type_grad_output in ("dynamic", "delayed"), "unsupported" + assert scaling_type_input in ("dynamic",), "unsupported" + assert scaling_type_weight in ("dynamic",), "unsupported" + assert scaling_type_grad_output in ("dynamic",), "unsupported" # there are three gemms in the fwd/bwd of a linear: # @@ -207,27 +159,12 @@ def get_float8_mem_sympy( if scaling_type_input == "dynamic": # second stage of max-abs reduction num_extra_kernels += 1 - elif scaling_type_input == "delayed": - # second stage of max-abs reduction - num_extra_kernels += 1 - # reciprocal of scale - num_extra_kernels += 1 if scaling_type_weight == "dynamic": # second stage of max-abs reduction num_extra_kernels += 1 - elif scaling_type_weight == "delayed": - # second stage of max-abs reduction - num_extra_kernels += 1 - # reciprocal of scale - num_extra_kernels += 1 if scaling_type_grad_output == "dynamic": # second stage of max-abs reduction num_extra_kernels += 1 - elif scaling_type_grad_output == "delayed": - # second stage of max-abs reduction - num_extra_kernels += 1 - # reciprocal of scale - num_extra_kernels += 1 extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC diff --git a/torchao/float8/stateful_float8_linear.py b/torchao/float8/stateful_float8_linear.py deleted file mode 100644 index ac01803e0b..0000000000 --- a/torchao/float8/stateful_float8_linear.py +++ /dev/null @@ -1,439 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -""" -Stateful version of Float8Linear, created to keep Float8Linear simple and -only require code readers to read the stateful code if they care about delayed -or static scaling. -""" - -from typing import Optional - -import torch -import torch.utils.checkpoint as checkpoint - -from torchao.float8.config import Float8LinearConfig, ScalingType -from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 -from torchao.float8.float8_linear import ( - Float8Linear, -) -from torchao.float8.float8_scaling_utils import ( - NoopFwToFloat8BwDelayed, - NoopFwToFloat8BwDynamic, - NoopFwToFloat8BwStatic, - _maybe_initialize_amaxes_scales_for_float8_cast, - hp_tensor_to_float8_delayed, - hp_tensor_to_float8_dynamic, - hp_tensor_to_float8_static, -) -from torchao.float8.float8_tensor import ( - GemmInputRole, - hp_tensor_and_scale_to_float8, -) -from torchao.float8.float8_utils import ( - tensor_to_amax, - tensor_to_scale, -) -from torchao.float8.fsdp_utils import ( - WeightWithDelayedFloat8CastTensor, - WeightWithDynamicFloat8CastTensor, - WeightWithStaticFloat8CastTensor, -) - - -@torch._dynamo.allow_in_graph -class manual_float8_matmul_with_args_in_float8(torch.autograd.Function): - """ - Like torch.matmul, but with the arguments in float8 - - Note: this function requires all arguments to already be Float8Tensor objects, - which only supports tensorwise scaling granularity. The reason we didn't just make this - function support axiswise scaling granularity is because that would need very - careful testing of delayed scaling, as delayed scaling modifies buffers inplace. - - In the future we'll probably have to unify, just postponing that until a future PR. - """ - - @staticmethod - def forward( - ctx, - input_fp8, - weight_fp8_t, - ): - ctx.save_for_backward(input_fp8, weight_fp8_t) - # the reshapes are needed in order to make the shapes compatible with - # torch.mm - orig_shape = input_fp8.shape - input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1]) - res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t) - res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) - return res_bits - - @staticmethod - def backward(ctx, grad_output_fp8): - input_fp8, weight_fp8_t = ctx.saved_tensors - - # the reshapes are needed in order to make the shapes compatible with - # torch.mm - grad_output_fp8_orig_shape = grad_output_fp8.shape - grad_output_fp8_reshaped = grad_output_fp8.reshape( - -1, grad_output_fp8_orig_shape[-1] - ) - - # calculate grad_input - grad_input = torch.mm( - grad_output_fp8_reshaped, - weight_fp8_t.t(), - ) - grad_input = grad_input.reshape( - *grad_output_fp8_orig_shape[:-1], grad_input.shape[-1] - ) - - input_fp8_orig_shape = input_fp8.shape - input_fp8_reshaped = input_fp8.reshape(-1, input_fp8_orig_shape[-1]) - - # calculate grad_weight - # Note: the variant below is slightly faster on LLaMa 3 8B pretraining - # compared to than calculating `grad_weight_t = input_fp8_t @ grad_output_fp8_reshaped` - grad_weight = torch.mm( - grad_output_fp8_reshaped.t(), - input_fp8_reshaped, - ) - - return grad_input, grad_weight.t() - - -class StatefulFloat8Linear(Float8Linear): - def __init__(self, *args, **kwargs): - # Amax scales should always be kept as float32. - self.always_float32_buffers = set() - - super().__init__(*args, **kwargs) - - # Convenience flag to skip code related to delayed scaling - self.has_any_delayed_scaling = ( - self.scaling_type_input is ScalingType.DELAYED - or self.scaling_type_weight is ScalingType.DELAYED - or self.scaling_type_grad_output is ScalingType.DELAYED - ) - - self.create_buffers() - - # Note: is_amax_initialized is not a buffer to avoid data dependent - # control flow visible to dynamo - # TODO(future PR): add serialization for this flag - self.is_amax_initialized = not self.config.enable_amax_init - - # pre_forward and post_forward are currently broken with FSDP - # and torch.compile, this option can disable them - # Note that when using `self.config.enable_pre_and_post_forward = False`, - # it's recommended to also set `self.config.enable_amax_init = False`. - # Otherwise, the amax buffer would never be marked as initialized and - # would be initialized in every iteration. - self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward - - def create_buffers(self): - # Default values for history buffers, see above TODO - history_len = self.config.delayed_scaling_config.history_len - device = self.weight.device - default_input = torch.finfo(self.config.cast_config_input.target_dtype).max - default_weight = torch.finfo(self.config.cast_config_weight.target_dtype).max - default_grad_output = torch.finfo( - self.config.cast_config_grad_output.target_dtype - ).max - - # Note: for now, create all the buffers if any are needed, to postpone - # the work to make the scale and amax syncing and history calculation - # handle a heterogeneous setup. We can do that work later if benchmarks - # show it is worth doing. - if self.has_any_delayed_scaling: - self.register_always_float32_buffer( - "fp8_amax_input", torch.tensor([default_input], device=device) - ) - self.register_always_float32_buffer( - "fp8_amax_history_input", torch.zeros(history_len, device=device) - ) - self.register_always_float32_buffer( - "fp8_scale_input", torch.tensor([1.0], device=device) - ) - self.register_always_float32_buffer( - "fp8_amax_weight", torch.tensor([default_weight], device=device) - ) - self.register_always_float32_buffer( - "fp8_amax_history_weight", torch.zeros(history_len, device=device) - ) - self.register_always_float32_buffer( - "fp8_scale_weight", torch.tensor([1.0], device=device) - ) - self.register_always_float32_buffer( - "fp8_amax_grad_output", - torch.tensor([default_grad_output], device=device), - ) - self.register_always_float32_buffer( - "fp8_amax_history_grad_output", torch.zeros(history_len, device=device) - ) - self.register_always_float32_buffer( - "fp8_scale_grad_output", torch.tensor([1.0], device=device) - ) - - if self.config.cast_config_input.static_scale is not None: - self.register_always_float32_buffer( - "fp8_static_scale_input", - self.config.cast_config_input.static_scale.to(device), - ) - if self.config.cast_config_weight.static_scale is not None: - self.register_always_float32_buffer( - "fp8_static_scale_weight", - self.config.cast_config_weight.static_scale.to(device), - ) - if self.config.cast_config_grad_output.static_scale is not None: - self.register_always_float32_buffer( - "fp8_static_scale_grad_output", - self.config.cast_config_grad_output.static_scale.to(device), - ) - - def register_always_float32_buffer( - self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True - ) -> None: - self.register_buffer(name=name, tensor=tensor, persistent=persistent) - self.always_float32_buffers.add(name) - - def _apply(self, fn, recurse=True): - ret = super()._apply(fn, recurse) - self.convert_amax_buffer_to_float32() - return ret - - def convert_amax_buffer_to_float32(self): - for key in self.always_float32_buffers: - if self._buffers[key] is not None: - self._buffers[key] = self._buffers[key].to(torch.float32) - - def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor: - is_amax_initialized = self.is_amax_initialized - # Duplicate the autocast logic for F.linear, so that the output - # of our module has the right original precision - if torch.is_autocast_enabled(): - # For now, hardcode to GPU's autocast dtype - # if we need CPU support in the future, we can add it - autocast_dtype = torch.get_autocast_gpu_dtype() - input = input.to(autocast_dtype) - - if tensor_already_casted_to_fp8(input): - input_fp8 = input - elif self.scaling_type_input is ScalingType.DELAYED: - scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - _maybe_initialize_amaxes_scales_for_float8_cast( - input, - self.fp8_amax_input, - self.fp8_amax_history_input, - self.fp8_scale_input, - scale_fn_name, - self.config.cast_config_input.target_dtype, - is_amax_initialized, - reduce_amax=True, - ) - input_fp8 = hp_tensor_to_float8_delayed( - input, - self.fp8_scale_input, - self.config.cast_config_input.target_dtype, - self.fp8_amax_input, - linear_mm_config=self.linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - ) - elif self.scaling_type_input is ScalingType.DYNAMIC: - input_fp8 = hp_tensor_to_float8_dynamic( - input, - self.config.cast_config_input.target_dtype, - self.linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - ) - else: - assert self.scaling_type_input is ScalingType.STATIC - input_fp8 = hp_tensor_to_float8_static( - input, - self.fp8_static_scale_input, - self.config.cast_config_input.target_dtype, - self.linear_mm_config, - ) - - return input_fp8 - - def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]: - if tensor_already_casted_to_fp8(weight): - return None - if self.scaling_type_weight is ScalingType.DELAYED: - scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - _maybe_initialize_amaxes_scales_for_float8_cast( - weight, - self.fp8_amax_weight, - self.fp8_amax_history_weight, - self.fp8_scale_weight, - scale_fn_name, - self.config.cast_config_weight.target_dtype, - self.is_amax_initialized, - reduce_amax=True, - ) - self.fp8_amax_weight.fill_(tensor_to_amax(weight)) - return self.fp8_scale_weight - elif self.scaling_type_weight is ScalingType.DYNAMIC: - return tensor_to_scale(weight, self.config.cast_config_weight.target_dtype) - else: - assert self.scaling_type_weight is ScalingType.STATIC - return self.fp8_static_scale_weight - - def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: - if self.scaling_type_grad_output is ScalingType.DELAYED: - scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - output = NoopFwToFloat8BwDelayed.apply( - output, - self.fp8_amax_grad_output, - self.fp8_amax_history_grad_output, - self.fp8_scale_grad_output, - scale_fn_name, - self.is_amax_initialized, - self.linear_mm_config, - self.config.cast_config_grad_output.target_dtype, - ) - elif self.scaling_type_grad_output is ScalingType.DYNAMIC: - output = NoopFwToFloat8BwDynamic.apply( - output, - self.linear_mm_config, - self.config.cast_config_grad_output.target_dtype, - ) - else: - assert self.scaling_type_grad_output is ScalingType.STATIC - output = NoopFwToFloat8BwStatic.apply( - output, - self.fp8_static_scale_grad_output, - self.linear_mm_config, - self.config.cast_config_grad_output.target_dtype, - ) - return output - - def cast_weight_to_float8_t( - self, - weight: torch.Tensor, - weight_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if tensor_already_casted_to_fp8(weight): - return weight.t() - weight_fp8 = hp_tensor_and_scale_to_float8( - weight, - weight_scale, - self.config.cast_config_weight.target_dtype, - self.linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ) - return weight_fp8.t() - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.has_any_delayed_scaling: - self.float8_pre_forward(input) - - input_fp8 = self.cast_input_to_float8(input) - # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, - # weight_scale should be saved. - weight_scale = self.get_weight_scale(self.weight) - - if self.config.force_recompute_fp8_weight_in_bwd: - weight_fp8_t = checkpoint.checkpoint( - self.cast_weight_to_float8_t, - self.weight, - weight_scale, - ) - else: - weight_fp8_t = self.cast_weight_to_float8_t(self.weight, weight_scale) - - output = manual_float8_matmul_with_args_in_float8.apply(input_fp8, weight_fp8_t) - - # Cast grad_output to float8_e5m2 during backward - output = self.cast_output_to_float8_in_bw(output) - - if self.bias is not None: - output = output + self.bias.to(output.dtype) - - if self.has_any_delayed_scaling: - self.float8_post_forward() - return output - - def float8_pre_forward(self, input): - # TODO(future PR): deprecate these functions and the corresponding - # config setting - if not self.enable_pre_and_post_forward: - return - - def float8_post_forward(self): - # TODO(future PR): deprecate these functions and the corresponding - # config setting - if not self.enable_pre_and_post_forward: - return - - @classmethod - def from_float( - cls, - mod, - config: Optional[Float8LinearConfig] = None, - ): - """ - Create an nn.Linear with fp8 compute from a regular nn.Linear - - Args: - mod (torch.nn.Linear): nn.Linear to convert - config (Optional[Float8LinearConfig]): configuration for conversion to float8 - """ - if config is None: - config = Float8LinearConfig() - with torch.device("meta"): - new_mod = cls( - mod.in_features, - mod.out_features, - bias=False, - config=config, - ) - new_mod.weight = mod.weight - new_mod.bias = mod.bias - # need to create buffers again when moving from meta device to - # real device - new_mod.create_buffers() - - # If FSDP float8 all-gather is on, wrap the weight in a float8-aware - # tensor subclass. This must happen last because: - # 1. weight needs to be on the correct device to create the buffers - # 2. buffers need to be already created for the delayed scaling version - # of the weight wrapper to be initialized - if config.enable_fsdp_float8_all_gather: - if config.cast_config_weight.scaling_type is ScalingType.DYNAMIC: - new_mod.weight = torch.nn.Parameter( - WeightWithDynamicFloat8CastTensor( - new_mod.weight, - new_mod.linear_mm_config, - new_mod.config.cast_config_weight.target_dtype, - ) - ) - elif config.cast_config_weight.scaling_type is ScalingType.DELAYED: - new_mod.weight = torch.nn.Parameter( - WeightWithDelayedFloat8CastTensor( - new_mod.weight, - new_mod.fp8_amax_weight, - new_mod.fp8_amax_history_weight, - new_mod.fp8_scale_weight, - new_mod.linear_mm_config, - new_mod.config.cast_config_weight.target_dtype, - new_mod.is_amax_initialized, - ) - ) - else: - assert config.cast_config_weight.scaling_type is ScalingType.STATIC - new_mod.weight = torch.nn.Parameter( - WeightWithStaticFloat8CastTensor( - new_mod.weight, - new_mod.fp8_static_scale_weight, - new_mod.linear_mm_config, - new_mod.config.cast_config_weight.target_dtype, - ) - ) - - return new_mod diff --git a/torchao/ops.py b/torchao/ops.py index bba2a054fc..a3aee761b9 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1,3 +1,5 @@ +import functools + import torch from torch import Tensor @@ -606,6 +608,27 @@ def _( return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) +@functools.lru_cache() +def _get_dtypes(): + """TODO: when e8m0 is hardened and major release lets remove uint8 support""" + if hasattr(torch, "float8_e8m0fnu"): + return (torch.uint8, torch.float8_e8m0fnu) + return (torch.uint8,) + + +def _check_scale_dtypes(A_scale, B_scale): + allowed_dtypes = _get_dtypes() + + torch._check( + A_scale.dtype in allowed_dtypes, + lambda: f"A_scale tensor must be uint8 or float8_e8m0fnu, got {A_scale.dtype}", + ) + torch._check( + B_scale.dtype in allowed_dtypes, + lambda: f"B_scale tensor must be uint8 or float8_e8m0fnu, got {B_scale.dtype}", + ) + + def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): """Defines a matmul between two fp8 tensors w/ MX scales in E8MO and returns a bf16 tensor. @@ -625,25 +648,7 @@ def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): MXN bf16 Tensor """ - torch._check( - A.dtype == torch.float8_e4m3fn, - lambda: f"Input tensor A must be float8_e4m3fn, got {A.dtype}", - ) - torch._check( - B.dtype == torch.float8_e4m3fn, - lambda: f"Input tensor B must be float8_e4m3fn, got {B.dtype}", - ) - - # TODO - Once e8m0 dtype is added to core udpate - # Check scale tensors are uint8 - torch._check( - A_scale.dtype == torch.uint8, - lambda: f"A_scale tensor must be uint8, got {A_scale.dtype}", - ) - torch._check( - B_scale.dtype == torch.uint8, - lambda: f"B_scale tensor must be uint8, got {B_scale.dtype}", - ) + _check_scale_dtypes(A_scale, B_scale) return torch.ops.torchao.mx_fp8_bf16.default(A, B, A_scale, B_scale) @@ -674,6 +679,7 @@ def mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): MXN bf16 Tensor """ + _check_scale_dtypes(A_scale, B_scale) return torch.ops.torchao.mx_fp4_bf16.default(A, B, A_scale, B_scale) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index d511d2614d..de7369c1cf 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Optional +from typing import Any, Optional, Union import torch @@ -27,6 +27,14 @@ class MXGemmKernelChoice(Enum): # TODO(future PR): add cuBLAS here once we land pytorch/pytorch support +# Pre-made recipes for common configurations +class MXLinearRecipeName(Enum): + MXFP8_EMULATED = "mxfp8_emulated" + MXFP8_CUTLASS = "mxfp8_cutlass" + MXFP4_EMULATED = "mxfp4_emulated" + MXFP4_CUTLASS = "mxfp4_cutlass" + + @dataclass class MXLinearConfig: # block size for scaling, default is 32 to match @@ -78,3 +86,31 @@ def __post_init__(self): assert ( self.elem_dtype_grad_output_override is None ), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels" + + @staticmethod + def from_recipe_name( + recipe_name: Union[MXLinearRecipeName, str], + ) -> "MXLinearConfig": + """ + Input: `MXLinearRecipeName` value, or a string representing a `MXLinearRecipeName` value + Output: a `MXLinearConfig` configured to implement the specified recipe + """ + if type(recipe_name) == str: + valid_names = [n.value for n in MXLinearRecipeName] + assert ( + recipe_name in valid_names + ), f"recipe_name {recipe_name} not in valid names {valid_names}" + recipe_name = MXLinearRecipeName(recipe_name) + + if recipe_name is MXLinearRecipeName.MXFP8_EMULATED: + return MXLinearConfig() + elif recipe_name is MXLinearRecipeName.MXFP8_CUTLASS: + return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUTLASS) + elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED: + return MXLinearConfig(elem_dtype=DTYPE_FP4) + elif recipe_name is MXLinearRecipeName.MXFP8_CUTLASS: + return MXLinearConfig( + elem_dtype=DTYPE_FP4, gemm_kernel_choice=MXGemmKernelChoice.CUTLASS + ) + else: + raise AssertionError(f"unknown recipe_name {recipe_name}") diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 16e61e0653..ddc2bcd665 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -74,6 +74,7 @@ def mx_mm(aten_op, args, kwargs=None): # real MX gemm backed by torchao's CUTLASS kernels M, K, N = a.shape[0], a.shape[1], b.shape[1] assert b._data.t().is_contiguous() + # TODO(future PR): use block_size instead of hardcoding 32 a_scale = a._scale_e8m0.view(M, K // 32) b_scale = b._scale_e8m0.view(N, K // 32) a_scale_block = to_blocked(a_scale) diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 4cdc26109d..8b186f82d6 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import torch -import torch.nn.functional as F Tensor = torch.Tensor @@ -31,14 +30,23 @@ def to_blocked(input_matrix) -> Tensor: n_row_blocks = ceil_div(rows, 128) n_col_blocks = ceil_div(cols, 4) - # Pad out and view as tiles of (128, 4) - padded = F.pad(input_matrix, (0, -cols % 4, 0, -rows % 128)) - blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + # Calculate the padded shape + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), + device=input_matrix.device, + dtype=input_matrix.dtype, + ) + padded[:rows, :cols] = input_matrix - # rearrange all tiles + # Rearrange the blocks + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) - # Layout rearranged tiles according to second pic return rearranged.flatten() diff --git a/torchao/prototype/quantized_training/int8_mm.py b/torchao/prototype/quantized_training/int8_mm.py index 7de6620d65..faaa6e463e 100644 --- a/torchao/prototype/quantized_training/int8_mm.py +++ b/torchao/prototype/quantized_training/int8_mm.py @@ -54,6 +54,7 @@ @triton.autotune(configs=configs, key=["M", "N", "K", "stride_ak", "stride_bk"]) +@triton.heuristics({"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0}) @triton.jit def _scaled_int8_mm_kernel( A_ptr, @@ -176,7 +177,6 @@ def scaled_int8_mm_cuda(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tens *A.stride(), *B.stride(), *C.stride(), - EVEN_K=K % 2 == 0, COL_SCALE_SCALAR=col_scale.numel() == 1, ) return C diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index cb7c8d0481..b278e22b3b 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -759,7 +759,7 @@ def _create_quantized_state_dict( if self.padding_allowed: import torch.nn.functional as F - logging.warn( + logging.warning( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) padded_in_features = find_multiple(in_features, 1024) @@ -767,7 +767,7 @@ def _create_quantized_state_dict( weight, pad=(0, padded_in_features - in_features) ) else: - logging.warn( + logging.warning( f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + "and that groupsize and inner_k_tiles*16 evenly divide into it" ) @@ -1147,7 +1147,7 @@ def _create_quantized_state_dict( if self.padding_allowed: import torch.nn.functional as F - logging.warn( + logging.warning( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) padded_in_features = find_multiple(in_features, 1024) @@ -1155,7 +1155,7 @@ def _create_quantized_state_dict( weight, pad=(0, padded_in_features - in_features) ) else: - logging.warn( + logging.warning( f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + "and that groupsize and inner_k_tiles*16 evenly divide into it" ) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index a0e2ea2cc4..d2b6e0c016 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -348,6 +348,8 @@ Marlin QQQ is an optimized GPU kernel that supports W4A8 mixed precision GEMM. F ### Gemlite Triton Int4 and Int8 quantization using the [Gemlite Triton](https://github.com/mobiusml/gemlite) kernels. You can try it out with the `quantize_` api as above alongside the constructor `gemlite_uintx_weight_only`. An example can be found in `torchao/_models/llama/generate.py`. +Note: we test on gemlite 0.4.1, but should be able to use any version after that, we'd recommend to use the latest release to get the most recent performance improvements. + ### UINTx Quantization We're trying to develop kernels for low bit quantization for intx quantization formats. While the current performance is not ideal, we're hoping to continue to iterate on these kernels to improve their performance. diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index fafda68d58..716634fe9d 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -75,9 +75,6 @@ def __init__( *args, **kwargs, ) - if bias: - raise NotImplementedError("bias not supported yet") - # initialize activation fake quantizer if activation_config is not None: self.activation_fake_quantizer = FakeQuantizer(activation_config) @@ -103,17 +100,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: w = self.weight_fake_quantizer(self.weight) else: w = self.weight - return F.linear(x, w) + return F.linear(x, w, self.bias) def to_linear(self) -> torch.nn.Linear: new_linear = torch.nn.Linear( - self.in_features, self.out_features, self.bias, device=self.weight.device + self.in_features, + self.out_features, + self.bias is not None, + device=self.weight.device, ) # In distributed training, the model may be instantiated # on the meta device, in which case there is no need to # copy the weights, and doing so will result in an error if self.weight.device != torch.device("meta"): new_linear.weight = self.weight + new_linear.bias = self.bias return new_linear @classmethod @@ -126,7 +127,7 @@ def from_linear( new_linear = FakeQuantizedLinear( mod.in_features, mod.out_features, - mod.bias, + mod.bias is not None, activation_config=activation_config, weight_config=weight_config, device=mod.weight.device, @@ -136,6 +137,7 @@ def from_linear( # copy the weights, and doing so will result in an error if mod.weight.device != torch.device("meta"): new_linear.weight = mod.weight + new_linear.bias = mod.bias return new_linear diff --git a/torchao/testing/float8/fsdp2_utils.py b/torchao/testing/float8/fsdp2_utils.py index a059b4d2a9..31a5cf8db0 100644 --- a/torchao/testing/float8/fsdp2_utils.py +++ b/torchao/testing/float8/fsdp2_utils.py @@ -8,10 +8,6 @@ Float8LinearConfig, ScalingType, ) -from torchao.float8.float8_linear_utils import ( - linear_requires_sync, - sync_float8_amax_and_scale_history, -) from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp @@ -38,9 +34,6 @@ def check_parity_no_mp( dist.all_reduce(param.grad) param.grad.div_(dist.get_world_size()) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(model) - optim.step() if ( model is fsdp_model @@ -82,7 +75,6 @@ def check_parity_bf16_mp( param_bf16.grad.div_(dist.get_world_size()) param_fp32.grad = param_bf16.grad.float() param_bf16.grad = None - # TODO(future): add amax syncing once delayed scaling is supported optim.step() for param_fp32, param_bf16 in zip( ref_model.parameters(), ref_model_bf16.parameters() diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/float8/test_utils.py index 7b8ac121b6..2da34f53ed 100644 --- a/torchao/testing/float8/test_utils.py +++ b/torchao/testing/float8/test_utils.py @@ -1,9 +1,6 @@ -import torch - from torchao.float8.config import ( CastConfig, Float8LinearConfig, - ScalingType, ) @@ -13,32 +10,14 @@ def get_test_float8_linear_config( scaling_type_grad_output, emulate: bool, ): - static_scale_one = torch.tensor([1.0], device="cuda") - - if scaling_type_input is ScalingType.STATIC: - static_scale_input = static_scale_one - else: - static_scale_input = None - if scaling_type_weight is ScalingType.STATIC: - static_scale_weight = static_scale_one - else: - static_scale_weight = None - if scaling_type_grad_output is ScalingType.STATIC: - static_scale_grad_output = static_scale_one - else: - static_scale_grad_output = None - cast_config_input = CastConfig( scaling_type=scaling_type_input, - static_scale=static_scale_input, ) cast_config_weight = CastConfig( scaling_type=scaling_type_weight, - static_scale=static_scale_weight, ) cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, - static_scale=static_scale_grad_output, ) config = Float8LinearConfig( diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index d88241783f..02d151cdb4 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -14,7 +14,7 @@ from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx from torchao.quantization import int8_weight_only, quantize_ from torchao.quantization.quant_primitives import MappingType -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, get_compute_capability """ How to use: @@ -41,6 +41,50 @@ class MyTestCase(TorchAOBasicTestCase): """ +def skip_if_compute_capability_less_than(min_capability): + import unittest + + def decorator(test_func): + def wrapper(*args, **kwargs): + if get_compute_capability() < min_capability: + raise unittest.SkipTest( + f"Compute capability is less than {min_capability}" + ) + return test_func(*args, **kwargs) + + return wrapper + + return decorator + + +def skip_if_rocm(message=None): + """Decorator to skip tests on ROCm platform with custom message. + + Args: + message (str, optional): Additional information about why the test is skipped. + """ + import pytest + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if torch.version.hip is not None: + skip_message = "Skipping the test in ROCm" + if message: + skip_message += f": {message}" + pytest.skip(skip_message) + return func(*args, **kwargs) + + return wrapper + + # Handle both @skip_if_rocm and @skip_if_rocm() syntax + if callable(message): + func = message + message = None + return decorator(func) + return decorator + + # copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389 def copy_tests(my_cls, other_cls, suffix, test_failures=None, xfail_prop=None): # noqa: B902 for name, value in my_cls.__dict__.items(): diff --git a/torchao/utils.py b/torchao/utils.py index 13b59c2e81..2a67f8a9c9 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -15,7 +15,6 @@ "profiler_runner", "get_available_devices", "get_compute_capability", - "skip_if_compute_capability_less_than", "benchmark_torch_function_in_microseconds", "find_multiple", "_register_custom_op", @@ -145,22 +144,6 @@ def get_compute_capability(): return 0.0 -def skip_if_compute_capability_less_than(min_capability): - import unittest - - def decorator(test_func): - def wrapper(*args, **kwargs): - if get_compute_capability() < min_capability: - raise unittest.SkipTest( - f"Compute capability is less than {min_capability}" - ) - return test_func(*args, **kwargs) - - return wrapper - - return decorator - - def compute_max_diff(output: torch.Tensor, output_ref: torch.Tensor) -> torch.Tensor: return torch.mean(torch.abs(output - output_ref)) / torch.mean( torch.abs(output_ref) @@ -626,7 +609,7 @@ def _torch_version_at_least(min_version): def is_MI300(): if torch.cuda.is_available() and torch.version.hip: mxArchName = ["gfx940", "gfx941", "gfx942"] - archName = torch.cuda.get_device_properties().gcnArchName + archName = torch.cuda.get_device_properties(0).gcnArchName for arch in mxArchName: if arch in archName: return True