Skip to content

Commit

Permalink
Updated elasticutils to how it will be structured in pathwaysutils
Browse files Browse the repository at this point in the history
  • Loading branch information
lukebaumann committed Feb 26, 2025
1 parent e1ca0d2 commit 0c975ce
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 142 deletions.
71 changes: 71 additions & 0 deletions MaxText/elastic/reshard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Resharding API for elastic training."""

from typing import Any
from typing import Callable, Optional, Sequence
import jax


def default_put_array(
arr: jax.Array,
dst_sharding: jax.sharding.Sharding,
donate_input: bool,
):
if not isinstance(dst_sharding, jax.sharding.Sharding):
raise ValueError("`sharding` must contain only `Sharding` instances.")
return jax.device_put(arr, dst_sharding, donate=donate_input)


def reshard(
x: Any,
sharding: jax.sharding.Sharding | Any,
*,
donate_input: bool = True,
put_array: Optional[
Callable[[jax.Array, Sequence[jax.sharding.Sharding], bool], jax.Array]
] = None,
) -> Any:
"""Reshards `x` to the specified `sharding`.
Args:
x: An array, scalar, or a nested Python container thereof.
sharding: A `Sharding` or a nested `Sharding` in a Python container (must
match the structure of `x`), specifying the target sharding.
donate_input: If `True`, donates the input arrays to reduce memory needed
for resharding. Donated buffers should not be reused.
put_array: A function that takes an array, a sharding, and a boolean
indicating whether to donate the input, and returns a copy of the array
with the specified sharding.
Returns:
A copy of `x` with the specified `sharding`.
"""
if put_array is None:
put_array = default_put_array

flat_x, tree_def = jax.tree_util.tree_flatten(x)
flat_sharding = jax.api_util.flatten_axes(
"reshard sharding", tree_def, sharding
)

if len(flat_x) != len(flat_sharding):
raise ValueError("Mismatched length between `x` and `sharding`.")

arrays = [
put_array(arr, dst_sharding, donate_input)
for arr, dst_sharding in zip(flat_x, flat_sharding)
]
return jax.tree_util.tree_unflatten(tree_def, arrays)

70 changes: 70 additions & 0 deletions MaxText/elastic/simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for elastic training."""

import logging
from typing import Any, Optional, Sequence

import jax
from pathwaysutils.google_internal.elastic import utils


PyTree = Any

logger = logging.getLogger(__name__)

logger.setLevel(logging.INFO)

# pylint: disable=logging-fstring-interpolation


class ElasticUtilsSimulator(utils.ElasticUtils):
"""Utility class for elastic training.
This class will simulate slices going down and coming back up.
"""
simulated_good_slice_indices: set[int]

def __init__(
self,
devices: Sequence[jax.Device],
total_slice_count: int,
save_period: Optional[int] = None,
reshard_check_period: Optional[int] = None,
max_failures: Optional[int] = None,
):
self.simulated_good_slice_indices = set(d.slice_index for d in devices)

super().__init__(
devices,
total_slice_count,
save_period,
reshard_check_period,
max_failures,
)

def update_good_slice_indices(self, good_slice_indices: set[int]):
"""Start step handler."""
self.simulated_good_slice_indices = good_slice_indices
logger.info(f"Updated: {self.simulated_good_slice_indices=}")

@utils.timeit
def get_slice_availability(self) -> set[int]:
"""Returns the set of good and bad slices."""
good_slice_indices = self.simulated_good_slice_indices

logger.info(f"{good_slice_indices=}")

return good_slice_indices

