Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform padding for multi-device tensors, to allow for homogenous TensorSpec across all devices #17476

Open
omilyutin-tt opened this issue Feb 1, 2025 · 7 comments
Assignees
Labels

Comments

@omilyutin-tt
Copy link
Contributor

Existing TTNN infra allows for uneven multi-device tensor sharding. For example:

>> torch_tensor = torch.ones(1, 13, 32, 32)
>>> mesh_tensor1 = ttnn.from_torch(torch_tensor,mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=1),layout=ttnn.TILE_LAYOUT,device=mesh_device)
>>> mesh_tensor2 = ttnn.from_torch(torch_tensor,mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=1),layout=ttnn.TILE_LAYOUT,device=mesh_device)
>>> output = ttnn.add(mesh_tensor, mesh_tensor2)
>>> output
(truncated)
device_id:1
ttnn.Tensor([[[[ 2.00000,  2.00000,  ...,  2.00000,  2.00000],
               [ 2.00000,  2.00000,  ...,  2.00000,  2.00000],
               ...,
               [ 2.00000,  2.00000,  ...,  2.00000,  2.00000],
               [ 2.00000,  2.00000,  ...,  2.00000,  2.00000]],

              [[ 2.00000,  2.00000,  ...,  2.00000,  2.00000],
               [ 2.00000,  2.00000,  ...,  2.00000,  2.00000],
               ...,
               [ 2.00000,  2.00000,  ...,  2.00000,  2.00000],
               [ 2.00000,  2.00000,  ...,  2.00000,  2.00000]]]], shape=Shape([1, 2, 32, 32]), dtype=DataType::FLOAT32, layout=Layout::TILE)
device_id:2
ttnn.Tensor([[[[ 2.00000,  2.00000,  ...,  2.00000,  2.00000],
               [ 2.00000,  2.00000,  ...,  2.00000,  2.00000],
               ...,
               [ 2.00000,  2.00000,  ...,  2.00000,  2.00000],
               [ 2.00000,  2.00000,  ...,  2.00000,  2.00000]]]], shape=Shape([1, 1, 32, 32]), dtype=DataType::FLOAT32, layout=Layout::TILE)

That is, sharding a tensor with (1, 13, 32, 32) shape across 8 devices results in the last shard being 2x smaller with (1, 1, 32, 32) shape as opposed to (1, 2, 32, 32). The addition operation ttnn.add and all of our ops dispatch infra works just fine.

This is a valid use case, but it won't be efficiently supported by the new distributed infrastructure. To enable homogenous workloads that have the same runtime args across devices, we should pad tensors, so that TensorSpec remains the same across all devices.

@omilyutin-tt omilyutin-tt self-assigned this Feb 1, 2025
@omilyutin-tt
Copy link
Contributor Author

As a first step, I am running a check to see how many of our existing tests and models rely on uneven multi-device sharding.

@omilyutin-tt omilyutin-tt added the P0 label Feb 1, 2025
@omilyutin-tt
Copy link
Contributor Author

Setting as P0 for now, during the investigation phase. It is likely this won't be a blocker, and we will find a workaround that trades off perf for generality.

@omilyutin-tt omilyutin-tt changed the title TTNN should pad multi-device tensors Perform padding for multi-device tensors, to allow for homogenous TensorSpec across all devices Feb 1, 2025
@omilyutin-tt
Copy link
Contributor Author

I ran the T3K, TG/TGG test suites with the assertion on. There is some noise, it seems that only falcon 7b demo on T3K runs into it.

FYI @cfjchu

@cfjchu
Copy link
Collaborator

cfjchu commented Feb 3, 2025

Falcon7b known to have uneven shapes so makes sense. Nice to know this was the only model that's affected. Any notable performance regressions ?

@omilyutin-tt
Copy link
Contributor Author

Any notable performance regressions ?

I just added a single TT_FATAL check, so this shouldn't be affecting perf.

@omilyutin-tt
Copy link
Contributor Author

We had offline conversation with @tt-asaigal @ayerofieiev-tt @jvegaTT @TT-BrianLiu @cfjchu. Padding is challenging, as it requires special handling, e.g. in reduction operations. Instead, we would like to invest in supporting heterogenous runtime args - this approach is similar to how single-device sharding is handled. Until that is fully supported by TT distributed, we can assume same tensor specs across devices.

Please add any missing context!

@omilyutin-tt omilyutin-tt added P1 and removed P0 labels Feb 3, 2025
@tt-asaigal
Copy link
Contributor

Thanks for flagging this @omilyutin-tt , this is important. I discussed with @cfjchu and I think we should consider the data-movement and compute logic separately.

Here are the options we discussed:

  1. Pad the input + Unpad the output
    • from_torch works without any changes - per device buffers are allocated based on uniform shard size + amount of data written to last device matches every other device
    • We are guaranteed to have reductions + matmuls in pretty much every model, so MeshWorkload needs to be aware of the uneven shard size on the last device. Long term we want this to be done through heterogenous runtime args (please see below for how something like this can be integrated with TTNN). Until this feature is fully supported, we can use a MeshWorkload with 2 programs (one that targets devices with uniform shards and another one that targets the last device with an uneven shard)
    • to_torch works without any changes, and we essentially cut the invalid data from the last shard by unpadding at the model level
  2. Have TTNN Data-Movement Primitives support Uneven Shards
    • from_torch and the functions it calls detect uneven sharding - a MeshBuffer is allocated to accommodate the largest shard. write_shards is used for uniform shards, the last shard goes through a single write_shard on a smaller buffer view
    • Ops still need to be aware of uneven shards (the approach mentioned above can be used)
    • to_torch is essentially the reverse of from_torch in this case

For either option, op infra changes are identical. I personally prefer option 2, since it can make models work out of the box as is.

Long term, we would like all ops to be multi-device aware, i.e. replace CreeateProgram with CreateMeshWorkload. In this case, the tensor spec can include multi-device sharding information, which can be used to program device-specific runtime args in the MeshWorkload created for the op. This is a fairly large and invasive change, since it touches every single op. I think it makes sense to defer this for now, and support this case through the infra itself.

cc. @davorchap

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants