Skip to content

Commit

Permalink
Add option to allow newlines in captions
Browse files Browse the repository at this point in the history
The YFCC-15M descriptions can have new lines in the caption, which
causes parquet's csv module to error by default. This commit allows
passing --newlines-in-captions True to img2dataset, which will tell
parquet to allow newlines in CSV values.
  • Loading branch information
achalddave committed Mar 8, 2023
1 parent fc3fb2e commit 0e15d4a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
2 changes: 2 additions & 0 deletions img2dataset/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def download(
max_shard_retry: int = 1,
user_agent_token: Optional[str] = None,
disallowed_header_directives: Optional[List[str]] = None,
newlines_in_captions: bool = False,
):
"""Download is the main entry point of img2dataset, it uses multiple processes and download multiple files"""
if disallowed_header_directives is None:
Expand Down Expand Up @@ -183,6 +184,7 @@ def signal_handler(signal_arg, frame): # pylint: disable=unused-argument
number_sample_per_shard,
done_shards,
tmp_path,
newlines_in_captions,
)

if output_format == "webdataset":
Expand Down
17 changes: 14 additions & 3 deletions img2dataset/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
number_sample_per_shard,
done_shards,
tmp_path,
newlines_in_captions,
) -> None:
self.input_format = input_format
self.url_col = url_col
Expand All @@ -47,6 +48,7 @@ def __init__(
self.save_additional_columns = save_additional_columns
self.number_sample_per_shard = number_sample_per_shard
self.done_shards = done_shards
self.newlines_in_captions = newlines_in_captions

fs, url_path = fsspec.core.url_to_fs(url_list)
self.fs = fs
Expand Down Expand Up @@ -79,13 +81,22 @@ def _save_to_arrow(self, input_file, start_shard_id):
if self.input_format in ["txt", "json", "csv", "tsv"]:
with self.fs.open(input_file, mode="rb") as file:
if self.input_format == "txt":
df = csv_pq.read_csv(file, read_options=csv_pq.ReadOptions(column_names=["url"]))
df = csv_pq.read_csv(
file,
read_options=csv_pq.ReadOptions(column_names=["url"]),
parse_options=csv_pq.ParseOptions(newlines_in_values=self.newlines_in_captions),
)
elif self.input_format == "json":
df = pa.Table.from_pandas(pd.read_json(file))
elif self.input_format == "csv":
df = csv_pq.read_csv(file)
df = csv_pq.read_csv(
file, parse_options=csv_pq.ParseOptions(newlines_in_values=self.newlines_in_captions)
)
elif self.input_format == "tsv":
df = csv_pq.read_csv(file, parse_options=csv_pq.ParseOptions(delimiter="\t"))
df = csv_pq.read_csv(
file,
parse_options=csv_pq.ParseOptions(delimiter="\t", newlines_in_values=self.newlines_in_captions),
)
else:
raise ValueError(f"Unknown input format {self.input_format}")
elif self.input_format == "tsv.gz":
Expand Down
1 change: 1 addition & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_reader(input_format, tmp_path):
number_sample_per_shard=batch_size,
done_shards=done_shards,
tmp_path=test_folder,
newlines_in_captions=False,
)

if input_format == "txt":
Expand Down

0 comments on commit 0e15d4a

Please sign in to comment.