Skip to content

Commit

Permalink
Review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Emilio Castillo committed Nov 17, 2023
1 parent 13c1746 commit 57774e7
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pytorch_pfn_extras/profiler/_time_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(
self._summary = DictSummary()
self._additional_stats: Dict[str, float] = {}

self._cpu_worker = _util._QueueWorker(
self._cpu_worker = _util.QueueWorker(
self._add_from_worker, max_queue_size
)
self._cuda_worker: Optional[_CUDAWorker] = None
Expand Down
13 changes: 11 additions & 2 deletions pytorch_pfn_extras/profiler/_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import time
from typing import Any, Dict, Generator, List, Optional, Type, Union, cast

import torch
import torch.cuda
import torch.utils.data
from pytorch_pfn_extras.profiler import _util
from pytorch_pfn_extras.writing import Writer

Expand All @@ -27,6 +28,9 @@ def flush(self, filename: str, writer: Writer) -> None:
def enable(self, enable_flag: bool) -> None:
raise NotImplementedError("Tracers must implement enable")

def finalize(self) -> None:
raise NotImplementedError("Tracers must implement finalize")


class DummyTracer(Tracer):
@contextlib.contextmanager
Expand Down Expand Up @@ -68,7 +72,7 @@ def __init__(
# Detect if i am a forked process, in such case I send the event to
# The parent process
self._is_cuda_available = torch.cuda.is_available()
self._tracer_queue: _util._QueueWorker = _util._QueueWorker(
self._tracer_queue: _util.QueueWorker = _util.QueueWorker(
self.add_remote_event, 1000
)
self._tracer_queue.initialize()
Expand Down Expand Up @@ -123,6 +127,7 @@ def add_remote_event(
def flush(self, filename: str, writer: Writer) -> None:
if not self._enable:
return
self._tracer_queue.synchronize()
# TODO(ecastill): try to work on some append mode manipulating the
# file pointer and with json.dumps?
savefun = ChromeTracingSaveFunc()
Expand Down Expand Up @@ -151,9 +156,13 @@ def load_state_dict(self, to_load: Dict[str, Any]) -> None:
self._event_count = to_load["_event_count"]

def clear(self) -> None:
self._tracer_queue.synchronize()
self._event_list = []
self._event_count = 0

def finalize(self) -> None:
self._tracer_queue.synchronize()


_tracer: Optional[Tracer] = None
_main_pid = os.getpid()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_pfn_extras/profiler/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, Optional


class _QueueWorker:
class QueueWorker:
def __init__(
self,
add: Callable[[str, Any], None],
Expand Down
10 changes: 7 additions & 3 deletions pytorch_pfn_extras/training/extensions/timeline_trace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, Optional

from pytorch_pfn_extras.profiler._tracing import get_tracer
from pytorch_pfn_extras.profiler._tracing import get_tracer, Tracer
from pytorch_pfn_extras.training import extension
from pytorch_pfn_extras.training import trigger as trigger_module
from pytorch_pfn_extras.training._manager_protocol import (
Expand Down Expand Up @@ -55,9 +55,10 @@ def __init__(
filename: Optional[str] = None,
enable: Optional[trigger_module.TriggerLike] = None,
disable: Optional[trigger_module.TriggerLike] = None,
tracer: Optional[Tracer] = None,
**kwargs: Any,
):
self._tracer = kwargs.get("tracer", get_tracer())
self._tracer = tracer if tracer is not None else get_tracer()
self._enable = None
if enable is not None:
self._enable = trigger_module.get_trigger(enable)
Expand All @@ -72,6 +73,8 @@ def __init__(
self._writer = kwargs.get("writer", None)

def _flush_trace(self, manager: ExtensionsManagerProtocol) -> None:
# TODO(kaku) It would be nice to be able to select a mode to
# synchronize the tracer and then flush it as a strict flush_trace()
writer = manager.writer if self._writer is None else self._writer

# write to the log file
Expand Down Expand Up @@ -103,4 +106,5 @@ def finalize(self, manager: ExtensionsManagerProtocol) -> None:
self._flush_trace(manager)
if self._writer is not None:
self._writer.finalize()
self._Tracer.clear()
self._tracer.clear()
self._tracer.finalize()

0 comments on commit 57774e7

Please sign in to comment.