Skip to content

Commit

Permalink
Merge branch 'main' into ethantang-db/stream_clients_refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 authored Nov 1, 2024
2 parents 3a59ab8 + fd9b55a commit 9bb255a
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 22 deletions.
2 changes: 1 addition & 1 deletion streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def __init__(self,
]
self._shm_prefix_int, self._locals_shm = get_shm_prefix(streams_local, streams_remote,
self._unique_rank_world)
self._filelock_root = os.path.join(gettempdir(), 'streaming')
self._filelock_root = gettempdir()
os.makedirs(self._filelock_root, exist_ok=True)

# Create the shared memory-backed barrier, without its lock, which is unpickleable.
Expand Down
71 changes: 53 additions & 18 deletions streaming/base/shared/prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
prevent shared resources like shared memory from colliding.
"""

import os
from collections import Counter
from tempfile import gettempdir
from time import sleep
from typing import Iterator, Union

import numpy as np
from torch import distributed as dist

from streaming.base.constant import LOCALS, TICK
from streaming.base.constant import BARRIER_FILELOCK, CACHE_FILELOCK, LOCALS, SHM_TO_CLEAN, TICK
from streaming.base.shared import SharedMemory
from streaming.base.world import World

Expand Down Expand Up @@ -91,7 +93,8 @@ def _check_self(streams_local: list[str]) -> None:
f'Reused local directory: {duplicate_local_dirs}. Provide a different one.')


def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, None]]) -> int:
def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, None]],
shm_name: str) -> int:
"""Find the next available prefix while checking existing local dirs for overlap.
Local leader walks the existing shm prefixes starting from zero, verifying that there is no
Expand All @@ -101,18 +104,40 @@ def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, No
Args:
streams_local (List[str]): Our local working directories.
streams_remote (List[Union[str, None]]): Our remote working directories.
shm_name (str): The shared memory file name, e.g., LOCALS, BARRIER etc.
Returns:
int: Next available prefix int.
"""
prefix_int = 0

for prefix_int in _each_prefix_int():
name = _get_path(prefix_int, LOCALS)

name = _get_path(prefix_int, shm_name)

# Check if any shared memory filelocks exist for the current prefix
try:
filelock_exists = any(
os.path.exists(os.path.join(gettempdir(), _get_path(prefix_int, filelock_name)))
for filelock_name in [BARRIER_FILELOCK, CACHE_FILELOCK])
if filelock_exists:
continue
except PermissionError:
continue

# Attempt to access shared memory by name. Use prefix_int if files do not exist
try:
shm = SharedMemory(name, False)
except PermissionError:
continue
except FileNotFoundError:
break

if shm_name != LOCALS:
continue

their_locals, _ = _unpack_locals(bytes(shm.buf))

# Do not check for a conflicting local directories across existing shared memory if
# remote directories are None. Get the next prefix.
if any(streams_remote):
Expand All @@ -135,7 +160,7 @@ def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, No


def _check_and_find_retrying(streams_local: list[str], streams_remote: list[Union[str, None]],
retry: int) -> int:
shm_name: str, retry: int) -> int:
"""Find the next available prefix while checking existing dirs for overlap.
If an overlap is found, sleeps for a tick and then tries again, up to "retry" times. We allow
Expand All @@ -145,6 +170,7 @@ def _check_and_find_retrying(streams_local: list[str], streams_remote: list[Unio
Args:
streams_local (List[str]): Our local working directories.
streams_remote (List[Union[str, None]]): Our remote working directories.
shm_name (str): The shared memory file name, e.g., LOCALS, BARRIER etc.
retry (int): Number of retries upon failure before raising an exception.
Returns:
Expand All @@ -155,7 +181,7 @@ def _check_and_find_retrying(streams_local: list[str], streams_remote: list[Unio
errs = []
for _ in range(1 + retry):
try:
return _check_and_find(streams_local, streams_remote)
return _check_and_find(streams_local, streams_remote, shm_name)
except ValueError as err:
errs.append(err)
sleep(TICK)
Expand Down Expand Up @@ -184,9 +210,16 @@ def get_shm_prefix(streams_local: list[str],
# Check my locals for overlap.
_check_self(streams_local)

prefix_int = max([
_check_and_find_retrying(streams_local, streams_remote, shm_name=shm_name, retry=retry)
for shm_name in SHM_TO_CLEAN
])

if dist.is_available() and dist.is_initialized():
dist.barrier()

# First, the local leader registers the first available shm prefix, recording its locals.
if world.is_local_leader:
prefix_int = _check_and_find_retrying(streams_local, streams_remote, retry)
name = _get_path(prefix_int, LOCALS)
data = _pack_locals(streams_local, prefix_int)
shm = SharedMemory(name, True, len(data))
Expand All @@ -197,16 +230,18 @@ def get_shm_prefix(streams_local: list[str],

# Non-local leaders go next, searching for match.
if not world.is_local_leader:
for prefix_int in _each_prefix_int():
name = _get_path(prefix_int, LOCALS)
try:
shm = SharedMemory(name, False)
except FileNotFoundError:
raise RuntimeError(f'Internal error: shared memory prefix was not registered by ' +
f'local leader. This may be because you specified ' +
f'different ``local`` parameters from different ranks.')
their_locals, their_prefix_int = _unpack_locals(bytes(shm.buf))
if streams_local == their_locals and prefix_int == their_prefix_int:
break

name = _get_path(prefix_int, LOCALS)
try:
shm = SharedMemory(name, False)
except FileNotFoundError:
raise RuntimeError(f'Internal error: shared memory prefix={prefix_int} was not ' +
f'registered by local leader. This may be because you specified ' +
f'different ``local`` parameters from different ranks.')

their_locals, their_prefix_int = _unpack_locals(bytes(shm.buf))
if streams_local != their_locals or prefix_int != their_prefix_int:
raise RuntimeError(f'Internal error: shared memory registered does not match ' +
f'local leader as streams_local or prefix_int not match. ' +
f'local leader: {their_locals} and {their_prefix_int}. ' +
f'expected: {streams_local} and {prefix_int}.')
return prefix_int, shm # pyright: ignore
9 changes: 6 additions & 3 deletions streaming/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,12 @@ def clean_stale_shared_memory() -> None:
try:
shm = BuiltinSharedMemory(name, True, 4)
except FileExistsError:
shm = BuiltinSharedMemory(name, False, 4)
leaked_shm = True
finally:
try:
shm = BuiltinSharedMemory(name, False, 4)
leaked_shm = True
except PermissionError:
continue
if shm:
shm.close() # pyright: ignore
shm.unlink()
# Come out of loop if no leaked shared memory
Expand Down
33 changes: 33 additions & 0 deletions tests/test_shared.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

import os
import tempfile
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from streaming.base import StreamingDataset
from streaming.base.constant import LOCALS
from streaming.base.shared import SharedArray, get_shm_prefix
from streaming.base.shared.memory import SharedMemory
from streaming.base.shared.prefix import _check_and_find
from streaming.base.util import clean_stale_shared_memory
from streaming.base.world import World
from tests.common.utils import convert_to_mds

Expand Down Expand Up @@ -157,3 +163,30 @@ def test_shared_array_size_is_integer(mock_shared_memory: MagicMock, dtype: type
mock_shared_memory.assert_called_once() # pyright: ignore
size_arg = mock_shared_memory.call_args[1]['size']
assert isinstance(size_arg, int), 'Size passed to SharedMemory is not an integer'


def test_check_and_find_skips_filelock_conflict():
"""Test _check_and_find skips prefix due to file lock conflict."""
clean_stale_shared_memory()

with patch('os.path.exists') as mock_exists, \
patch('multiprocessing.shared_memory.SharedMemory', side_effect=FileNotFoundError):
# Simulate that `/000000.barrier_filelock` exists, indicating a lock conflict
bf_path = os.path.join(tempfile.gettempdir(), '000000_barrier_filelock')
mock_exists.side_effect = lambda path: path == bf_path

# Expect _check_and_find to return 1 as the next available prefix
next_prefix = _check_and_find(['local_dir'], [None], LOCALS)
assert next_prefix == 1


@patch.object(SharedMemory,
'__init__',
side_effect=[
PermissionError('Mocked permission error'),
FileNotFoundError('Mocked file not found error')
])
def test_shared_memory_permission_error(mock_shared_memory_class: MagicMock):
with patch('os.path.exists', return_value=False):
next_prefix = _check_and_find(['local'], [None], LOCALS)
assert next_prefix == 1

0 comments on commit 9bb255a

Please sign in to comment.