Skip to content

Commit

Permalink
feat: support non streamable arrow file binary format (#7025)
Browse files Browse the repository at this point in the history
* feat: support non streamable arrow file binary format

Signed-off-by: Mehant Kammakomati <[email protected]>

* use generator

Co-authored-by: Quentin Lhoest <[email protected]>

* feat: add unit test to load data in both arrow formats

Signed-off-by: Mehant Kammakomati <[email protected]>

---------

Signed-off-by: Mehant Kammakomati <[email protected]>
Co-authored-by: Quentin Lhoest <[email protected]>
  • Loading branch information
kmehant and lhoestq authored Jul 31, 2024
1 parent 65b9499 commit ce4a0c5
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 3 deletions.
13 changes: 11 additions & 2 deletions src/datasets/packaged_modules/arrow/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ def _split_generators(self, dl_manager):
if self.info.features is None:
for file in itertools.chain.from_iterable(files):
with open(file, "rb") as f:
self.info.features = datasets.Features.from_arrow_schema(pa.ipc.open_stream(f).schema)
try:
reader = pa.ipc.open_stream(f)
except pa.lib.ArrowInvalid:
reader = pa.ipc.open_file(f)
self.info.features = datasets.Features.from_arrow_schema(reader.schema)
break
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits
Expand All @@ -59,7 +63,12 @@ def _generate_tables(self, files):
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
with open(file, "rb") as f:
try:
for batch_idx, record_batch in enumerate(pa.ipc.open_stream(f)):
try:
batches = pa.ipc.open_stream(f)
except pa.lib.ArrowInvalid:
reader = pa.ipc.open_file(f)
batches = (reader.get_batch(i) for i in range(reader.num_record_batches))
for batch_idx, record_batch in enumerate(batches):
pa_table = pa.Table.from_batches([record_batch])
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
Expand Down
47 changes: 46 additions & 1 deletion tests/packaged_modules/test_arrow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,53 @@
import pyarrow as pa
import pytest

from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.arrow.arrow import ArrowConfig
from datasets.packaged_modules.arrow.arrow import Arrow, ArrowConfig


@pytest.fixture
def arrow_file_streaming_format(tmp_path):
filename = tmp_path / "stream.arrow"
testdata = [[1, 1, 1], [0, 100, 6], [1, 90, 900]]

schema = pa.schema([pa.field("input_ids", pa.list_(pa.int32()))])
array = pa.array(testdata, type=pa.list_(pa.int32()))
table = pa.Table.from_arrays([array], schema=schema)
with open(filename, "wb") as f:
with pa.ipc.new_stream(f, schema) as writer:
writer.write_table(table)
return str(filename)


@pytest.fixture
def arrow_file_file_format(tmp_path):
filename = tmp_path / "file.arrow"
testdata = [[1, 1, 1], [0, 100, 6], [1, 90, 900]]

schema = pa.schema([pa.field("input_ids", pa.list_(pa.int32()))])
array = pa.array(testdata, type=pa.list_(pa.int32()))
table = pa.Table.from_arrays([array], schema=schema)
with open(filename, "wb") as f:
with pa.ipc.new_file(f, schema) as writer:
writer.write_table(table)
return str(filename)


@pytest.mark.parametrize(
"file_fixture, config_kwargs",
[
("arrow_file_streaming_format", {}),
("arrow_file_file_format", {}),
],
)
def test_arrow_generate_tables(file_fixture, config_kwargs, request):
arrow = Arrow(**config_kwargs)
generator = arrow._generate_tables([[request.getfixturevalue(file_fixture)]])
pa_table = pa.concat_tables([table for _, table in generator])

expected = {"input_ids": [[1, 1, 1], [0, 100, 6], [1, 90, 900]]}
assert pa_table.to_pydict() == expected


def test_config_raises_when_invalid_name() -> None:
Expand Down

0 comments on commit ce4a0c5

Please sign in to comment.