Skip to content

Commit

Permalink
fix filtered formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-hh committed Oct 8, 2024
1 parent 3b65d99 commit d906b9f
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,6 @@ def _iter(self):
num_examples_to_skip = self._state_dict["num_examples_since_previous_state"]
else:
num_examples_to_skip = 0
iterator = iter(self.ex_iterable)

formatter = get_formatter(self.formatting.format_type) if self.formatting else None
if self.formatting and self.ex_iterable.iter_arrow:
Expand Down Expand Up @@ -1276,9 +1275,13 @@ def _iter(self):
function_args.append([current_idx + i for i in range(batch_len)])
mask = self.function(*function_args, **self.fn_kwargs)
# yield one example at a time from the batch
example_keys = combined_key.split("_")
examples = _batch_to_examples(batch)
for key, example, to_keep in zip(example_keys, examples, mask):
# TODO: nicer way to handle keys?
if not self.formatting:
keys = combined_key.split("_")
else:
keys = [combined_key] * len(mask)
for key, example, to_keep in zip(keys, examples, mask):
current_idx += 1
if self._state_dict:
self._state_dict["num_examples_since_previous_state"] += 1
Expand All @@ -1292,7 +1295,7 @@ def _iter(self):
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] = current_idx
else:
for key, example in iterator:
for key, example in batched_examples_iterator:
# If not batched, we can apply the filtering function direcly
example = dict(example)
inputs = example
Expand Down

0 comments on commit d906b9f

Please sign in to comment.