Skip to content

Commit

Permalink
Support DiLoCo training.
Browse files Browse the repository at this point in the history
This is an initial implementation of Distributed Low-Communication
trianing (DiLoCo) as described in https://arxiv.org/abs/2311.08105.

This implementation adds the `drjax` package to the pip requirements for
bookkeeping and subtle configuraiton of the `jax.vmap`'s
`spmd_axis_name` argument.

Going forward, one can specify `ici_diloco_parallelism` or
`dcn_dilooco_parallelism` greater than 1 (the default, which disables)
to enable DiLoCo training.
  • Loading branch information
ZacharyGarrett committed Mar 6, 2025
1 parent e2b8c53 commit 1299433
Show file tree
Hide file tree
Showing 8 changed files with 479 additions and 6 deletions.
11 changes: 9 additions & 2 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ jax_cache_dir: "~/jax_cache"
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu'

# Parallelism
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
Expand Down Expand Up @@ -308,7 +308,7 @@ logical_axis_rules: [
['paged_kv_head_dim_size', []],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
data_sharding: [['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]

# sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters.
sharding_tolerance: 0.02
Expand All @@ -317,6 +317,7 @@ sharding_tolerance: 0.02
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_diloco_parallelism: 1
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_fsdp_transpose_parallelism: 1
Expand All @@ -327,6 +328,7 @@ dcn_tensor_sequence_parallelism: 1 # never recommended
dcn_pipeline_parallelism: 1
dcn_expert_parallelism: 1
dcn_autoregressive_parallelism: 1 # never recommended
ici_diloco_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_transpose_parallelism: 1
Expand Down Expand Up @@ -450,6 +452,11 @@ enable_data_shuffling: True
data_shuffle_seed: 0
init_weights_seed: 0

# DiLoCo params https://arxiv.org/pdf/2311.08105
diloco_sync_period: 20
diloco_outer_lr: 0.7
diloco_outer_momentum: 0.9

# You may disable clipping by setting gradient_clipping_threshold to zero.
gradient_clipping_threshold: 1.0

Expand Down
164 changes: 164 additions & 0 deletions MaxText/diloco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An implementation of Distributed Low-Communication (DiLoCo) training.
This module contains implementations of:
- DiLoCo: Distributed Low-Communication Training of Language Models
https://arxiv.org/abs/2311.08105
- Streaming DiLoCo with overlapping communication: Towards a Distributed Free Lunch
https://arxiv.org/abs/2501.18512
"""

import drjax
import jax
from jaxtyping import Array, Int32, Key, PyTree, UInt32
import optax
import pyconfig

from flax import struct
from typing import Any, Callable, Protocol, Tuple

Batch = Any
Params = PyTree
Metrics = PyTree
OptState = optax.OptState
InnerOptStates = optax.OptState
PRNGKey = Key[Array, ""] | UInt32[Array, "2"]
Step = Int32[Array, ""]


class StateProtocol(Protocol):
"""The protocol expected from the underlying train step state."""

@property
def params(self) -> Params:
...

@property
def step(self) -> Step:
...


class DiLoCoTrainState(struct.PyTreeNode):
"""The state of the DiLoCo training process.
Attributes:
inner_state: A PyTree of the state for each step of the inner optimization.
All arrays are expected to have a leading dimension with size of the
number of diloco replicas so that training steps can be mapped over this
dimension.
outer_params: A PyTree of the global model weights. These will mimic a
sub-PyTree in `inner_state`, which rank-1 shape.
outer_opt_state: The state for the outer Nesterov momentum optimizer.
step: The step counter of the training process.
"""

inner_state: StateProtocol
outer_params: Params
outer_opt_state: OptState
step: Step


def build_diloco_train_step(
config: pyconfig.HyperParameters,
train_step: Callable[[StateProtocol, Batch, PRNGKey], tuple[StateProtocol, Metrics]],
state: StateProtocol,
) -> Tuple[
DiLoCoTrainState,
Callable[[DiLoCoTrainState, Batch, PRNGKey], tuple[DiLoCoTrainState, Metrics]],
]:
"""Convert a local state and train step into DiLoCo-compatible versions.
This is an implementation of the original (non-streaming) DiLoCo algorithm
which syncs all model parameters across the replicas every
`config.diloco_sync_period` steps, treating the difference accumulated over
non-sync steps as a pseudo gradient and applying SGD with Nesterov momentum on
the "global" model.
Args:
config: The config used to set up training.
train_step: A local train step. This will be executed independently within
each replica.
state: The flax train state of the standard train loop. This will be
modified to create DiLoCo state.
"""
outer_optimizer = optax.sgd(
config.diloco_outer_lr,
momentum=config.diloco_outer_momentum,
nesterov=True,
)

@drjax.program(placements={"diloco": config.num_diloco_replicas})
def init_train_state(state: StateProtocol) -> DiLoCoTrainState:
# Inner state must be broadcast across clients.
inner_state = drjax.broadcast(state)
# Outer state retains a single copy of the model parameters and optimizer
# state.
outer_params = state.params
outer_opt_state = outer_optimizer.init(outer_params)
return DiLoCoTrainState(
inner_state=inner_state,
outer_params=outer_params,
outer_opt_state=outer_opt_state,
step=state.step,
)

def synchronize(state):
# Calculate the delta between the current replica's state and the global
# state (since last synchronization).
broadcast_outer_params = drjax.broadcast(state.outer_params)
model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params)
# Treat the average delta as the outer optimizer's gradient and apply to
# the global (outer) model params.
averaged_pseudo_grad = drjax.reduce_mean(model_delta)
updates, new_opt_state = outer_optimizer.update(averaged_pseudo_grad, state.outer_opt_state, state.outer_params)
new_outer_params = optax.apply_updates(state.outer_params, updates)
# Replace inner model params with the new global model params.
# NOTE: inner optimizer state is retained despite the change in parameters,
# see section 6.1 in https://arxiv.org/pdf/2311.08105.
new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state)
return state.replace(
outer_params=new_outer_params,
outer_opt_state=new_opt_state,
inner_state=new_inner_state,
)

def typed_reduce_mean(in_tree):
total = drjax.reduce_sum(in_tree)
avg = jax.tree.map(lambda x: (x / config.num_diloco_replicas).astype(x.dtype), total)
return avg

@drjax.program(placements={"diloco": config.num_diloco_replicas})
def diloco_train_step(state, batch, prng):
# Broadcast the RNG across replicas.
broadcast_rng = drjax.broadcast(prng)
inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng))
avg_metrics = typed_reduce_mean(metrics)
state = state.replace(
inner_state=inner_state,
step=inner_state.step[0],
)
# Either synchronize the model, or no-op, depending on whether the current
# step falls on the synchronization period.
state = jax.lax.cond(
inner_state.step[0] % config.diloco_sync_period == 0,
synchronize,
lambda x: x, # no-op
state,
)
return state, avg_metrics

return init_train_state(state), diloco_train_step
25 changes: 24 additions & 1 deletion MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,29 @@ def create_device_mesh(config, devices=None):

multi_slice_env = num_slices > 1

dcn_parallelism = [
config.dcn_diloco_parallelism,
config.dcn_data_parallelism,
config.dcn_pipeline_parallelism,
config.dcn_fsdp_parallelism,
config.dcn_fsdp_transpose_parallelism,
config.dcn_sequence_parallelism,
config.dcn_tensor_parallelism,
config.dcn_expert_parallelism,
config.dcn_autoregressive_parallelism,
]
ici_parallelism = [
config.ici_diloco_parallelism,
config.ici_data_parallelism,
config.ici_pipeline_parallelism,
config.ici_fsdp_parallelism,
config.ici_fsdp_transpose_parallelism,
config.ici_sequence_parallelism,
config.ici_tensor_parallelism,
config.ici_expert_parallelism,
config.ici_autoregressive_parallelism,
]

# Find possible unspecified parallelisms
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")

Expand Down Expand Up @@ -640,7 +663,7 @@ def create_device_mesh(config, devices=None):
if config.optimize_mesh_for_tpu_v6e:
mesh = optimize_mesh_for_tpu_v6e(mesh, devices)

max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
max_logging.log(f"Num_devices: {num_devices}, shape: {mesh.shape}")

return mesh

Expand Down
3 changes: 3 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ def update_model_vars(base_config_path, raw_keys, config_name: str):

def create_parallelisms_list(raw_keys):
ici_parallelism = [
raw_keys["ici_diloco_parallelism"],
raw_keys["ici_data_parallelism"],
raw_keys["ici_pipeline_parallelism"],
raw_keys["ici_fsdp_parallelism"],
Expand All @@ -541,6 +542,7 @@ def create_parallelisms_list(raw_keys):
raw_keys["ici_autoregressive_parallelism"],
]
dcn_parallelism = [
raw_keys["dcn_diloco_parallelism"],
raw_keys["dcn_data_parallelism"],
raw_keys["dcn_pipeline_parallelism"],
raw_keys["dcn_fsdp_parallelism"],
Expand All @@ -554,6 +556,7 @@ def create_parallelisms_list(raw_keys):
]
raw_keys["ici_parallelism"] = ici_parallelism
raw_keys["dcn_parallelism"] = dcn_parallelism
raw_keys["num_diloco_replicas"] = int(raw_keys["ici_diloco_parallelism"] * raw_keys["dcn_diloco_parallelism"])
return raw_keys


Expand Down
Loading

0 comments on commit 1299433

Please sign in to comment.