-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding CCL Async test cases to TG nightly and bug fix (#16700)
### Overview - [x] Adding CCL async test cases to TG nightly - [x] bug fixes for all gather (core assignment, semaphore reset logic) ### Known Issues Several issues are exposed in CCLs by these test cases. These failing ones are commented out for now: - [ ] #16699 ### Checklist - [x] All Post commit: https://github.com/tenstorrent/tt-metal/actions/runs/12771284956 - [x] TG nightly: https://github.com/tenstorrent/tt-metal/actions/runs/12771306345 - [x] TG post commit: https://github.com/tenstorrent/tt-metal/actions/runs/12771296063 - [x] T3K post commit and nightly: https://github.com/tenstorrent/tt-metal/actions/runs/12756885713
- Loading branch information
1 parent
a0cf894
commit 095d101
Showing
8 changed files
with
502 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
330 changes: 330 additions & 0 deletions
330
tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,330 @@ | ||
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
import pytest | ||
from loguru import logger | ||
import ttnn | ||
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc | ||
from models.utility_functions import skip_for_grayskull | ||
from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import ( | ||
create_and_load_sub_device_manager_with_fabric_interface, | ||
teardown_fabric_interface, | ||
create_global_semaphore_with_same_address, | ||
) | ||
|
||
from tests.ttnn.unit_tests.operations.ccl.test_all_gather_TG_post_commit import ( | ||
run_line_all_gather_on_TG_with_mesh_tensor_along_rows, | ||
) | ||
|
||
from tests.ttnn.unit_tests.operations.ccl.test_new_all_gather import ( | ||
run_all_gather_impl, | ||
) | ||
|
||
|
||
# Enumerate the post-commit cases explicitly | ||
@skip_for_grayskull("Requires eth connected devices to run") | ||
@pytest.mark.parametrize( | ||
"num_devices, num_links", | ||
[(4, 1)], | ||
# [(4, 3)], Multi-links fails https://github.com/tenstorrent/tt-metal/issues/16699 | ||
) | ||
@pytest.mark.parametrize( | ||
"input_dtype", | ||
[ | ||
ttnn.bfloat16, | ||
ttnn.bfloat8_b, | ||
], | ||
) | ||
@pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR]) | ||
@pytest.mark.parametrize( | ||
"tensor_mem_layout,per_chip_output_shape, dim, input_shard_shape,shard_grid,layout", | ||
( | ||
# LLama | ||
( | ||
ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
(1, 1, 32, 1024 * 4), | ||
3, | ||
(32, 32), | ||
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), | ||
ttnn.TILE_LAYOUT, | ||
), | ||
( | ||
ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
(4, 1, 32, 1280), | ||
0, | ||
(32, 32), | ||
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 4))}), | ||
ttnn.TILE_LAYOUT, | ||
), | ||
), | ||
) | ||
@pytest.mark.parametrize("replication_factor", [8]) | ||
@pytest.mark.parametrize("enable_async", [True]) | ||
@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) | ||
def test_line_all_gather_sharded_on_TG_rows_post_commit( | ||
mesh_device, | ||
num_devices, | ||
per_chip_output_shape, | ||
input_shard_shape, | ||
shard_grid, | ||
shard_grid_orientation, | ||
tensor_mem_layout, | ||
dim, | ||
num_links, | ||
input_dtype, | ||
layout, | ||
use_program_cache, | ||
function_level_defaults, | ||
enable_async, | ||
replication_factor, | ||
num_iters=1, | ||
): | ||
if len(mesh_device.get_devices()) != 32: | ||
pytest.skip("Not TG!") | ||
if input_dtype == ttnn.bfloat16 and per_chip_output_shape == (1, 1, 32, 1024 * 4): | ||
pytest.skip("Skipped due to hang Issue #16699") | ||
input_shard_spec = ttnn.ShardSpec( | ||
shard_grid, | ||
input_shard_shape, | ||
shard_grid_orientation, | ||
) | ||
run_line_all_gather_on_TG_with_mesh_tensor_along_rows( | ||
mesh_device, | ||
num_devices, | ||
per_chip_output_shape, | ||
tensor_mem_layout, | ||
dim, | ||
num_links, | ||
input_dtype, | ||
layout, | ||
ttnn.BufferType.L1, | ||
use_program_cache, | ||
function_level_defaults, | ||
enable_async=enable_async, | ||
input_shard_spec=input_shard_spec, | ||
num_iters=num_iters, | ||
num_all_gather_instances=replication_factor, | ||
cluster_axis=1, | ||
use_all_gather_async=True, | ||
enable_persistent_fabric=True, | ||
create_persistent_fabric=True, | ||
teardown_persistent_fabric=True, | ||
) | ||
|
||
|
||
# Enumerate the post-commit cases explicitly | ||
@skip_for_grayskull("Requires eth connected devices to run") | ||
@pytest.mark.parametrize( | ||
"num_devices, num_links", | ||
[ | ||
(8, 1), | ||
], | ||
# [(8, 4), (8, 3), (8, 2)], Multi-links fails https://github.com/tenstorrent/tt-metal/issues/16699 | ||
) | ||
@pytest.mark.parametrize( | ||
"input_dtype", | ||
[ | ||
ttnn.bfloat16, | ||
ttnn.bfloat8_b, | ||
], | ||
) | ||
@pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR]) | ||
@pytest.mark.parametrize( | ||
"tensor_mem_layout, input_shape, dim, input_shard_shape,shard_grid,layout", | ||
( | ||
( | ||
ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
(8, 1, 32, 2048), | ||
0, | ||
(32, 64), | ||
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), | ||
ttnn.TILE_LAYOUT, | ||
), | ||
( | ||
ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
(1, 8, 32, 2048), | ||
1, | ||
(32, 64), | ||
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), | ||
ttnn.TILE_LAYOUT, | ||
), | ||
( | ||
ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
(1, 1, 256, 2048), | ||
2, | ||
(32, 64), | ||
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), | ||
ttnn.TILE_LAYOUT, | ||
), | ||
( | ||
ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
(1, 1, 32, 16384), | ||
3, | ||
(32, 64), | ||
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), | ||
ttnn.TILE_LAYOUT, | ||
), | ||
( | ||
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, | ||
(8, 1, 2048, 32), | ||
0, | ||
(64, 32), | ||
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), | ||
ttnn.TILE_LAYOUT, | ||
), | ||
( | ||
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, | ||
(1, 8, 2048, 32), | ||
1, | ||
(64, 32), | ||
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), | ||
ttnn.TILE_LAYOUT, | ||
), | ||
( | ||
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, | ||
(1, 1, 16384, 32), | ||
2, | ||
(64, 32), | ||
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), | ||
ttnn.TILE_LAYOUT, | ||
), | ||
( | ||
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, | ||
(1, 1, 2048, 256), | ||
3, | ||
(64, 32), | ||
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), | ||
ttnn.TILE_LAYOUT, | ||
), | ||
), | ||
) | ||
@pytest.mark.parametrize("replication_factor", [4]) | ||
@pytest.mark.parametrize("enable_async", [True]) | ||
@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) | ||
def test_line_all_gather_sharded_on_TG_cols_post_commit( | ||
mesh_device, | ||
num_devices, | ||
input_shape, | ||
input_shard_shape, | ||
shard_grid, | ||
shard_grid_orientation, | ||
tensor_mem_layout, | ||
dim, | ||
num_links, | ||
input_dtype, | ||
layout, | ||
use_program_cache, | ||
function_level_defaults, | ||
enable_async, | ||
replication_factor, | ||
num_iters=1, | ||
): | ||
if len(mesh_device.get_devices()) != 32: | ||
pytest.skip("Not TG!") | ||
if input_dtype == ttnn.bfloat16 and input_shape == (1, 1, 256, 2048): | ||
pytest.skip("Skipped due to hang Issue #16699") | ||
input_shard_spec = ttnn.ShardSpec( | ||
shard_grid, | ||
input_shard_shape, | ||
shard_grid_orientation, | ||
) | ||
|
||
run_line_all_gather_on_TG_with_mesh_tensor_along_rows( | ||
mesh_device, | ||
num_devices, | ||
input_shape, | ||
tensor_mem_layout, | ||
dim, | ||
num_links, | ||
input_dtype, | ||
layout, | ||
ttnn.BufferType.L1, | ||
use_program_cache, | ||
function_level_defaults, | ||
enable_async=enable_async, | ||
num_iters=num_iters, | ||
input_shard_spec=input_shard_spec, | ||
num_all_gather_instances=replication_factor, | ||
cluster_axis=0, | ||
use_all_gather_async=True, | ||
enable_persistent_fabric=True, | ||
create_persistent_fabric=True, | ||
teardown_persistent_fabric=True, | ||
) | ||
|
||
|
||
# Enumerate the post-commit cases explicitly | ||
@skip_for_grayskull("Requires eth connected devices to run") | ||
@pytest.mark.parametrize( | ||
"num_devices, num_links, per_chip_output_shape, dim, layout", | ||
[ | ||
(8, 1, [1, 8, 32, 1280], 1, ttnn.TILE_LAYOUT), | ||
(8, 1, [8, 1, 32, 1280], 0, ttnn.TILE_LAYOUT), | ||
(8, 1, [1, 8, 32, 2048], 1, ttnn.TILE_LAYOUT), | ||
(8, 1, [1, 8, 32, 2304], 1, ttnn.TILE_LAYOUT), | ||
(8, 1, [1, 8, 32, 4096], 1, ttnn.TILE_LAYOUT), | ||
# multi-links fails: https://github.com/tenstorrent/tt-metal/issues/16699 | ||
# (8, 4, [1, 8, 32, 1280], 1, ttnn.TILE_LAYOUT), | ||
# (8, 4, [8, 1, 32, 1280], 0, ttnn.TILE_LAYOUT), | ||
# (8, 4, [1, 8, 32, 2048], 1, ttnn.TILE_LAYOUT), | ||
# (8, 4, [1, 8, 32, 2304], 1, ttnn.TILE_LAYOUT), | ||
# (8, 4, [1, 8, 32, 4096], 1, ttnn.TILE_LAYOUT), | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"input_dtype", | ||
[ | ||
ttnn.bfloat16, | ||
ttnn.bfloat8_b, | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"buffer_type", | ||
[ | ||
ttnn.BufferType.DRAM, | ||
ttnn.BufferType.L1, | ||
], | ||
) | ||
@pytest.mark.parametrize("replication_factor", [4]) | ||
@pytest.mark.parametrize("enable_async", [True]) | ||
@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) | ||
def test_line_all_gather_on_TG_cols_nightly( | ||
mesh_device, | ||
num_devices, | ||
per_chip_output_shape, | ||
dim, | ||
num_links, | ||
input_dtype, | ||
layout, | ||
buffer_type, | ||
use_program_cache, | ||
function_level_defaults, | ||
enable_async, | ||
replication_factor, | ||
num_iters=1, | ||
): | ||
if len(mesh_device.get_devices()) != 32: | ||
pytest.skip("Not TG!") | ||
run_line_all_gather_on_TG_with_mesh_tensor_along_rows( | ||
mesh_device, | ||
num_devices, | ||
per_chip_output_shape, | ||
ttnn.TensorMemoryLayout.INTERLEAVED, | ||
dim, | ||
num_links, | ||
input_dtype, | ||
layout, | ||
buffer_type, | ||
use_program_cache, | ||
function_level_defaults, | ||
enable_async=enable_async, | ||
num_iters=num_iters, | ||
num_all_gather_instances=replication_factor, | ||
cluster_axis=0, | ||
use_all_gather_async=True, | ||
enable_persistent_fabric=True, | ||
create_persistent_fabric=True, | ||
teardown_persistent_fabric=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.