Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change. #647

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion seqio/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,7 @@ def clear_mixtures():
"FakeLazyTfds",
[
"name",
"tfds_splits",
"resolved_tfds_name",
"data_dir",
"load",
Expand Down Expand Up @@ -1411,7 +1412,8 @@ def _load_shard(shard_instruction, shuffle_files, seed):

fake_tfds = FakeLazyTfds(
name="fake:0.0.0",
resolved_tfds_name="fake:0.0.0",
tfds_splits=None,
resolved_tfds_name=lambda: "fake:0.0.0",
data_dir="/tfds",
load=get_fake_dataset,
load_shard=_load_shard,
Expand Down
98 changes: 53 additions & 45 deletions seqio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class TfdsSplit:

dataset: str
split: Optional[str]
data_dir: Optional[tfds.typing.PathLike] = None
data_dir: Optional[str] = None

def __post_init__(self):
_validate_tfds_name(self.dataset)
Expand All @@ -120,13 +120,8 @@ class LazyTfdsLoader(object):
def __init__(
self,
name: Optional[str] = None,
data_dir=None,
split_map: Optional[
Union[
Mapping[str, str],
Mapping[str, "TfdsSplit"],
]
] = None,
data_dir: Optional[str] = None,
split_map: Union[Mapping[str, str], Mapping[str, TfdsSplit], None] = None,
decoders=None,
):
"""LazyTfdsLoader constructor.
Expand Down Expand Up @@ -164,16 +159,31 @@ def name(self) -> Optional[str]:
return self._name

@property
def resolved_tfds_name(self) -> Optional[str]:
def tfds_splits(self) -> Optional[Mapping[str, TfdsSplit]]:
return self._split_map if self._is_custom_split_map else None

def resolved_tfds_name(self, split: Optional[str] = None) -> Optional[str]:
"""Returns the resolved TFDS dataset name.

When the specified TFDS name doesn't specify everything, e.g. the version
has a wildcard or the config is not specified, then this function returns
the complete TFDS name if the dataset has already been loaded.

Args:
split: optional split name.

Returns:
complete TFDS name.
"""
if self.is_memoized:
return self.builder.get_reference().tfds_name(include_version=True)
return self.name
if self.is_memoized(split):
return (
self._get_builder(split)
.get_reference()
.tfds_name(include_version=True)
)
else:
dataset, _ = self.get_split_params(split)
return dataset

def __str__(self):
return (
Expand All @@ -189,9 +199,28 @@ def __repr__(self):
f" decoders={self._decoders})"
)

def get_split_params(
self, split: Optional[str] = None
) -> Tuple[Optional[str], Optional[str]]:
"""Returns a tuple of (dataset, data_dir) for the given split."""
if self._is_custom_split_map:
if mapped_split := self._split_map.get(split):
dataset = mapped_split.dataset
data_dir = mapped_split.data_dir
else:
raise ValueError(
"`LazyTfdsLoader` refers to multiple datasets, pass `split` to"
" `get_split_params()`."
)
else:
dataset = self.name
data_dir = self.data_dir

return dataset, data_dir


@property
def data_dir(self):
def data_dir(self) -> Optional[str]:
"""Returns the data directory for this TFDS dataset."""

if self._is_custom_split_map:
Expand Down Expand Up @@ -225,38 +254,22 @@ def _get_builder_key(
) -> Tuple[Optional[str], Optional[str]]:
return (dataset, data_dir)

@property
def is_memoized(self) -> bool:
if self._is_custom_split_map:
return all(
self._get_builder_key(tfds_split.dataset, tfds_split.data_dir)
in LazyTfdsLoader._MEMOIZED_BUILDERS
for tfds_split in self._split_map.values()
)
else:
return (
self._get_builder_key(self.name, self.data_dir)
in LazyTfdsLoader._MEMOIZED_BUILDERS
)
def is_memoized(self, split: Optional[str] = None) -> bool:
"""Returns true if the dataset is memoized."""
dataset, data_dir = self.get_split_params(split)

return (
self._get_builder_key(dataset, data_dir)
in LazyTfdsLoader._MEMOIZED_BUILDERS
)

@property
def builder(self):
return self._get_builder()

def _get_builder(self, split: Optional[str] = None):
"""Returns the DatasetBuilder for this TFDS dataset."""
if self._is_custom_split_map:
if mapped_split := self._split_map.get(split):
dataset = mapped_split.dataset
data_dir = mapped_split.data_dir
else:
raise ValueError(
"`LazyTfdsLoader` refers to multiple datasets, pass `split` to"
" `_get_builder()`."
)
else:
dataset = self.name
data_dir = self.data_dir
dataset, data_dir = self.get_split_params(split)
builder_key = self._get_builder_key(dataset, data_dir)
if builder_key not in LazyTfdsLoader._MEMOIZED_BUILDERS:
if dataset:
Expand Down Expand Up @@ -311,12 +324,7 @@ def load(
):
"""Returns a tf.data.Dataset for the given split."""
dataset_split = self._map_split(split)
if self._is_custom_split_map:
name = self._split_map[split].dataset
data_dir = self._split_map[split].data_dir
else:
name = self.name
data_dir = self.data_dir
dataset, data_dir = self.get_split_params(split)
read_config = self.read_config
read_config.input_context = (
tf.distribute.InputContext( # pylint: disable=g-long-ternary
Expand All @@ -329,7 +337,7 @@ def load(
read_config.shuffle_seed = seed
read_config.skip_prefetch = True
return tfds.load(
name,
dataset,
split=dataset_split,
data_dir=data_dir,
shuffle_files=shuffle_files,
Expand Down
4 changes: 2 additions & 2 deletions seqio/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def test_no_tfds_version(self):
@mock.patch("tensorflow_datasets.builder")
def test_wildcard_in_version(self, mock_tfds_builder):
loader = utils.LazyTfdsLoader(name="a/b:1.0.*")
self.assertEqual("a/b:1.0.*", loader.resolved_tfds_name)
self.assertEqual("a/b:1.0.*", loader.resolved_tfds_name())
mock_tfds_builder.assert_not_called()
# Get the builder to make sure it's memoized
_ = loader.builder
mock_reference = mock.MagicMock()
mock_tfds_builder.return_value.get_reference.return_value = mock_reference
mock_reference.tfds_name.return_value = "a/b:1.0.2"
self.assertEqual("a/b:1.0.2", loader.resolved_tfds_name)
self.assertEqual("a/b:1.0.2", loader.resolved_tfds_name())

@mock.patch("tensorflow_datasets.builder")
def test_builder_memoization(self, mock_tfds_builder):
Expand Down
Loading