Skip to content

Commit

Permalink
Re-enable UNet Shallow trace+2CQ test case
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed Jan 18, 2025
1 parent bc0575f commit a48a5de
Showing 1 changed file with 12 additions and 19 deletions.
31 changes: 12 additions & 19 deletions models/experimental/functional_unet/tests/test_unet_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,15 +353,14 @@ def test_unet_trace_2cq_multi_device(
ttnn.release_trace(mesh_device, tid)


@pytest.mark.skip("L1-to-DRAM resharding is currently broken (#16741)")
@skip_for_grayskull("UNet not currently supported on GS")
@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize(
"device_params", [{"l1_small_size": 68864, "trace_region_size": 423936, "num_command_queues": 2}], indirect=True
)
@pytest.mark.parametrize(
"batch, groups, iterations",
((1, 2, 32),),
((1, 2, 128),),
)
def test_unet_trace_2cq_same_io(
batch: int,
Expand Down Expand Up @@ -395,12 +394,12 @@ def test_unet_trace_2cq_same_io(
],
ttnn.ShardOrientation.ROW_MAJOR,
)
dram_memory_config = ttnn.MemoryConfig(
input_dram_memory_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.DRAM, dram_shard_spec
)

input_tensor = ttnn.allocate_tensor_on_device(
ttnn_input.shape, ttnn.bfloat16, ttnn.ROW_MAJOR_LAYOUT, device, dram_memory_config
ttnn_input.shape, ttnn.bfloat16, ttnn.ROW_MAJOR_LAYOUT, device, input_dram_memory_config
)
ttnn.record_event(0, op_event)
ttnn.record_event(1, read_event)
Expand All @@ -414,10 +413,11 @@ def test_unet_trace_2cq_same_io(
l1_input_tensor = ttnn.reshard(input_tensor, ttnn_model.input_sharded_memory_config)
ttnn.record_event(0, op_event)
output_tensor = ttnn_model(l1_input_tensor, move_input_tensor_to_device=False)

output_dram_shard_spec = ttnn.ShardSpec(
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 0))}),
[output_tensor.shape[-2], output_tensor.shape[-1] // 8],
ttnn.CoreRangeSet(
{ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(dram_grid_size.x - 1, dram_grid_size.y - 1))}
),
[output_tensor.shape[-2], output_tensor.shape[-1] // dram_grid_size.x],
ttnn.ShardOrientation.ROW_MAJOR,
)
output_dram_memory_config = ttnn.MemoryConfig(
Expand All @@ -426,11 +426,6 @@ def test_unet_trace_2cq_same_io(
dram_output_tensor = ttnn.reshard(output_tensor, output_dram_memory_config)
logger.info(f"Done compile run")

logger.info(f"Running sanity check on compile-run output against reference model output")
B, C, H, W = torch_output_tensor.shape
ttnn_output_tensor = ttnn.to_torch(dram_output_tensor).reshape(B, C, H, W)
verify_with_pcc(torch_output_tensor, ttnn_output_tensor, UNET_FULL_MODEL_PCC)

logger.info(f"Capturing trace")
ttnn.wait_for_event(1, op_event)
ttnn.copy_host_to_device_tensor(ttnn_input, input_tensor, cq_id=1)
Expand All @@ -454,7 +449,7 @@ def test_unet_trace_2cq_same_io(
)
assert input_trace_addr == l1_input_tensor.buffer_address()
ttnn.end_trace_capture(device, tid, cq_id=0)
dram_output_tensor = ttnn.reshard(output_tensor, dram_memory_config, dram_output_tensor)
dram_output_tensor = ttnn.reshard(output_tensor, output_dram_memory_config, dram_output_tensor)
ttnn.synchronize_device(device)

outputs = []
Expand All @@ -468,7 +463,7 @@ def test_unet_trace_2cq_same_io(
ttnn.record_event(0, op_event)
ttnn.execute_trace(device, tid, cq_id=0, blocking=False)
ttnn.wait_for_event(0, read_event)
dram_output_tensor = ttnn.reshard(output_tensor, dram_memory_config, dram_output_tensor)
dram_output_tensor = ttnn.reshard(output_tensor, output_dram_memory_config, dram_output_tensor)
ttnn.record_event(0, model_event)
ttnn.wait_for_event(1, op_event)
ttnn.copy_host_to_device_tensor(ttnn_input, input_tensor, cq_id=1)
Expand All @@ -481,18 +476,16 @@ def test_unet_trace_2cq_same_io(
ttnn.record_event(0, op_event)
ttnn.execute_trace(device, tid, cq_id=0, blocking=False)
ttnn.wait_for_event(0, read_event)
dram_output_tensor = ttnn.reshard(output_tensor, dram_memory_config, dram_output_tensor)
dram_output_tensor = ttnn.reshard(output_tensor, output_dram_memory_config, dram_output_tensor)
ttnn.record_event(0, model_event)
ttnn.wait_for_event(1, model_event)
outputs.append(dram_output_tensor.cpu(blocking=False, cq_id=1))
ttnn.synchronize_device(device)
end = time.time()
logger.info(f"Average model time={1000.0 * (end-start) / iterations : .2f} ms")
logger.info(f"Average model performance={iterations * batch / (end-start) : .2f} fps")
logger.info(f"Average model performance={iterations * groups * batch / (end-start) : .2f} fps")

logger.info(f"Running sanity check against reference model output")
B, C, H, W = torch_output_tensor.shape
ttnn_output_tensor = ttnn.to_torch(outputs[-1]).reshape(B, C, H, W)
verify_with_pcc(torch_output_tensor, ttnn_output_tensor, UNET_FULL_MODEL_PCC)

verify_with_pcc(torch_output_tensor, ttnn.to_torch(outputs[-1]).reshape(B, C, H, W), pcc=UNET_FULL_MODEL_PCC)
ttnn.release_trace(device, tid)

0 comments on commit a48a5de

Please sign in to comment.