diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 27a949cacca..e2a14898607 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -140,6 +140,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]): self._process_function = process_function self.last_event_name: Optional[Events] = None self.should_terminate = False + self.skip_completed_after_termination = False self.should_terminate_single_epoch = False self.should_interrupt = False self.state = State() @@ -538,7 +539,7 @@ def call_interrupt(): self.logger.info("interrupt signaled. Engine will interrupt the run after current iteration is finished.") self.should_interrupt = True - def terminate(self) -> None: + def terminate(self, skip_completed: bool = False) -> None: """Sends terminate signal to the engine, so that it terminates completely the run. The run is terminated after the event on which ``terminate`` method was called. The following events are triggered: @@ -547,6 +548,9 @@ def terminate(self) -> None: - :attr:`~ignite.engine.events.Events.TERMINATE` - :attr:`~ignite.engine.events.Events.COMPLETED` + Args: + skip_completed: if True, the event :attr:`~ignite.engine.events.Events.COMPLETED` is not fired after + :attr:`~ignite.engine.events.Events.TERMINATE`. Default is False. Examples: .. testcode:: @@ -617,9 +621,12 @@ def terminate(): .. versionchanged:: 0.4.10 Behaviour changed, for details see https://github.com/pytorch/ignite/issues/2669 + .. versionchanged:: 0.5.2 + Added `skip_completed` flag """ self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.") self.should_terminate = True + self.skip_completed_after_termination = skip_completed def terminate_epoch(self) -> None: """Sends terminate signal to the engine, so that it terminates the current epoch. The run @@ -993,13 +1000,17 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]: time_taken = time.time() - start_time # time is available for handlers but must be updated after fire self.state.times[Events.COMPLETED.name] = time_taken - handlers_start_time = time.time() - self._fire_event(Events.COMPLETED) - time_taken += time.time() - handlers_start_time - # update time wrt handlers - self.state.times[Events.COMPLETED.name] = time_taken + + # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True` + if not (self.should_terminate and self.skip_completed_after_termination): + handlers_start_time = time.time() + self._fire_event(Events.COMPLETED) + time_taken += time.time() - handlers_start_time + # update time wrt handlers + self.state.times[Events.COMPLETED.name] = time_taken + hours, mins, secs = _to_hours_mins_secs(time_taken) - self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}") + self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}") except BaseException as e: self._dataloader_iter = None @@ -1174,13 +1185,17 @@ def _internal_run_legacy(self) -> State: time_taken = time.time() - start_time # time is available for handlers but must be updated after fire self.state.times[Events.COMPLETED.name] = time_taken - handlers_start_time = time.time() - self._fire_event(Events.COMPLETED) - time_taken += time.time() - handlers_start_time - # update time wrt handlers - self.state.times[Events.COMPLETED.name] = time_taken + + # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True` + if not (self.should_terminate and self.skip_completed_after_termination): + handlers_start_time = time.time() + self._fire_event(Events.COMPLETED) + time_taken += time.time() - handlers_start_time + # update time wrt handlers + self.state.times[Events.COMPLETED.name] = time_taken + hours, mins, secs = _to_hours_mins_secs(time_taken) - self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}") + self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}") except BaseException as e: self._dataloader_iter = None diff --git a/ignite/engine/events.py b/ignite/engine/events.py index 9dd99348492..87622d3415c 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -259,36 +259,47 @@ class Events(EventEnum): - TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch, after receiving a :meth:`~ignite.engine.engine.Engine.terminate_epoch()` or :meth:`~ignite.engine.engine.Engine.terminate()` call. + - EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even + when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called. - TERMINATE : triggered when the run is about to end completely, after receiving :meth:`~ignite.engine.engine.Engine.terminate()` call. - - EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even - when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called. - - COMPLETED : triggered when engine's run is completed + - COMPLETED : triggered when engine's run is completed or terminated with + :meth:`~ignite.engine.engine.Engine.terminate()`, unless the flag + `skip_completed` is set to True. The table below illustrates which events are triggered when various termination methods are called. .. list-table:: - :widths: 24 25 33 18 + :widths: 35 38 28 20 20 :header-rows: 1 * - Method - - EVENT_COMPLETED - TERMINATE_SINGLE_EPOCH + - EPOCH_COMPLETED - TERMINATE + - COMPLETED * - no termination - - ✔ - ✗ + - ✔ - ✗ + - ✔ * - :meth:`~ignite.engine.engine.Engine.terminate_epoch()` - ✔ - ✔ - ✗ + - ✔ * - :meth:`~ignite.engine.engine.Engine.terminate()` - ✗ - ✔ - ✔ + - ✔ + * - :meth:`~ignite.engine.engine.Engine.terminate()` with `skip_completed=True` + - ✗ + - ✔ + - ✔ + - ✗ Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine: @@ -357,7 +368,7 @@ class CustomEvents(EventEnum): STARTED = "started" """triggered when engine's run is started.""" COMPLETED = "completed" - """triggered when engine's run is completed""" + """triggered when engine's run is completed, or after receiving terminate() call.""" ITERATION_STARTED = "iteration_started" """triggered when an iteration is started.""" diff --git a/tests/ignite/contrib/engines/test_common.py b/tests/ignite/contrib/engines/test_common.py index d0100be9e8d..e14042e62c1 100644 --- a/tests/ignite/contrib/engines/test_common.py +++ b/tests/ignite/contrib/engines/test_common.py @@ -8,7 +8,6 @@ from torch.utils.data.distributed import DistributedSampler import ignite.distributed as idist - import ignite.handlers as handlers from ignite.contrib.engines.common import ( _setup_logging, diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 13021242650..fcb0299aa22 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -40,11 +40,14 @@ class TestEngine: def set_interrupt_resume_enabled(self, interrupt_resume_enabled): Engine.interrupt_resume_enabled = interrupt_resume_enabled - def test_terminate(self): + @pytest.mark.parametrize("skip_completed", [True, False]) + def test_terminate(self, skip_completed): engine = Engine(lambda e, b: 1) assert not engine.should_terminate - engine.terminate() + assert not engine.skip_completed_after_termination + engine.terminate(skip_completed) assert engine.should_terminate + assert engine.skip_completed_after_termination == skip_completed def test_invalid_process_raises_with_invalid_signature(self): with pytest.raises(ValueError, match=r"Engine must be given a processing function in order to run"): @@ -236,25 +239,32 @@ def check_iter_and_data(): assert num_calls_check_iter_epoch == 1 @pytest.mark.parametrize( - "terminate_event, e, i", + "terminate_event, e, i, skip_completed", [ - (Events.STARTED, 0, 0), - (Events.EPOCH_STARTED(once=2), 2, None), - (Events.EPOCH_COMPLETED(once=2), 2, None), - (Events.GET_BATCH_STARTED(once=12), None, 12), - (Events.GET_BATCH_COMPLETED(once=12), None, 12), - (Events.ITERATION_STARTED(once=14), None, 14), - (Events.ITERATION_COMPLETED(once=14), None, 14), + (Events.STARTED, 0, 0, True), + (Events.EPOCH_STARTED(once=2), 2, None, True), + (Events.EPOCH_COMPLETED(once=2), 2, None, True), + (Events.GET_BATCH_STARTED(once=12), None, 12, True), + (Events.GET_BATCH_COMPLETED(once=12), None, 12, False), + (Events.ITERATION_STARTED(once=14), None, 14, True), + (Events.ITERATION_COMPLETED(once=14), None, 14, True), + (Events.STARTED, 0, 0, False), + (Events.EPOCH_STARTED(once=2), 2, None, False), + (Events.EPOCH_COMPLETED(once=2), 2, None, False), + (Events.GET_BATCH_STARTED(once=12), None, 12, False), + (Events.GET_BATCH_COMPLETED(once=12), None, 12, False), + (Events.ITERATION_STARTED(once=14), None, 14, False), + (Events.ITERATION_COMPLETED(once=14), None, 14, False), ], ) - def test_terminate_events_sequence(self, terminate_event, e, i): + def test_terminate_events_sequence(self, terminate_event, e, i, skip_completed): engine = RecordedEngine(MagicMock(return_value=1)) data = range(10) max_epochs = 5 @engine.on(terminate_event) def call_terminate(): - engine.terminate() + engine.terminate(skip_completed) @engine.on(Events.EXCEPTION_RAISED) def assert_no_exceptions(ee): @@ -271,10 +281,15 @@ def assert_no_exceptions(ee): if e is None: e = i // len(data) + 1 + if skip_completed: + assert engine.called_events[-1] == (e, i, Events.TERMINATE) + assert engine.called_events[-2] == (e, i, terminate_event) + else: + assert engine.called_events[-1] == (e, i, Events.COMPLETED) + assert engine.called_events[-2] == (e, i, Events.TERMINATE) + assert engine.called_events[-3] == (e, i, terminate_event) + assert engine.called_events[0] == (0, 0, Events.STARTED) - assert engine.called_events[-1] == (e, i, Events.COMPLETED) - assert engine.called_events[-2] == (e, i, Events.TERMINATE) - assert engine.called_events[-3] == (e, i, terminate_event) assert engine._dataloader_iter is None @pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)])