Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Domain Parallelism with ShardTensor #784

Open
wants to merge 43 commits into
base: main
Choose a base branch
from

Conversation

coreyjadams
Copy link
Collaborator

@coreyjadams coreyjadams commented Feb 6, 2025

Modulus Pull Request

Description

This PR adds new capabilities to Modulus:

  • ShardTensor is an extension to pytorch DTensor that enables uneven sharding of tensors across DeviceMesh objects. While some logical sharding constraints remain, this allows more dynamic and flexible operation on distributed input data, especially in cases where the input data shape and output data shape differ.
  • ShardTensor also enables an ecosystem of operation extensions. Two major ones are included in this PR: convolutions (1D/2D/3D) and neighborhood attention. When the right components of modulus are imported, these operations (when performed on sharded tensors) will automatically compute halo regions and perform data transfers to enable results consistent with single device outputs.
    • For small data, this is not useful, but for extremely large data this is a powerful way to scale training on large inputs.
  • The documentation for Modulus now includes an API reference for ShardTensor, as well as an example of integrating multiple levels of parallelism by combining shard tensor and pytorch FSDP.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

Adds a dependency on wrapt for monkey-patching operations on sharded inputs..

coreyjadams and others added 30 commits December 17, 2024 08:52
…ieces are WIP but this has basic functionality supported for creation and forward usage.
…t of the ops have been validated, all that remains is to wrap the na2d function call to ensure it will dispatch properly.
…s also a minor bug in the backward

pass that got more pronounced with smaller data: grad inputs were failing to properly collect
haloed gradients and add them on the edges.  Now fixed.
…gnificant overhead.

I'm implementing here an option to switch to peer to peer message passing, since it might
benefit from stream utilization in layers like natten.na2d.

It's a developer choice currently, not a user choice.
…gnificant functionality changes in this commit.
Add `scatter_tensor` function to enable more easy transition to shard tensor.
This function allows users to maintain data pipelines (on one rank) and easily
scatter that data to a domain mesh.
But also, this adjusts the shard tensor mechanism for tracking shard info to use
a dict instead of a list of tuples.
No real code changes applied here.
@pzharrington
Copy link
Collaborator

\blossom-ci

@pzharrington
Copy link
Collaborator

Overall, I think this is looking good, nice work! I started with the documentation and then focused on unit tests to see overall functionality, as well as changes to the DistributedManager to see device mesh functionality and the main changes to what existed previously. Also looked at the ShardTensor definition, halo collectives and conv/natten patches, but didn’t spend much time on the other backend stuff.

Aside from the minor comments added, my main flag is to make the unit testing more complete, but I don’t think that should necessarily block merging. In particular I think for ops that we support (conv or nat, currently), we should add unit tests for correctness compared to a non-sharded baseline for forward and backward passes (subject to within some numerical tolerance, esp. in context of the neighborhood attention numerics we discovered).

Update tutorial based on feedback from @pzharrington
Remove wildcard import.
@coreyjadams
Copy link
Collaborator Author

Thanks for the review @pzharrington! I agree with you on the testing. Here's my thoughts:

  • Modulus should support unit tests of the basic ShardTensor functionality, "baked in". Most of those are there, but I have locally some tests in development regarding gradient propagation through sharded tensors. I would like to get them in but didn't want to hold the review.
  • Tests on numerical accuracy are probably too much for CI/CD and unit testing. I am working on exactly the tools you highlighted, but it's currently manually run and analyzed. It aims to support numerical checking (and performance benchmarking!) of all the operations we patch like this in modulus, as well as extending to more comprehensive layers and even full models. I'll get the repo up on gitlab to iron out the kinks and keep the numerical checking untied from the modulus release. FYI, apart from the issues we found in na2d for long sequence lengths, all patched operations are passing numerical checks.

@ktangsali
Copy link
Collaborator

/multi-gpu-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants