Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vkuzo committed Feb 26, 2025
2 parents 051bbcb + 7d87946 commit faaddfe
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

# This is a convenience script to profile fwd+bwd of individual layers with
# float8 training or mx training on a single GPU.

import copy
import functools
import io
Expand Down Expand Up @@ -38,12 +41,14 @@

from torchao.float8.config import (
Float8LinearConfig,
ScalingType,
)
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
)
from torchao.testing.float8.test_utils import get_test_float8_linear_config
from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.prototype.mx_formats.utils import to_blocked

# don't truncate long kernel names
pd.options.display.max_colwidth = 100
Expand Down Expand Up @@ -257,7 +262,6 @@ def profile_function(
# set up AC for max(abs(tensor))
# context: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts
ops_to_save = [
torch.ops.aten.abs.default,
torch.ops.aten.max.default,
]

Expand All @@ -275,50 +279,56 @@ def policy_fn(ctx, op, *args, **kwargs):
def main(
profile_path_prefix: pathlib.Path,
compile: bool = True,
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
recipe_name: Optional[str] = None,
float8_recipe_name: Optional[str] = None,
mx_recipe_name: Optional[str] = None,
model_type: str = "linear",
dtype_filter: str = "both",
add_inductor_metadata_to_trace: bool = True,
experiment_filter: str = "both",
add_inductor_metadata_to_trace: bool = False,
enable_activation_checkpointing: bool = False,
mode_filter: str = "fwd_bwd",
forward_only: bool = False,
):
assert model_type in (
"linear",
"ln_linear",
"norm_ffn_norm",
"norm_ffn_norm_small",
), "unsupported"
assert dtype_filter in ("both", "float8", "bfloat16")

scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)

