diff --git a/models/experimental/functional_unet/tests/test_unet_trace.py b/models/experimental/functional_unet/tests/test_unet_trace.py index dd87aa66063..e308f350e83 100644 --- a/models/experimental/functional_unet/tests/test_unet_trace.py +++ b/models/experimental/functional_unet/tests/test_unet_trace.py @@ -353,7 +353,6 @@ def test_unet_trace_2cq_multi_device( ttnn.release_trace(mesh_device, tid) -@pytest.mark.skip("L1-to-DRAM resharding is currently broken (#16741)") @skip_for_grayskull("UNet not currently supported on GS") @pytest.mark.models_performance_bare_metal @pytest.mark.parametrize( @@ -361,7 +360,7 @@ def test_unet_trace_2cq_multi_device( ) @pytest.mark.parametrize( "batch, groups, iterations", - ((1, 2, 32),), + ((1, 2, 128),), ) def test_unet_trace_2cq_same_io( batch: int, @@ -395,12 +394,12 @@ def test_unet_trace_2cq_same_io( ], ttnn.ShardOrientation.ROW_MAJOR, ) - dram_memory_config = ttnn.MemoryConfig( + input_dram_memory_config = ttnn.MemoryConfig( ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.DRAM, dram_shard_spec ) input_tensor = ttnn.allocate_tensor_on_device( - ttnn_input.shape, ttnn.bfloat16, ttnn.ROW_MAJOR_LAYOUT, device, dram_memory_config + ttnn_input.shape, ttnn.bfloat16, ttnn.ROW_MAJOR_LAYOUT, device, input_dram_memory_config ) ttnn.record_event(0, op_event) ttnn.record_event(1, read_event) @@ -414,10 +413,11 @@ def test_unet_trace_2cq_same_io( l1_input_tensor = ttnn.reshard(input_tensor, ttnn_model.input_sharded_memory_config) ttnn.record_event(0, op_event) output_tensor = ttnn_model(l1_input_tensor, move_input_tensor_to_device=False) - output_dram_shard_spec = ttnn.ShardSpec( - ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 0))}), - [output_tensor.shape[-2], output_tensor.shape[-1] // 8], + ttnn.CoreRangeSet( + {ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(dram_grid_size.x - 1, dram_grid_size.y - 1))} + ), + [output_tensor.shape[-2], output_tensor.shape[-1] // dram_grid_size.x], ttnn.ShardOrientation.ROW_MAJOR, ) output_dram_memory_config = ttnn.MemoryConfig( @@ -426,11 +426,6 @@ def test_unet_trace_2cq_same_io( dram_output_tensor = ttnn.reshard(output_tensor, output_dram_memory_config) logger.info(f"Done compile run") - logger.info(f"Running sanity check on compile-run output against reference model output") - B, C, H, W = torch_output_tensor.shape - ttnn_output_tensor = ttnn.to_torch(dram_output_tensor).reshape(B, C, H, W) - verify_with_pcc(torch_output_tensor, ttnn_output_tensor, UNET_FULL_MODEL_PCC) - logger.info(f"Capturing trace") ttnn.wait_for_event(1, op_event) ttnn.copy_host_to_device_tensor(ttnn_input, input_tensor, cq_id=1) @@ -454,7 +449,7 @@ def test_unet_trace_2cq_same_io( ) assert input_trace_addr == l1_input_tensor.buffer_address() ttnn.end_trace_capture(device, tid, cq_id=0) - dram_output_tensor = ttnn.reshard(output_tensor, dram_memory_config, dram_output_tensor) + dram_output_tensor = ttnn.reshard(output_tensor, output_dram_memory_config, dram_output_tensor) ttnn.synchronize_device(device) outputs = [] @@ -468,7 +463,7 @@ def test_unet_trace_2cq_same_io( ttnn.record_event(0, op_event) ttnn.execute_trace(device, tid, cq_id=0, blocking=False) ttnn.wait_for_event(0, read_event) - dram_output_tensor = ttnn.reshard(output_tensor, dram_memory_config, dram_output_tensor) + dram_output_tensor = ttnn.reshard(output_tensor, output_dram_memory_config, dram_output_tensor) ttnn.record_event(0, model_event) ttnn.wait_for_event(1, op_event) ttnn.copy_host_to_device_tensor(ttnn_input, input_tensor, cq_id=1) @@ -481,18 +476,16 @@ def test_unet_trace_2cq_same_io( ttnn.record_event(0, op_event) ttnn.execute_trace(device, tid, cq_id=0, blocking=False) ttnn.wait_for_event(0, read_event) - dram_output_tensor = ttnn.reshard(output_tensor, dram_memory_config, dram_output_tensor) + dram_output_tensor = ttnn.reshard(output_tensor, output_dram_memory_config, dram_output_tensor) ttnn.record_event(0, model_event) ttnn.wait_for_event(1, model_event) outputs.append(dram_output_tensor.cpu(blocking=False, cq_id=1)) ttnn.synchronize_device(device) end = time.time() logger.info(f"Average model time={1000.0 * (end-start) / iterations : .2f} ms") - logger.info(f"Average model performance={iterations * batch / (end-start) : .2f} fps") + logger.info(f"Average model performance={iterations * groups * batch / (end-start) : .2f} fps") logger.info(f"Running sanity check against reference model output") B, C, H, W = torch_output_tensor.shape - ttnn_output_tensor = ttnn.to_torch(outputs[-1]).reshape(B, C, H, W) - verify_with_pcc(torch_output_tensor, ttnn_output_tensor, UNET_FULL_MODEL_PCC) - + verify_with_pcc(torch_output_tensor, ttnn.to_torch(outputs[-1]).reshape(B, C, H, W), pcc=UNET_FULL_MODEL_PCC) ttnn.release_trace(device, tid)