Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DiLoCo training. #1353

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ jax_cache_dir: "~/jax_cache"
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu'

# Parallelism
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
Expand Down Expand Up @@ -308,7 +308,7 @@ logical_axis_rules: [
['paged_kv_head_dim_size', []],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
data_sharding: [['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]

# sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters.
sharding_tolerance: 0.02
Expand All @@ -317,6 +317,7 @@ sharding_tolerance: 0.02
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_diloco_parallelism: 1
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_fsdp_transpose_parallelism: 1
Expand All @@ -327,6 +328,7 @@ dcn_tensor_sequence_parallelism: 1 # never recommended
dcn_pipeline_parallelism: 1
dcn_expert_parallelism: 1
dcn_autoregressive_parallelism: 1 # never recommended
ici_diloco_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_transpose_parallelism: 1
Expand Down Expand Up @@ -450,6 +452,11 @@ enable_data_shuffling: True
data_shuffle_seed: 0
init_weights_seed: 0

# DiLoCo params https://arxiv.org/pdf/2311.08105
diloco_sync_period: 20
diloco_outer_lr: 0.7
diloco_outer_momentum: 0.9

# You may disable clipping by setting gradient_clipping_threshold to zero.
gradient_clipping_threshold: 1.0

Expand Down
164 changes: 164 additions & 0 deletions MaxText/diloco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An implementation of Distributed Low-Communication (DiLoCo) training.

This module contains implementations of:

- DiLoCo: Distributed Low-Communication Training of Language Models
https://arxiv.org/abs/2311.08105
- Streaming DiLoCo with overlapping communication: Towards a Distributed Free Lunch
https://arxiv.org/abs/2501.18512
"""

import drjax
import jax
from jaxtyping import Array, Int32, Key, PyTree, UInt32
import optax
import pyconfig

from flax import struct
from typing import Any, Callable, Protocol, Tuple

Batch = Any
Params = PyTree
Metrics = PyTree
OptState = optax.OptState
InnerOptStates = optax.OptState
PRNGKey = Key[Array, ""] | UInt32[Array, "2"]
Step = Int32[Array, ""]


class StateProtocol(Protocol):
"""The protocol expected from the underlying train step state."""

@property
def params(self) -> Params:
...

@property
def step(self) -> Step:
...


class DiLoCoTrainState(struct.PyTreeNode):
"""The state of the DiLoCo training process.

Attributes:
inner_state: A PyTree of the state for each step of the inner optimization.
All arrays are expected to have a leading dimension with size of the
number of diloco replicas so that training steps can be mapped over this
dimension.
outer_params: A PyTree of the global model weights. These will mimic a
sub-PyTree in `inner_state`, which rank-1 shape.
outer_opt_state: The state for the outer Nesterov momentum optimizer.
step: The step counter of the training process.
"""

inner_state: StateProtocol
outer_params: Params
outer_opt_state: OptState
step: Step


def build_diloco_train_step(
config: pyconfig.HyperParameters,
train_step: Callable[[StateProtocol, Batch, PRNGKey], tuple[StateProtocol, Metrics]],
state: StateProtocol,
) -> Tuple[
DiLoCoTrainState,
Callable[[DiLoCoTrainState, Batch, PRNGKey], tuple[DiLoCoTrainState, Metrics]],
]:
"""Convert a local state and train step into DiLoCo-compatible versions.

This is an implementation of the original (non-streaming) DiLoCo algorithm
which syncs all model parameters across the replicas every
`config.diloco_sync_period` steps, treating the difference accumulated over
non-sync steps as a pseudo gradient and applying SGD with Nesterov momentum on
the "global" model.

Args:
config: The config used to set up training.
train_step: A local train step. This will be executed independently within
each replica.
state: The flax train state of the standard train loop. This will be
modified to create DiLoCo state.
"""
outer_optimizer = optax.sgd(
config.diloco_outer_lr,
momentum=config.diloco_outer_momentum,
nesterov=True,
)

@drjax.program(placements={"diloco": config.num_diloco_replicas})
def init_train_state(state: StateProtocol) -> DiLoCoTrainState:
# Inner state must be broadcast across clients.
inner_state = drjax.broadcast(state)
# Outer state retains a single copy of the model parameters and optimizer
# state.
outer_params = state.params
outer_opt_state = outer_optimizer.init(outer_params)
return DiLoCoTrainState(
inner_state=inner_state,
outer_params=outer_params,
outer_opt_state=outer_opt_state,
step=state.step,
)

def synchronize(state):
# Calculate the delta between the current replica's state and the global
# state (since last synchronization).
broadcast_outer_params = drjax.broadcast(state.outer_params)
model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params)
# Treat the average delta as the outer optimizer's gradient and apply to
# the global (outer) model params.
averaged_pseudo_grad = drjax.reduce_mean(model_delta)
updates, new_opt_state = outer_optimizer.update(averaged_pseudo_grad, state.outer_opt_state, state.outer_params)
new_outer_params = optax.apply_updates(state.outer_params, updates)
# Replace inner model params with the new global model params.
# NOTE: inner optimizer state is retained despite the change in parameters,
# see section 6.1 in https://arxiv.org/pdf/2311.08105.
new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state)
return state.replace(
outer_params=new_outer_params,
outer_opt_state=new_opt_state,
inner_state=new_inner_state,
)

def typed_reduce_mean(in_tree):
total = drjax.reduce_sum(in_tree)
avg = jax.tree.map(lambda x: (x / config.num_diloco_replicas).astype(x.dtype), total)
return avg

@drjax.program(placements={"diloco": config.num_diloco_replicas})
def diloco_train_step(state, batch, prng):
# Broadcast the RNG across replicas.
broadcast_rng = drjax.broadcast(prng)
inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng))
avg_metrics = typed_reduce_mean(metrics)
state = state.replace(
inner_state=inner_state,
step=inner_state.step[0],
)
# Either synchronize the model, or no-op, depending on whether the current
# step falls on the synchronization period.
state = jax.lax.cond(
inner_state.step[0] % config.diloco_sync_period == 0,
synchronize,
lambda x: x, # no-op
state,
)
return state, avg_metrics

return init_train_state(state), diloco_train_step
25 changes: 24 additions & 1 deletion MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,29 @@ def create_device_mesh(config, devices=None):

multi_slice_env = num_slices > 1

dcn_parallelism = [
config.dcn_diloco_parallelism,
config.dcn_data_parallelism,
config.dcn_pipeline_parallelism,
config.dcn_fsdp_parallelism,
config.dcn_fsdp_transpose_parallelism,
config.dcn_sequence_parallelism,
config.dcn_tensor_parallelism,
config.dcn_expert_parallelism,
config.dcn_autoregressive_parallelism,
]
ici_parallelism = [
config.ici_diloco_parallelism,
config.ici_data_parallelism,
config.ici_pipeline_parallelism,
config.ici_fsdp_parallelism,
config.ici_fsdp_transpose_parallelism,
config.ici_sequence_parallelism,
config.ici_tensor_parallelism,
config.ici_expert_parallelism,
config.ici_autoregressive_parallelism,
]

# Find possible unspecified parallelisms
ici_parallelism = fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")

Expand Down Expand Up @@ -640,7 +663,7 @@ def create_device_mesh(config, devices=None):
if config.optimize_mesh_for_tpu_v6e:
mesh = optimize_mesh_for_tpu_v6e(mesh, devices)

max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}")
max_logging.log(f"Num_devices: {num_devices}, shape: {mesh.shape}")

return mesh

Expand Down
3 changes: 3 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ def update_model_vars(base_config_path, raw_keys, config_name: str):

def create_parallelisms_list(raw_keys):
ici_parallelism = [
raw_keys["ici_diloco_parallelism"],
raw_keys["ici_data_parallelism"],
raw_keys["ici_pipeline_parallelism"],
raw_keys["ici_fsdp_parallelism"],
Expand All @@ -541,6 +542,7 @@ def create_parallelisms_list(raw_keys):
raw_keys["ici_autoregressive_parallelism"],
]
dcn_parallelism = [
raw_keys["dcn_diloco_parallelism"],
raw_keys["dcn_data_parallelism"],
raw_keys["dcn_pipeline_parallelism"],
raw_keys["dcn_fsdp_parallelism"],
Expand All @@ -554,6 +556,7 @@ def create_parallelisms_list(raw_keys):
]
raw_keys["ici_parallelism"] = ici_parallelism
raw_keys["dcn_parallelism"] = dcn_parallelism
raw_keys["num_diloco_replicas"] = int(raw_keys["ici_diloco_parallelism"] * raw_keys["dcn_diloco_parallelism"])
return raw_keys


Expand Down
Loading
Loading