Skip to content

Commit

Permalink
fixed some edge case bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
clearhanhui committed Jul 27, 2024
1 parent 7dceb80 commit a14bd43
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 81 deletions.
2 changes: 1 addition & 1 deletion benchmarks/benchmark_asyncload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from parquet_loader import ParquetDataset, ParquetDataLoader
from parquet_loader.reader import AsyncParquetReader, SyncParquetReader

path = 'synthetic_data'
path = '../synthetic_data'
delay_in_seconds = 0.01


Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benckmark_streamload.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from memory_profiler import profile as mem_profile

## config
path = 'synthetic_data'
path = '../synthetic_data'
num_workers = 4
batch_size = 66

Expand Down
11 changes: 7 additions & 4 deletions parquet_loader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,19 @@ def squeeze_first_dim(data):
else:
return data


class _SingleProcessDataLoaderIter(TorchSingleProcessDataLoaderIter):
def _next_data(self):
data = super()._next_data()
return squeeze_first_dim(data)



class _MultiProcessingDataLoaderIter(TorchMultiProcessingDataLoaderIter):
def _next_data(self):
data = super()._next_data()
return squeeze_first_dim(data)



class ParquetDataLoader(DataLoader):
def __init__(
self,
Expand All @@ -60,7 +61,7 @@ def __init__(
):
dataset.set_shuffle(shuffle)
dataset.set_drop_last(drop_last)
self._num_batches = dataset.get_num_batches(batch_size)
dataset.set_batch_size(batch_size)

# reset arguments
shuffle = False
Expand Down Expand Up @@ -91,7 +92,9 @@ def __init__(


def __len__(self) -> int:
return self._num_batches
if not hasattr(self, 'num_batches'):
self.num_batches = self.dataset.get_num_batches()
return self.num_batches

def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
Expand Down
95 changes: 59 additions & 36 deletions parquet_loader/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import List, Optional

import pandas as pd
Expand All @@ -15,38 +16,30 @@
from .reader import SyncParquetReader, AsyncParquetReader


logger = logging.getLogger(__name__)


class ParquetDataset(IterableDataset):
def __init__(
self,
path: str,
column_names: Optional[List[str]] = None,
async_read: bool = False,
max_preload: int = 1
):
_ds = ds.dataset(path)
self.metas = [
ParquetMetadata(
file_path=f.path,
num_rows=f.metadata.num_rows,
num_row_groups=f.metadata.num_row_groups,
num_rows_per_row_group=[
f.metadata.row_group(i).num_rows
for i in range(f.metadata.num_row_groups)
]
)
for f in _ds.get_fragments()
]
self.column_names = column_names or _ds.schema.names
del _ds
self.world_size, self.global_rank, self.num_nodes = detect_distributed_env()
):
self.path = path
self.column_names = column_names
self.async_read = async_read
self.max_preload = max_preload
Reader = AsyncParquetReader if async_read and max_preload > 0 else SyncParquetReader
self.reader = Reader(self.max_preload)
self.reader = AsyncParquetReader(self.max_preload) \
if async_read and max_preload > 0 else \
SyncParquetReader(self.max_preload)
self.world_size, self.global_rank, self.num_nodes = detect_distributed_env()
self.batch_size = 1


def __len__(self):
if not getattr(self, 'num_rows', None):
if not hasattr(self, 'num_rows'):
self._associate_to_workers()
return self.num_rows

Expand All @@ -56,18 +49,42 @@ def set_shuffle(self, shuffle: bool) -> None:

def set_drop_last(self, drop_last: bool) -> None:
self.drop_last = drop_last


def get_num_batches(self, batch_size=1):
assert hasattr(self, 'drop_last'), 'call `set_drop_last` before call `get_num_batches`'
def set_batch_size(self, batch_size: int) -> None:
self.batch_size = batch_size

def get_num_batches(self, batch_size=None):
assert hasattr(self, 'drop_last'), 'call `set_drop_last` before call `get_num_batches`'
batch_size = batch_size or self.batch_size
if self.drop_last:
return len(self) // batch_size
else:
return (len(self) + batch_size - 1) // batch_size


def _try_fetch_metadata(self):
if hasattr(self, 'metas'):
return

_ds = ds.dataset(self.path)
self.metas = [
ParquetMetadata(
file_path=f.path,
num_rows=f.metadata.num_rows,
num_row_groups=f.metadata.num_row_groups,
num_rows_per_row_group=[
f.metadata.row_group(i).num_rows
for i in range(f.metadata.num_row_groups)
]
)
for f in _ds.get_fragments()
]
self.column_names = self.column_names or _ds.schema.names
del _ds


def _associate_to_workers(self):
self._try_fetch_metadata()
self.num_workers, self.worker_rank = detect_worker_env()
self.intervals, self.num_rows = self.shuffler.associate_to_workers(
metas=self.metas,
Expand All @@ -88,16 +105,19 @@ def __iter__(self):


def __getitem__(self, index: int) -> pd.DataFrame:
logger.warning("call `__getitem__` is inefficient, only for test usage.")
self._try_fetch_metadata()
global_index = 0
for itv in self.intervals:
offset = itv.local_row_end - itv.local_row_start
global_index += offset
if global_index > index:
f = pq.ParquetFile(self.metas[itv.file_index])
table = f.read_row_group(itv.row_group_index).select(self.column_names)
f.close()
return table.slice(index - (global_index - offset), 1)\
.to_pandas(split_blocks=True, self_destruct=True).to_numpy()
for itvs in self.intervals:
for itv in itvs:
offset = itv.local_row_end - itv.local_row_start
global_index += offset
if global_index > index:
f = pq.ParquetFile(self.metas[itv.file_index])
table = f.read_row_group(itv.row_group_index).select(self.column_names)
f.close()
return table.slice(index - (global_index - offset), 1)\
.to_pandas(split_blocks=True, self_destruct=True).to_numpy()


def iter_batch(self):
Expand All @@ -123,8 +143,11 @@ def iter_batch(self):
tables = [table_left] # reset tables

# last batch, it may not be full batch
table_left = pa.concat_tables(tables)
batch_data = table_left.slice(0, self.batch_size)\
.to_pandas(split_blocks=True, self_destruct=True).to_numpy()
yield batch_data
if len(tables) > 0:
table_left = pa.concat_tables(tables)
if table_left.shape[0] == 0:
return
batch_data = table_left.slice(0, self.batch_size)\
.to_pandas(split_blocks=True, self_destruct=True).to_numpy()
yield batch_data

17 changes: 11 additions & 6 deletions parquet_loader/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Reader:
def __init__(self, max_preload: int = 1):
self.max_preload = max_preload

def setup(self, metas: List[ParquetMetadata], intervals: Dict[int, List[RowGroupInterval]]):
def setup(self, metas: List[ParquetMetadata], intervals: List[List[RowGroupInterval]]):
self.metas = metas
self.intervals = intervals

Expand All @@ -39,8 +39,8 @@ class SyncParquetReader(Reader):

@property
def table_iterator(self):
for fi, itvs in self.intervals.items():
with self._open_parquet_file(self.metas[fi].file_path) as pf:
for itvs in self.intervals:
with self._open_parquet_file(self.metas[itvs[0].file_index].file_path) as pf:
for itv in itvs:
offset = itv.local_row_end - itv.local_row_start
yield pf.read_row_group(itv.row_group_index).slice(itv.local_row_start, offset)
Expand All @@ -52,10 +52,15 @@ class AsyncParquetReader(Reader):
_END_TOKEN = "_END"
_DEFAULT_TIMEOUT = 1

def _preload(self, metas, intervals, queue):
def _preload(
self,
metas: List[ParquetMetadata],
intervals: List[List[RowGroupInterval]],
queue: Queue
):
try:
for fi, itvs in intervals.items():
with self._open_parquet_file(metas[fi].file_path) as pf:
for itvs in intervals:
with self._open_parquet_file(metas[itvs[0].file_index].file_path) as pf:
for itv in itvs:
offset = itv.local_row_end - itv.local_row_start
table = pf.read_row_group(itv.row_group_index).slice(itv.local_row_start, offset)
Expand Down
8 changes: 4 additions & 4 deletions parquet_loader/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def associate_to_workers(
current_worker_rank: int = 0,
drop_last: bool = False,
batch_size: int = 1,
) -> Tuple[Dict[int, List[RowGroupInterval]], int]:
) -> Tuple[List[List[RowGroupInterval]], int]:
raise NotImplementedError

def shuffle(self, data: np.ndarray) -> np.ndarray:
Expand All @@ -41,7 +41,7 @@ def associate_to_workers(
current_worker_rank: int = 0,
drop_last: bool = False,
batch_size: int = 1,
) -> Tuple[Dict[int, List[RowGroupInterval]], int]:
) -> Tuple[List[List[RowGroupInterval]], int]:

