From 5f2cd887328762288c1554f22228a5f04c908b8e Mon Sep 17 00:00:00 2001 From: Fabio Bonassi Date: Mon, 2 Dec 2024 23:32:39 +0100 Subject: [PATCH] - Fixed docs broken links. - Do not update self.state.times[Events.COMPLETED.name] if terminated - Fixed unit test --- ignite/engine/engine.py | 30 ++++++++++----------- ignite/engine/events.py | 2 +- tests/ignite/contrib/engines/test_common.py | 2 +- tests/ignite/engine/test_engine.py | 11 ++++---- 4 files changed, 22 insertions(+), 23 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 2e689b389ee..14e88d48015 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -4,7 +4,7 @@ import time import warnings import weakref -from collections import OrderedDict, defaultdict +from collections import defaultdict, OrderedDict from collections.abc import Mapping from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union @@ -549,8 +549,8 @@ def terminate(self, skip_completed: bool = False) -> None: - :attr:`~ignite.engine.events.Events.COMPLETED` Args: - skip_completed: if True, the event `~ignite.engine.events.Events.COMPLETED` is not fired after - `~ignite.engine.events.Events.TERMINATE`. Default is False. + skip_completed: if True, the event :attr:`~ignite.engine.events.Events.COMPLETED` is not fired after + :attr:`~ignite.engine.events.Events.TERMINATE`. Default is False. Examples: .. testcode:: @@ -905,8 +905,6 @@ def switch_batch(engine): # If engine was terminated and now is resuming from terminated state # we need to initialize iter_counter as 0 self._init_iter = 0 - # And we reset the skip_completed_after_termination to its default value - self.skip_completed_after_termination = False if self._dataloader_iter is None: self.state.dataloader = data @@ -1002,17 +1000,19 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]: time_taken = time.time() - start_time # time is available for handlers but must be updated after fire self.state.times[Events.COMPLETED.name] = time_taken - handlers_start_time = time.time() - # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed_after_termination=True` - if self.should_terminate and not self.skip_completed_after_termination or not self.should_terminate: + if self.should_terminate and self.skip_completed_after_termination: + # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed_after_termination=True` + hours, mins, secs = _to_hours_mins_secs(time_taken) + self.logger.info(f"Engine run terminated. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}") + else: + handlers_start_time = time.time() self._fire_event(Events.COMPLETED) - - time_taken += time.time() - handlers_start_time - # update time wrt handlers - self.state.times[Events.COMPLETED.name] = time_taken - hours, mins, secs = _to_hours_mins_secs(time_taken) - self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}") + time_taken += time.time() - handlers_start_time + # update time wrt handlers + self.state.times[Events.COMPLETED.name] = time_taken + hours, mins, secs = _to_hours_mins_secs(time_taken) + self.logger.info(f"Engine run completed. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}") except BaseException as e: self._dataloader_iter = None @@ -1192,7 +1192,7 @@ def _internal_run_legacy(self) -> State: # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed_after_termination=True` if self.should_terminate and not self.skip_completed_after_termination or not self.should_terminate: self._fire_event(Events.COMPLETED) - + time_taken += time.time() - handlers_start_time # update time wrt handlers self.state.times[Events.COMPLETED.name] = time_taken diff --git a/ignite/engine/events.py b/ignite/engine/events.py index 1f1ab6e6726..af5ebbbabe1 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -266,7 +266,7 @@ class Events(EventEnum): - EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called. - COMPLETED : triggered when engine's run is completed or terminated with :meth:`~ignite.engine.engine.Engine.terminate()`, - unless the flag `skip_event_completed` is set to True. + unless the flag `skip_completed` is set to True. The table below illustrates which events are triggered when various termination methods are called. diff --git a/tests/ignite/contrib/engines/test_common.py b/tests/ignite/contrib/engines/test_common.py index c93a081c754..e14042e62c1 100644 --- a/tests/ignite/contrib/engines/test_common.py +++ b/tests/ignite/contrib/engines/test_common.py @@ -1,6 +1,6 @@ import os import sys -from unittest.mock import MagicMock, call +from unittest.mock import call, MagicMock import pytest import torch diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 351e78d41ce..43957bf45d4 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -1,6 +1,6 @@ import os import time -from unittest.mock import MagicMock, Mock, call +from unittest.mock import call, MagicMock, Mock import numpy as np import pytest @@ -49,7 +49,7 @@ def test_terminate(self): def test_terminate_and_not_complete(self): engine = Engine(lambda e, b: 1) assert not engine.should_terminate and not engine.skip_completed_after_termination - engine.terminate(skip_event_completed=True) + engine.terminate(skip_completed=True) assert engine.should_terminate and engine.skip_completed_after_termination def test_invalid_process_raises_with_invalid_signature(self): @@ -260,14 +260,14 @@ def check_iter_and_data(): (Events.ITERATION_COMPLETED(once=14), None, 14, False), ], ) - def test_terminate_events_sequence(self, terminate_event, e, i, skip_event_completed): + def test_terminate_events_sequence(self, terminate_event, e, i, skip_completed): engine = RecordedEngine(MagicMock(return_value=1)) data = range(10) max_epochs = 5 @engine.on(terminate_event) def call_terminate(): - engine.terminate(skip_event_completed) + engine.terminate(skip_completed) @engine.on(Events.EXCEPTION_RAISED) def assert_no_exceptions(ee): @@ -284,7 +284,7 @@ def assert_no_exceptions(ee): if e is None: e = i // len(data) + 1 - if skip_event_completed: + if skip_completed: assert engine.called_events[-1] == (e, i, Events.TERMINATE) assert engine.called_events[-2] == (e, i, terminate_event) else: @@ -1425,4 +1425,3 @@ def check_iter_epoch(): state = engine.run(data, max_epochs=max_epochs) assert state.iteration == max_epochs * len(data) and state.epoch == max_epochs assert num_calls_check_iter_epoch == 1 -