Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poc elastic training #1310

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a410df6
Add Pathways Benchmarking Recipes for Scale Testing
SujeethJinesh Jan 31, 2025
ccd8c26
Updated input pipeline to use values from elastic utils.
lukebaumann Nov 21, 2024
9fdab26
Checking in benchmark changes
lukebaumann Mar 1, 2025
62b6eb0
checking in. async checkpointing does not work. no checkpointing or s…
lukebaumann Mar 1, 2025
33516b9
Wait until all slices are available before starting. Fixed a bug with…
lukebaumann Mar 3, 2025
0298e13
Updated to three maybe_* functions
lukebaumann Mar 4, 2025
ef3d7c5
Added an initial save at the beginning
lukebaumann Mar 4, 2025
7d1cf08
Removed restoring from checkpoint during reshard_handler
lukebaumann Mar 4, 2025
d5eb1b4
Updates for the refactor
lukebaumann Mar 5, 2025
f05b93c
Using shaurya's images, 20 checkpoint period, and sync checkpointing
lukebaumann Mar 5, 2025
4f3c822
Testing on smaller cluster and turning off profiler
lukebaumann Mar 5, 2025
6cb73e9
Updated simulator init args
lukebaumann Mar 5, 2025
2511d6e
Using simulator and simplified maybe_reshard_up/down return values
lukebaumann Mar 5, 2025
8b6c8ff
Fixed the bug. Indented to the wrong level so jit was within the cont…
lukebaumann Mar 5, 2025
f4110ed
Removed the simulator
lukebaumann Mar 5, 2025
60c3583
Removed an unecessary block
lukebaumann Mar 5, 2025
7a779ab
Using the big cluster
lukebaumann Mar 5, 2025
e5ccfe9
Added back one block
lukebaumann Mar 5, 2025
c7d3492
Turning off profiler
lukebaumann Mar 5, 2025
220c70f
Added NOT_FOUND to exception types that we try to handle with reshard…
lukebaumann Mar 7, 2025
ba07bf4
Fix to maybe_reshard_up for handling the retry logic
lukebaumann Mar 7, 2025
c0ad90c
Fixed the retry logic for reshard up (down was fixed prior). Fixed se…
lukebaumann Mar 8, 2025
8fd5628
Fixed another bug for reshard down retry logic
lukebaumann Mar 8, 2025
b575d37
Checking in. Things are not working
lukebaumann Mar 8, 2025
6b34aac
Cleaned up some random logs. I believe there is a race condition for …
lukebaumann Mar 8, 2025
ce94037
Bumping the reshard check period to 15
lukebaumann Mar 8, 2025
1fedd26
Changed name from save to snapshot. Added a snapshot deque. Refactore…
lukebaumann Mar 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added MaxText/elastic/__init__.py
Empty file.
72 changes: 72 additions & 0 deletions MaxText/elastic/reshard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Resharding API for elastic training."""

from typing import Any
from typing import Callable, Sequence
import jax


def default_put_array(
arr: jax.Array,
dst_sharding: jax.sharding.Sharding,
donate_input: bool,
):
if not isinstance(dst_sharding, jax.sharding.Sharding):
raise ValueError("`sharding` must contain only `Sharding` instances.")
return jax.device_put(arr, dst_sharding, donate=donate_input)


def reshard(
x: Any,
sharding: jax.sharding.Sharding | Any,
*,
donate_input: bool = False,
put_array: (
Callable[[jax.Array, Sequence[jax.sharding.Sharding], bool], jax.Array]
| None
) = None,
) -> Any:
"""Reshards `x` to the specified `sharding`.

Args:
x: An array, scalar, or a nested Python container thereof.
sharding: A `Sharding` or a nested `Sharding` in a Python container (must
match the structure of `x`), specifying the target sharding.
donate_input: If `True`, donates the input arrays to reduce memory needed
for resharding. Donated buffers should not be reused.
put_array: A function that takes an array, a sharding, and a boolean
indicating whether to donate the input, and returns a copy of the array
with the specified sharding.

Returns:
A copy of `x` with the specified `sharding`.
"""
if put_array is None:
put_array = default_put_array

flat_x, tree_def = jax.tree_util.tree_flatten(x)
flat_sharding = jax.api_util.flatten_axes(
"reshard sharding", tree_def, sharding
)

if len(flat_x) != len(flat_sharding):
raise ValueError("Mismatched length between `x` and `sharding`.")

arrays = [
put_array(arr, dst_sharding, donate_input)
for arr, dst_sharding in zip(flat_x, flat_sharding)
]
return jax.tree_util.tree_unflatten(tree_def, arrays)

72 changes: 72 additions & 0 deletions MaxText/elastic/simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for elastic training."""

import logging
from typing import Any, Sequence

import jax
from elastic import utils


PyTree = Any

logger = logging.getLogger(__name__)

logger.setLevel(logging.INFO)

# pylint: disable=logging-fstring-interpolation


class ElasticUtilsSimulator(utils.ElasticUtils):
"""Utility class for elastic training.

This class will simulate slices going down and coming back up.
"""
simulated_good_slice_indices: set[int]

def __init__(
self,
devices: Sequence[jax.Device],
total_slice_count: int,
snapshot_period: int | None = None,
reshard_check_period: int | None = None,
max_failure_count: int | None = None,
max_reshard_retry_count: int | None = None,
):
self.simulated_good_slice_indices = set(d.slice_index for d in devices)

super().__init__(
devices,
total_slice_count,
snapshot_period,
reshard_check_period,
max_failure_count,
max_reshard_retry_count,
)

def update_good_slice_indices(self, good_slice_indices: set[int]):
"""Start step handler."""
self.simulated_good_slice_indices = good_slice_indices
logger.info(f"Updated: {self.simulated_good_slice_indices=}")

@utils.timeit
def get_slice_availability(self) -> set[int]:
"""Returns the set of good and bad slices."""
good_slice_indices = self.simulated_good_slice_indices

logger.info(f"{good_slice_indices=}")

return good_slice_indices

Loading