-
Notifications
You must be signed in to change notification settings - Fork 279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Domain Parallelism with ShardTensor #784
base: main
Are you sure you want to change the base?
Conversation
…simple DDP sharding
…ieces are WIP but this has basic functionality supported for creation and forward usage.
…t of the ops have been validated, all that remains is to wrap the na2d function call to ensure it will dispatch properly.
…ng in unbind op rules.
….ops.aten.convolution.default.
…s also a minor bug in the backward pass that got more pronounced with smaller data: grad inputs were failing to properly collect haloed gradients and add them on the edges. Now fixed.
…gnificant overhead. I'm implementing here an option to switch to peer to peer message passing, since it might benefit from stream utilization in layers like natten.na2d. It's a developer choice currently, not a user choice.
…gnificant functionality changes in this commit.
Add `scatter_tensor` function to enable more easy transition to shard tensor. This function allows users to maintain data pipelines (on one rank) and easily scatter that data to a domain mesh.
But also, this adjusts the shard tensor mechanism for tracking shard info to use a dict instead of a list of tuples.
No real code changes applied here.
There appears to be one corner case in redistribute to fix. TBD. Tests for grad propogation are coming.
FSDP and modulus ShardTensor
\blossom-ci |
Overall, I think this is looking good, nice work! I started with the documentation and then focused on unit tests to see overall functionality, as well as changes to the Aside from the minor comments added, my main flag is to make the unit testing more complete, but I don’t think that should necessarily block merging. In particular I think for ops that we support (conv or nat, currently), we should add unit tests for correctness compared to a non-sharded baseline for forward and backward passes (subject to within some numerical tolerance, esp. in context of the neighborhood attention numerics we discovered). |
Thanks for the review @pzharrington! I agree with you on the testing. Here's my thoughts:
|
/multi-gpu-ci |
Modulus Pull Request
Description
This PR adds new capabilities to Modulus:
ShardTensor
is an extension to pytorchDTensor
that enables uneven sharding of tensors across DeviceMesh objects. While some logical sharding constraints remain, this allows more dynamic and flexible operation on distributed input data, especially in cases where the input data shape and output data shape differ.ShardTensor
also enables an ecosystem of operation extensions. Two major ones are included in this PR: convolutions (1D/2D/3D) and neighborhood attention. When the right components of modulus are imported, these operations (when performed on sharded tensors) will automatically compute halo regions and perform data transfers to enable results consistent with single device outputs.ShardTensor
, as well as an example of integrating multiple levels of parallelism by combining shard tensor and pytorchFSDP
.Checklist
Dependencies
Adds a dependency on
wrapt
for monkey-patching operations on sharded inputs..