-
Notifications
You must be signed in to change notification settings - Fork 326
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updated elasticutils to how it will be structured in pathwaysutils
- Loading branch information
1 parent
e1ca0d2
commit 0c975ce
Showing
5 changed files
with
198 additions
and
142 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright 2025 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Resharding API for elastic training.""" | ||
|
||
from typing import Any | ||
from typing import Callable, Optional, Sequence | ||
import jax | ||
|
||
|
||
def default_put_array( | ||
arr: jax.Array, | ||
dst_sharding: jax.sharding.Sharding, | ||
donate_input: bool, | ||
): | ||
if not isinstance(dst_sharding, jax.sharding.Sharding): | ||
raise ValueError("`sharding` must contain only `Sharding` instances.") | ||
return jax.device_put(arr, dst_sharding, donate=donate_input) | ||
|
||
|
||
def reshard( | ||
x: Any, | ||
sharding: jax.sharding.Sharding | Any, | ||
*, | ||
donate_input: bool = True, | ||
put_array: Optional[ | ||
Callable[[jax.Array, Sequence[jax.sharding.Sharding], bool], jax.Array] | ||
] = None, | ||
) -> Any: | ||
"""Reshards `x` to the specified `sharding`. | ||
Args: | ||
x: An array, scalar, or a nested Python container thereof. | ||
sharding: A `Sharding` or a nested `Sharding` in a Python container (must | ||
match the structure of `x`), specifying the target sharding. | ||
donate_input: If `True`, donates the input arrays to reduce memory needed | ||
for resharding. Donated buffers should not be reused. | ||
put_array: A function that takes an array, a sharding, and a boolean | ||
indicating whether to donate the input, and returns a copy of the array | ||
with the specified sharding. | ||
Returns: | ||
A copy of `x` with the specified `sharding`. | ||
""" | ||
if put_array is None: | ||
put_array = default_put_array | ||
|
||
flat_x, tree_def = jax.tree_util.tree_flatten(x) | ||
flat_sharding = jax.api_util.flatten_axes( | ||
"reshard sharding", tree_def, sharding | ||
) | ||
|
||
if len(flat_x) != len(flat_sharding): | ||
raise ValueError("Mismatched length between `x` and `sharding`.") | ||
|
||
arrays = [ | ||
put_array(arr, dst_sharding, donate_input) | ||
for arr, dst_sharding in zip(flat_x, flat_sharding) | ||
] | ||
return jax.tree_util.tree_unflatten(tree_def, arrays) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright 2025 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Utilities for elastic training.""" | ||
|
||
import logging | ||
from typing import Any, Optional, Sequence | ||
|
||
import jax | ||
from pathwaysutils.google_internal.elastic import utils | ||
|
||
|
||
PyTree = Any | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
logger.setLevel(logging.INFO) | ||
|
||
# pylint: disable=logging-fstring-interpolation | ||
|
||
|
||
class ElasticUtilsSimulator(utils.ElasticUtils): | ||
"""Utility class for elastic training. | ||
This class will simulate slices going down and coming back up. | ||
""" | ||
simulated_good_slice_indices: set[int] | ||
|
||
def __init__( | ||
self, | ||
devices: Sequence[jax.Device], | ||
total_slice_count: int, | ||
save_period: Optional[int] = None, | ||
reshard_check_period: Optional[int] = None, | ||
max_failures: Optional[int] = None, | ||
): | ||
self.simulated_good_slice_indices = set(d.slice_index for d in devices) | ||
|
||
super().__init__( | ||
devices, | ||
total_slice_count, | ||
save_period, | ||
reshard_check_period, | ||
max_failures, | ||
) | ||
|
||
def update_good_slice_indices(self, good_slice_indices: set[int]): | ||
"""Start step handler.""" | ||
self.simulated_good_slice_indices = good_slice_indices | ||
logger.info(f"Updated: {self.simulated_good_slice_indices=}") | ||
|
||
@utils.timeit | ||
def get_slice_availability(self) -> set[int]: | ||
"""Returns the set of good and bad slices.""" | ||
good_slice_indices = self.simulated_good_slice_indices | ||
|
||
logger.info(f"{good_slice_indices=}") | ||
|
||
return good_slice_indices | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.