Skip to content

Commit

Permalink
- Fixed docs broken links.
Browse files Browse the repository at this point in the history
- Do not update self.state.times[Events.COMPLETED.name]  if terminated
- Fixed unit test
  • Loading branch information
bonassifabio committed Dec 2, 2024
1 parent f1d194a commit 5f2cd88
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 23 deletions.
30 changes: 15 additions & 15 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import warnings
import weakref
from collections import OrderedDict, defaultdict
from collections import defaultdict, OrderedDict
from collections.abc import Mapping
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union

Expand Down Expand Up @@ -549,8 +549,8 @@ def terminate(self, skip_completed: bool = False) -> None:
- :attr:`~ignite.engine.events.Events.COMPLETED`
Args:
skip_completed: if True, the event `~ignite.engine.events.Events.COMPLETED` is not fired after
`~ignite.engine.events.Events.TERMINATE`. Default is False.
skip_completed: if True, the event :attr:`~ignite.engine.events.Events.COMPLETED` is not fired after
:attr:`~ignite.engine.events.Events.TERMINATE`. Default is False.
Examples:
.. testcode::
Expand Down Expand Up @@ -905,8 +905,6 @@ def switch_batch(engine):
# If engine was terminated and now is resuming from terminated state
# we need to initialize iter_counter as 0
self._init_iter = 0
# And we reset the skip_completed_after_termination to its default value
self.skip_completed_after_termination = False

if self._dataloader_iter is None:
self.state.dataloader = data
Expand Down Expand Up @@ -1002,17 +1000,19 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken
handlers_start_time = time.time()

# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed_after_termination=True`
if self.should_terminate and not self.skip_completed_after_termination or not self.should_terminate:
if self.should_terminate and self.skip_completed_after_termination:
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed_after_termination=True`
hours, mins, secs = _to_hours_mins_secs(time_taken)
self.logger.info(f"Engine run terminated. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
else:
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)

time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken
hours, mins, secs = _to_hours_mins_secs(time_taken)
self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken
hours, mins, secs = _to_hours_mins_secs(time_taken)
self.logger.info(f"Engine run completed. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")

except BaseException as e:
self._dataloader_iter = None
Expand Down Expand Up @@ -1192,7 +1192,7 @@ def _internal_run_legacy(self) -> State:
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed_after_termination=True`
if self.should_terminate and not self.skip_completed_after_termination or not self.should_terminate:
self._fire_event(Events.COMPLETED)

time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken
Expand Down
2 changes: 1 addition & 1 deletion ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class Events(EventEnum):
- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
- COMPLETED : triggered when engine's run is completed or terminated with :meth:`~ignite.engine.engine.Engine.terminate()`,
unless the flag `skip_event_completed` is set to True.
unless the flag `skip_completed` is set to True.
The table below illustrates which events are triggered when various termination methods are called.
Expand Down
2 changes: 1 addition & 1 deletion tests/ignite/contrib/engines/test_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import sys
from unittest.mock import MagicMock, call
from unittest.mock import call, MagicMock

import pytest
import torch
Expand Down
11 changes: 5 additions & 6 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import time
from unittest.mock import MagicMock, Mock, call
from unittest.mock import call, MagicMock, Mock

import numpy as np
import pytest
Expand Down Expand Up @@ -49,7 +49,7 @@ def test_terminate(self):
def test_terminate_and_not_complete(self):
engine = Engine(lambda e, b: 1)
assert not engine.should_terminate and not engine.skip_completed_after_termination
engine.terminate(skip_event_completed=True)
engine.terminate(skip_completed=True)
assert engine.should_terminate and engine.skip_completed_after_termination

def test_invalid_process_raises_with_invalid_signature(self):
Expand Down Expand Up @@ -260,14 +260,14 @@ def check_iter_and_data():
(Events.ITERATION_COMPLETED(once=14), None, 14, False),
],
)
def test_terminate_events_sequence(self, terminate_event, e, i, skip_event_completed):
def test_terminate_events_sequence(self, terminate_event, e, i, skip_completed):
engine = RecordedEngine(MagicMock(return_value=1))
data = range(10)
max_epochs = 5

@engine.on(terminate_event)
def call_terminate():
engine.terminate(skip_event_completed)
engine.terminate(skip_completed)

@engine.on(Events.EXCEPTION_RAISED)
def assert_no_exceptions(ee):
Expand All @@ -284,7 +284,7 @@ def assert_no_exceptions(ee):
if e is None:
e = i // len(data) + 1

if skip_event_completed:
if skip_completed:
assert engine.called_events[-1] == (e, i, Events.TERMINATE)
assert engine.called_events[-2] == (e, i, terminate_event)
else:
Expand Down Expand Up @@ -1425,4 +1425,3 @@ def check_iter_epoch():
state = engine.run(data, max_epochs=max_epochs)
assert state.iteration == max_epochs * len(data) and state.epoch == max_epochs
assert num_calls_check_iter_epoch == 1

0 comments on commit 5f2cd88

Please sign in to comment.