112 changes: 56 additions & 56 deletions MaxText/elasticutils.py → MaxText/elastic/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for elastic training."""

import collections
from collections.abc import Mapping
import contextlib
import functools
import itertools
Expand All @@ -13,6 +27,7 @@

import jax
import numpy as np
from pathwaysutils.google_internal.elastic import reshard

jax._src.array.ArrayImpl._check_if_deleted = lambda _: False # pylint: disable=protected-access

Expand Down Expand Up @@ -74,35 +89,34 @@ def wrapper(*args, **kwargs):

class ElasticUtils:
"""Utility class for elastic training."""
_devices: Sequence[jax.Device]
slice_to_devices: Mapping[int, Sequence[jax.Device]]
total_slice_count: int
save_period: int
reshard_check_period: int
max_failure_count: Optional[int]
max_reshard_retry_count: Optional[int]
failure_count: int
reshard_retry_count: int
good_slice_indices: set[int]
data: Mapping[str, Any]

TEST_VALUE = 100

def __init__(
self,
devices: Sequence[jax.Device],
total_slice_count: int,
save_period: Optional[int] = None,
reshard_check_period: Optional[int] = None,
save_period: int = 1,
reshard_check_period: int = 1,
max_failure_count: Optional[int] = None,
max_reshard_retry_count: Optional[int] = None,
):
self.devices = devices
self.total_slice_count = total_slice_count

if save_period is None:
save_period = 1
self.save_period = save_period

if reshard_check_period is None:
reshard_check_period = 1
self.reshard_check_period = reshard_check_period

if max_failure_count is None:
max_failure_count = float("inf")
self.max_failure_count = max_failure_count

if max_reshard_retry_count is None:
max_reshard_retry_count = float("inf")
self.max_reshard_retry_count = max_reshard_retry_count

self.failure_count = 0
Expand All @@ -121,14 +135,20 @@ def slice_down(self, reshard_retry: bool = False):
self.reshard_retry_count = 0

logger.info(f"{self.failure_count=} {self.max_failure_count=}")
if self.failure_count >= self.max_failure_count:
logger.fatal(f"Max failure count reached {self.max_failure_count}")
if (
self.max_failure_count is not None
and self.failure_count >= self.max_failure_count
):
logger.critical(f"Max failure count reached {self.max_failure_count}")

logger.info(
f"{self.reshard_retry_count=} {self.max_reshard_retry_count=}"
)
if self.reshard_retry_count > self.max_reshard_retry_count:
logger.fatal(
if (
self.max_reshard_retry_count is not None
and self.reshard_retry_count > self.max_reshard_retry_count
):
logger.critical(
f"Max reshard retry count reached {self.max_reshard_retry_count}"
)

Expand Down Expand Up @@ -311,31 +331,24 @@ def reshard(
if put_array is None:
put_array = cls.default_put_array

flat_x, tree_def = jax.tree_util.tree_flatten(x)
flat_sharding = jax.api_util.flatten_axes(
"reshard sharding", tree_def, sharding
return reshard.reshard(
x, sharding, donate_input=donate_input, put_array=put_array
)

if len(flat_x) != len(flat_sharding):
raise ValueError("Mismatched length between `x` and `sharding`.")

arrays = [
put_array(arr, dst_sharding, donate_input)
for arr, dst_sharding in zip(flat_x, flat_sharding)
]
return jax.tree_util.tree_unflatten(tree_def, arrays)

@staticmethod
def put_array_device_put0(
arr: jax.Array,
dst_sharding: jax.sharding.Sharding,
donate_input: bool,
):
if not isinstance(dst_sharding, jax.sharding.Sharding):
raise ValueError("`sharding` must contain only `Sharding` instances.")
return jax.device_put(arr, dst_sharding, donate=donate_input)

default_put_array = put_array_device_put0
def scale_by_good_slices(self, x: int | float) -> int | float:
"""Scale x by the number of good slices."""
if isinstance(x, int):
ret, remainder = divmod(x * self.good_slice_count, self.total_slice_count)
if remainder:
raise ValueError(
f"Cannot scale {x=} by good slices because it will result in a "
f"remainder of {remainder=}."
)
return ret
elif isinstance(x, float):
return x * self.good_slice_count / self.total_slice_count
else:
raise ValueError(f"Unsupported type: {type(x)}")

def put_array_device_put1(
self,
Expand Down Expand Up @@ -502,20 +515,7 @@ def put_array_device_put3(
arr.shape, dst_sharding, arrays
)

def scale_by_good_slices(self, x: int | float) -> int | float:
"""Scale x by the number of good slices."""
if isinstance(x, int):
ret, remainder = divmod(x * self.good_slice_count, self.total_slice_count)
if remainder:
raise ValueError(
f"Cannot scale {x=} by good slices because it will result in a "
f"remainder of {remainder=}."
)
return ret
elif isinstance(x, float):
return x * self.good_slice_count / self.total_slice_count
else:
raise ValueError(f"Unsupported type: {type(x)}")
default_put_array = put_array_device_put1


@contextlib.contextmanager
Expand Down Expand Up @@ -552,7 +552,7 @@ def handler():
logger.info(f"Error print traceback for {thread.ident=}")
pass
finally:
# logger.fatal("Timeout from timebomb!")
# logger.critical("Timeout from timebomb!")
# os.abort()
pass

Expand Down
Loading

0 comments on commit 0c975ce

Please sign in to comment.