Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into bench_structure
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Feb 26, 2025
2 parents 8b7291c + 98c4e2e commit a750f7c
Show file tree
Hide file tree
Showing 60 changed files with 527 additions and 2,478 deletions.
49 changes: 49 additions & 0 deletions .github/workflows/regression_test_rocm.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: Run Regression Tests on ROCm

on:
push:
branches:
- main
tags:
- ciflow/rocm/*

concurrency:
group: regression_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
cancel-in-progress: true

env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}

jobs:
test-nightly:
strategy:
fail-fast: false
matrix:
include:
- name: ROCM Nightly
runs-on: linux.rocm.gpu.torchao
torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/rocm6.3'
gpu-arch-type: "rocm"
gpu-arch-version: "6.3"

permissions:
id-token: write
contents: read
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
timeout: 120
no-sudo: ${{ matrix.gpu-arch-type == 'rocm' }}
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
submodules: recursive
script: |
conda create -n venv python=3.9 -y
conda activate venv
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
pip install -r dev-requirements.txt
pip install .
export CONDA=$(dirname $(dirname $(which conda)))
export LD_LIBRARY_PATH=$CONDA/lib/:$LD_LIBRARY_PATH
pytest test --verbose -s
53 changes: 13 additions & 40 deletions benchmarks/float8/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
ScalingType,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.float8.float8_tensor import ScaledMMConfig

# estimating TOPs for matmuls in fp32, fp16, fp8
Expand Down Expand Up @@ -122,39 +118,18 @@ def main(
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)
cast_config_input = CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
cast_config_weight = CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
cast_config_grad_output = CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
Expand Down Expand Up @@ -185,7 +160,7 @@ def main(
copy.deepcopy(linear_ref),
config=config,
)
scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}"
scaling_repr = linear_float8.extra_repr()

if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
Expand All @@ -196,8 +171,6 @@ def main(
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()

def float8_forw_backward():
if linear_requires_sync(config):
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

def n_times(n, fn, *args, **kwargs):
Expand Down
180 changes: 0 additions & 180 deletions benchmarks/float8/bench_multi_gpu.py

This file was deleted.

Loading

0 comments on commit a750f7c

Please sign in to comment.