Skip to content

Commit

Permalink
Added a max reshard retry count
Browse files Browse the repository at this point in the history
  • Loading branch information
lukebaumann committed Feb 21, 2025
1 parent d7b7eb3 commit 506da52
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 11 deletions.
33 changes: 24 additions & 9 deletions MaxText/elasticutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def __init__(
total_slice_count: int,
save_period: Optional[int] = None,
reshard_check_period: Optional[int] = None,
max_failures: Optional[int] = None,
max_failure_count: Optional[int] = None,
max_reshard_retry_count: Optional[int] = None,
):
self.devices = devices
self.total_slice_count = total_slice_count
Expand All @@ -96,25 +97,40 @@ def __init__(
reshard_check_period = 1
self.reshard_check_period = reshard_check_period

if max_failures is None:
max_failures = float("inf")
self.max_failures = max_failures
if max_failure_count is None:
max_failure_count = float("inf")
self.max_failure_count = max_failure_count

if max_reshard_retry_count is None:
max_reshard_retry_count = float("inf")
self.max_reshard_retry_count = max_reshard_retry_count

self.failure_count = 0
self.reshard_retry_count = 0
self.good_slice_indices = self.get_slice_availability()
self.data = {}

def slice_down(self):
def slice_down(self, reshard_retry: bool = False):
"""Slice down."""
logger.info("Slice down")
self.good_slice_indices = self.get_slice_availability()
self.failure_count += 1
if reshard_retry:
self.reshard_retry_count += 1
else:
self.reshard_retry_count = 0

logger.info(f"{self.failure_count=} {self.max_failure_count=}")
if self.failure_count >= self.max_failure_count:
logger.fatal(f"Max failure count reached {self.max_failure_count}")

logger.info(
f"Failure count: {self.failure_count} with max {self.max_failures}"
f"{self.reshard_retry_count=} {self.max_reshard_retry_count=}"
)
if self.failure_count >= self.max_failures:
logger.fatal(f"Max failures reached {self.max_failures}")
if self.reshard_retry_count > self.max_reshard_retry_count:
logger.fatal(
f"Max reshard retry count reached {self.max_reshard_retry_count}"
)

@timeit
def save(self, save_step: int, blocking: bool = True, **kwargs):
Expand Down Expand Up @@ -486,7 +502,6 @@ def put_array_device_put3(
arr.shape, dst_sharding, arrays
)


def scale_by_good_slices(self, x: int | float) -> int | float:
"""Scale x by the number of good slices."""
if isinstance(x, int):
Expand Down
4 changes: 3 additions & 1 deletion MaxText/elasticutils_fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def __init__(
total_slice_count: int,
save_period: Optional[int] = None,
reshard_check_period: Optional[int] = None,
max_failures: Optional[int] = None,
max_failure_count: Optional[int] = None,
max_reshard_retry_count: Optional[int] = None,
):
self.fake_good_slice_indices = set(d.slice_index for d in devices)

Expand All @@ -38,6 +39,7 @@ def __init__(
save_period,
reshard_check_period,
max_failures,
max_reshard_retry_count,
)

def update_good_slice_indices(self, good_slice_indices: set[int]):
Expand Down
2 changes: 1 addition & 1 deletion MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ def reshard(arr):
max_logging.log("Unknown JaxRuntimeError during resharding!")
raise

config.eu.slice_down()
config.eu.slice_down(reshard_retry=True)

return (
restore_step,
Expand Down

0 comments on commit 506da52

Please sign in to comment.