return associate_to_workers(
metas=metas,
Expand All @@ -68,7 +68,7 @@ def associate_to_workers(
current_worker_rank: int = 0,
drop_last: bool = False,
batch_size: int = 1,
) -> Tuple[Dict[int, List[RowGroupInterval]], int]:
) -> Tuple[List[List[RowGroupInterval]], int]:

# shuffle files
metas = copy.deepcopy(metas)
Expand All @@ -85,7 +85,7 @@ def associate_to_workers(
)

# shuffle row groups
for fi, itvs in intervals.items():
for itvs in intervals:
self.rng.shuffle(itvs)

return intervals, num_rows
Expand Down
71 changes: 47 additions & 24 deletions parquet_loader/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import copy
from typing import List, Tuple, Dict

import torch
Expand Down Expand Up @@ -34,27 +35,32 @@ def associate_to_workers(
current_worker_rank: int = 0,
drop_last: bool = False,
batch_size: int = 1,
) -> Tuple[Dict[int, List[RowGroupInterval]], int]:
) -> Tuple[List[List[RowGroupInterval]], int]:

total_num_rows = sum([meta.num_rows for meta in metas])
rank0_extra_rows = total_num_rows % world_size # rank 0 may get extra rows
num_rows_per_ranks = [
total_num_rows // world_size + total_num_rows % world_size # TODO: check its correctness
if rank == world_size - 1 and not drop_last
total_num_rows // world_size + rank0_extra_rows
if rank == 0 and not drop_last
else total_num_rows // world_size
for rank in range(world_size)
]
ratio = num_workers * batch_size
if drop_last:
ratio = num_workers * batch_size
num_rows_per_ranks = [ratio * int(item // ratio) for item in num_rows_per_ranks]
num_rows_per_workers = np.array([
[
num_rows_per_ranks[rank] // num_workers + num_rows_per_ranks[rank] % num_workers
if worker_rank == 0 and not drop_last # worker 0 gets the remainder
else num_rows_per_ranks[rank] // num_workers
for worker_rank in range(num_workers)
]
for rank in range(world_size)
])
num_rows_per_ranks = [ratio * int(rows // ratio) for rows in num_rows_per_ranks]

num_rows_per_rank_worker = []
for rank in range(world_size):
rank_extra_rows = num_rows_per_ranks[rank] % ratio
num_rows_per_worker = []
for worker_rank in range(num_workers):
worker_num_rows = num_rows_per_ranks[rank] // ratio * batch_size
if rank_extra_rows > 0:
worker_num_rows += min(rank_extra_rows, batch_size)
rank_extra_rows -= batch_size
num_rows_per_worker.append(worker_num_rows)
num_rows_per_rank_worker.append(num_rows_per_worker)
num_rows_per_rank_worker = np.array(num_rows_per_rank_worker)

row_group_intervals = []
global_rows = 0
Expand All @@ -68,11 +74,12 @@ def associate_to_workers(
row_group_intervals.append(intervals_per_file)


start_row_index_per_workers = [0] + np.cumsum(num_rows_per_workers).tolist()
current_global_row_start = start_row_index_per_workers[current_rank * num_workers + current_worker_rank]
current_global_row_end = start_row_index_per_workers[current_rank * num_workers + current_worker_rank + 1]
current_intervals = {}
start_row_index_per_rank_worker = [0] + np.cumsum(num_rows_per_rank_worker).tolist()
current_global_row_start = start_row_index_per_rank_worker[current_rank * num_workers + current_worker_rank]
current_global_row_end = start_row_index_per_rank_worker[current_rank * num_workers + current_worker_rank + 1]
current_intervals = []
for intervals_per_file in row_group_intervals:
current_file_itvs = []
for itv in intervals_per_file:
if itv.global_row_end <= current_global_row_start:
continue
Expand All @@ -88,13 +95,29 @@ def associate_to_workers(
current_global_row_end < itv.global_row_end:
itv.local_row_end = current_global_row_end - itv.global_row_start
itv.global_row_end = current_global_row_end

if itv.file_index not in current_intervals:
current_intervals[itv.file_index] = [itv]
else:
current_intervals[itv.file_index].append(itv)


current_file_itvs.append(itv)
if len(current_file_itvs) > 0:
current_intervals.append(current_file_itvs)

if not drop_last and rank0_extra_rows > 0:
for itvs in current_intervals:
current_file_itvs = []
for itv in itvs:
offset = itv.local_row_end - itv.local_row_start
rank0_extra_rows -= offset
itv = copy.deepcopy(itv)
if rank0_extra_rows > 0:
current_file_itvs.append(itv)
else:
itv.local_row_end += rank0_extra_rows
itv.global_row_end += rank0_extra_rows
current_file_itvs.append(itv)
break
current_intervals.append(current_file_itvs)
if rank0_extra_rows <= 0:
break

return current_intervals, current_global_row_end-current_global_row_start


Expand Down
5 changes: 3 additions & 2 deletions tests/test_associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@
batch_size=11,
)
print(num_rows)
for idx, itv in intervals.items():
print(itv)
for itvs in intervals:
for itv in itvs:
print(itv)
Loading

0 comments on commit a14bd43

Please sign in to comment.