Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This is an initial implementation of Distributed Low-Communication training (DiLoCo) as described in https://arxiv.org/abs/2311.08105. DiLoCo is an inner-outer bi-level optimization training strategy that significantly reduces the amount of bandwidth used compared to data-parallel training by syncing between the replicas less frequently.
This implementation adds the
drjax
package to the pip requirements for bookkeeping and subtle configuration of thejax.vmap
'sspmd_axis_name
argument.Going forward, one can specify
ici_diloco_parallelism
ordcn_dilooco_parallelism
greater than 1 (the default, which disables) to enable DiLoCo training.Next steps would include implementing the streaming DiLoCo variant (https://arxiv.org/abs/2501.18512)
Tests
This PR introduces a new tests/diloco_test.py that has a test for numerical correctness of a simple two parameter model.
Checklist
Before submitting this PR, please make sure (put X in square brackets):