Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DiLoCo training. #1353

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Support DiLoCo training. #1353

wants to merge 1 commit into from

Conversation

ZacharyGarrett
Copy link
Collaborator

@ZacharyGarrett ZacharyGarrett commented Mar 6, 2025

Description

This is an initial implementation of Distributed Low-Communication training (DiLoCo) as described in https://arxiv.org/abs/2311.08105. DiLoCo is an inner-outer bi-level optimization training strategy that significantly reduces the amount of bandwidth used compared to data-parallel training by syncing between the replicas less frequently.

This implementation adds the drjax package to the pip requirements for bookkeeping and subtle configuration of the jax.vmap's spmd_axis_name argument.

Going forward, one can specify ici_diloco_parallelism or dcn_dilooco_parallelism greater than 1 (the default, which disables) to enable DiLoCo training.

Next steps would include implementing the streaming DiLoCo variant (https://arxiv.org/abs/2501.18512)

Tests

This PR introduces a new tests/diloco_test.py that has a test for numerical correctness of a simple two parameter model.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

This is an initial implementation of Distributed Low-Communication
trianing (DiLoCo) as described in https://arxiv.org/abs/2311.08105.

This implementation adds the `drjax` package to the pip requirements for
bookkeeping and subtle configuraiton of the `jax.vmap`'s
`spmd_axis_name` argument.

Going forward, one can specify `ici_diloco_parallelism` or
`dcn_dilooco_parallelism` greater than 1 (the default, which disables)
to enable DiLoCo training.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant