Skip to content

Commit

Permalink
fix wrong tracing location of fetch data (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
rayandrew authored Oct 30, 2024
1 parent cda4df3 commit cc5abbc
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 5 deletions.
4 changes: 3 additions & 1 deletion dlio_benchmark/data_loader/native_dali_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def next(self):
pipeline.reset()
for step in range(num_samples // batch_size):
try:
for batch in self._dataset:
# TODO: @hariharan-devarajan: change below line when we bump the dftracer version to
# `dlp.iter(self._dataset, name=self.next.__qualname__)`
for batch in dlp.iter(self._dataset):
logging.debug(f"{utcnow()} Creating {len(batch)} batches by {self._args.my_rank} rank ")
yield batch
except StopIteration:
Expand Down
4 changes: 3 additions & 1 deletion dlio_benchmark/data_loader/tf_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def read(self):
@dlp.log
def next(self):
super().next()
for batch in self._dataset:
# TODO: @hariharan-devarajan: change below line when we bump the dftracer version to
# `dlp.iter(self._dataset, name=self.next.__qualname__)`
for batch in dlp.iter(self._dataset):
yield batch
self.epoch_number += 1
dlp.update(epoch=self.epoch_number)
Expand Down
4 changes: 3 additions & 1 deletion dlio_benchmark/data_loader/torch_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def next(self):
total = self._args.training_steps if self.dataset_type is DatasetType.TRAIN else self._args.eval_steps
logging.debug(f"{utcnow()} Rank {self._args.my_rank} should read {total} batches")
step = 1
for batch in self._dataset:
# TODO: @hariharan-devarajan: change below line when we bump the dftracer version to
# `dlp.iter(self._dataset, name=self.next.__qualname__)`
for batch in dlp.iter(self._dataset):
dlp.update(step = step)
step += 1
yield batch
Expand Down
4 changes: 2 additions & 2 deletions dlio_benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _eval(self, epoch):
total = math.floor(self.num_samples * self.num_files_eval / self.batch_size_eval / self.comm_size)
loader = self.framework.get_loader(DatasetType.VALID)
t0 = time()
for batch in dlp.iter(loader.next()):
for batch in loader.next():
self.stats.eval_batch_loaded(epoch, step, t0)
eval_time = 0.0
if self.eval_time > 0:
Expand Down Expand Up @@ -256,7 +256,7 @@ def _train(self, epoch):

loader = self.framework.get_loader(dataset_type=DatasetType.TRAIN)
t0 = time()
for batch in dlp.iter(loader.next()):
for batch in loader.next():
if overall_step > max_steps or ((self.total_training_steps > 0) and (overall_step > self.total_training_steps)):
if self.args.my_rank == 0:
logging.info(f"{utcnow()} Maximum number of steps reached")
Expand Down

0 comments on commit cc5abbc

Please sign in to comment.