if recipe_name is None:
config = get_test_float8_linear_config(
scaling_type_input,
scaling_type_weight,
scaling_type_grad_output,
emulate=False,
assert experiment_filter in (
"both",
"lowp",
"ref",
), "experiment_filter must be one of `both`, `lowp`, `ref`"
assert (
mode_filter
in (
"fwd_bwd",
"fwd",
"cast_only",
"cast_with_to_blocked",
)
elif recipe_name is not None:
config = Float8LinearConfig.from_recipe_name(recipe_name)

scaling_repr = "_".join(
[
s.short_str()
for s in (scaling_type_input, scaling_type_weight, scaling_type_grad_output)
]
)
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`"
if mode_filter == "cast_only":
assert experiment_filter == "lowp", "unsupported"

assert not (
float8_recipe_name is not None and mx_recipe_name is not None
), "either float8_recipe_name or mx_recipe_name can be specified, but not both"

if float8_recipe_name is None and mx_recipe_name is None:
config = Float8LinearConfig()
elif float8_recipe_name is not None:
config = Float8LinearConfig.from_recipe_name(float8_recipe_name)
elif mx_recipe_name is not None:
config = MXLinearConfig.from_recipe_name(mx_recipe_name)

print(f"Compile is set to | {compile}")
print(f"model_type is set to | {model_type}")
print(f"scaling_repr is set to | {scaling_repr}")
print(
f"enable_activation_checkpointing is set to {enable_activation_checkpointing}"
)
print(f"mode_filter is set to {mode_filter}")
print(f"config: {config}")

device = "cuda"
ref_dtype = torch.bfloat16
Expand Down Expand Up @@ -359,36 +369,74 @@ def main(

m_ref = m_ref.to(device).to(ref_dtype)

m_float8 = copy.deepcopy(m_ref)
convert_to_float8_training(m_float8, config=config)
# get gradient shape
with torch.no_grad():
_ = m_ref(input_tensor)
grad_output = torch.ones_like(_)

m_lowp = copy.deepcopy(m_ref)
if mx_recipe_name is None:
convert_to_float8_training(m_lowp, config=config)
else:
swap_linear_with_mx_linear(m_lowp, config=config)

# this function is only used for cast_only
to_mx_func = MXTensor.to_mx

# this function is used for cast_with_to_blocked
def cast_with_to_blocked(x_hp):
x_mx = MXTensor.to_mx(
x_hp,
config.elem_dtype,
config.block_size,
gemm_kernel_choice=config.gemm_kernel_choice,
)
m, k = x_hp.shape
scale_blocked = to_blocked(x_mx._scale_e8m0.reshape(m, k // config.block_size))
return x_mx._data, scale_blocked

print("m_ref", m_ref)
print("m_lowp", m_lowp)
print("input_tensor.shape", input_tensor.shape)
print("grad_output.shape", grad_output.shape)
print()

def ref_forw_backward(x):
assert mode_filter not in ("cast_only", "cast_with_to_blocked"), "unsupported"
if enable_activation_checkpointing:
out = checkpoint(m_ref, x, use_reentrant=False, context_fn=context_fn)
else:
out = m_ref(x)
out.sum().backward()
if mode_filter == "fwd_bwd":
out.backward(grad_output)

def lowp_forw_backward_wrapper(x):
if mode_filter == "cast_only":
# just cast and return early
_input_tensor_mx = to_mx_func(
input_tensor,
config.elem_dtype,
config.block_size,
gemm_kernel_choice=config.gemm_kernel_choice,
)
return
elif mode_filter == "cast_with_to_blocked":
_input_tensor_mx, scale = cast_with_to_blocked(input_tensor)
return

def float8_forw(x):
if enable_activation_checkpointing:
out = checkpoint(m_float8, x, use_reentrant=False, context_fn=context_fn)
out = checkpoint(m_lowp, x, use_reentrant=False, context_fn=context_fn)
else:
out = m_float8(x)
return out

def float8_forw_backward_wrapper(x):
# TODO(future PR): this wrapper is for delayed scaling, we can clean it
# up now that delayed scaling is deprecated.
out = float8_forw(x)

# out.sum().backward() is also not torch.compile fullgraph
# friendly
with record_function("backward"):
out.sum().backward()
out = m_lowp(x)
if mode_filter == "fwd_bwd":
with record_function("backward"):
out.backward(grad_output)

if compile:
m_ref = torch.compile(m_ref, fullgraph=True)
float8_forw = torch.compile(float8_forw, fullgraph=True)
m_lowp = torch.compile(m_lowp, fullgraph=True)
to_mx_func = torch.compile(to_mx_func, fullgraph=True)
cast_with_to_blocked = torch.compile(cast_with_to_blocked, fullgraph=True)

# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
# to populate triton kernel bandwidth further down in the script
Expand All @@ -398,15 +446,21 @@ def float8_forw_backward_wrapper(x):
else:
f = io.StringIO()
context = redirect_stdout(f)

# if we are skipping forward, enable torch.no_grad()
maybe_no_grad_context = (
torch.no_grad() if mode_filter != "fwd_bwd" else nullcontext()
)

try:
with context:
with context, maybe_no_grad_context:
profile_iters = 5
ref_times, float8_times = None, None
ref_times, lowp_times = None, None
data = []

num_leaf_tensors = 1 + len(list(m_ref.parameters()))

if dtype_filter != "float8":
if experiment_filter != "lowp":
# Profile Reference Model
print("profiling ref")
ref_trace_suffix = f"_{model_type}_ref_compile_{compile}.json"
Expand Down Expand Up @@ -452,50 +506,46 @@ def float8_forw_backward_wrapper(x):
]
)

if dtype_filter != "bfloat16":
# Profile Float8 Model
print("profiling float8")
float8_trace_suffix = (
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
)
float8_log_suffix = (
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.txt"
)
trace_float8_path = profile_path_prefix + float8_trace_suffix
log_float8_path = profile_path_prefix + float8_log_suffix
trace_float8_modified_path = trace_float8_path.replace(
if experiment_filter != "ref":
# Profile lowp Model
print("profiling lowp")
lowp_trace_suffix = f"_{model_type}_lowp_compile_{compile}.json"
lowp_log_suffix = f"_{model_type}_lowp_compile_{compile}.txt"
trace_lowp_path = profile_path_prefix + lowp_trace_suffix
log_lowp_path = profile_path_prefix + lowp_log_suffix
trace_lowp_modified_path = trace_lowp_path.replace(
".json", "_modified.json"
)
profile_config = ProfileConfig(
trace_float8_path,
log_float8_path,
trace_float8_modified_path,
float8_trace_suffix,
trace_lowp_path,
log_lowp_path,
trace_lowp_modified_path,
lowp_trace_suffix,
iters=profile_iters,
warmup_iters=2,
sync=True,
)
p = profile_function(
profile_config,
float8_forw_backward_wrapper,
lowp_forw_backward_wrapper,
add_inductor_metadata_to_trace,
input_tensor,
)
print(f"saved profiling trace to {trace_float8_path}")
print(f"saved profiling trace to {trace_lowp_path}")
if add_inductor_metadata_to_trace:
print(f"saved torch logs to {log_float8_path}")
print(f"saved modified trace to {trace_float8_modified_path}")
float8_times = profiler_output_to_filtered_time_by_kernel_name(
print(f"saved torch logs to {log_lowp_path}")
print(f"saved modified trace to {trace_lowp_modified_path}")
lowp_times = profiler_output_to_filtered_time_by_kernel_name(
p, profile_iters, num_leaf_tensors
)
total_time_ms = (
sum(v for v in float8_times.values()) / 1e3 / profile_iters
sum(v for v in lowp_times.values()) / 1e3 / profile_iters
)
for k, v in float8_times.items():
for k, v in lowp_times.items():
v_ms = v / 1e3 / profile_iters
data.append(
[
"1_float8",
"1_lowp",
k,
kernel_name_to_category(k),
v / 1e3 / profile_iters,
Expand All @@ -509,6 +559,7 @@ def float8_forw_backward_wrapper(x):
# print the redirected stdout back to regular stdout
print(f.getvalue())

# TODO(future PR): this seems to no longer work, fix it or delete it
if os.environ.get("TORCHINDUCTOR_PROFILE", "") != "":
# populate the triton kernel bandwidth
for line in f.getvalue().split("\n"):
Expand Down Expand Up @@ -546,13 +597,13 @@ def float8_forw_backward_wrapper(x):
fill_value=0,
margins=True,
)
# drop last row, which has totals across ref + float8 which does not make sense
# drop last row, which has totals across ref + lowp which does not make sense
df_p = df_p[:-1]
df_p = df_p.transpose()

if dtype_filter == "both":
df_p["f8_div_ref"] = df_p["1_float8"] / df_p["0_ref"]
df_p["ref_div_f8"] = df_p["0_ref"] / df_p["1_float8"]
if experiment_filter == "both":
df_p["lowp_div_ref"] = df_p["1_lowp"] / df_p["0_ref"]
df_p["ref_div_lowp"] = df_p["0_ref"] / df_p["1_lowp"]

print("\nSummary of time (ms) by kernel category\n\n", df_p)

Expand Down
35 changes: 9 additions & 26 deletions benchmarks/float8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,6 @@ def profiler_output_to_filtered_time_by_kernel_name(
# forward pass sum
assert e.count == num_iter, f"unexpected number of iter for {e.key}"
continue
elif e.key == "aten::fill_":
# filling the forward pass sum with 1.0
assert e.count == num_iter, f"unexpected number of iter for {e.key}"
continue
elif e.key == "aten::copy_":
# copying 1.0 from grad_out of `sum` to grad_out of next op
assert e.count == num_iter, f"unexpected number of iter for {e.key}"
continue
elif e.key == "aten::add_":
# accumulating gradients into leaf tensors
assert e.count == (
Expand Down Expand Up @@ -110,25 +102,16 @@ def profiler_output_to_gpu_time_for_key(prof, key):

def kernel_name_to_category(k):
# number prefix is for easy sorting
if k in ("aten::mm", "aten::addmm", "aten::_scaled_mm"):
return "0_gemm"
elif (
# max(abs(tensor))
("abs" in k and "max" in k)
or
# casting pointwise to float8
("clamp" in k)
or
# things related to scaled_mm
("scaled_mm" in k)
or
# syncing amaxes and scales
("roll" in k)
if k in (
"aten::mm",
"aten::addmm",
"aten::_scaled_mm",
"torchao::mx_fp8_bf16",
"torchao::mx_fp4_bf16",
):
# note: the above filter is approximate and will give false
# positives if model code contains other code to abs/max/clamp
return "1_f8_overhead"
return "2_other"
return "0_gemm"
else:
return "1_other"


def parse_bw_and_kernel_name(line):
Expand Down
2 changes: 0 additions & 2 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,8 +1243,6 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
y_wo, (code,) = run_and_get_code(m_c, x)
sqnr = compute_error(y_ref, y_wo)
self.assertGreaterEqual(sqnr, 38)
if device == "cuda":
self.assertTrue("mixed_mm" in code, f"got code: {code}")

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
Loading

0 comments on commit faaddfe

Please sign in to comment.