Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give the option to terminate the engine without firing Events.COMPLET… #3309

Merged
merged 20 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
self._process_function = process_function
self.last_event_name: Optional[Events] = None
self.should_terminate = False
self.skip_completed_after_termination = False
self.should_terminate_single_epoch = False
self.should_interrupt = False
self.state = State()
Expand Down Expand Up @@ -538,7 +539,7 @@ def call_interrupt():
self.logger.info("interrupt signaled. Engine will interrupt the run after current iteration is finished.")
self.should_interrupt = True

def terminate(self) -> None:
def terminate(self, skip_completed: bool = False) -> None:
"""Sends terminate signal to the engine, so that it terminates completely the run. The run is
terminated after the event on which ``terminate`` method was called. The following events are triggered:

Expand All @@ -547,6 +548,9 @@ def terminate(self) -> None:
- :attr:`~ignite.engine.events.Events.TERMINATE`
- :attr:`~ignite.engine.events.Events.COMPLETED`

Args:
skip_completed: if True, the event :attr:`~ignite.engine.events.Events.COMPLETED` is not fired after
:attr:`~ignite.engine.events.Events.TERMINATE`. Default is False.

Examples:
.. testcode::
Expand Down Expand Up @@ -617,9 +621,12 @@ def terminate():
.. versionchanged:: 0.4.10
Behaviour changed, for details see https://github.com/pytorch/ignite/issues/2669

.. versionchanged:: 0.5.2
Added `skip_completed` flag
"""
self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
self.should_terminate = True
self.skip_completed_after_termination = skip_completed

def terminate_epoch(self) -> None:
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
Expand Down Expand Up @@ -993,13 +1000,17 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken

# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
if not (self.should_terminate and self.skip_completed_after_termination):
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken

hours, mins, secs = _to_hours_mins_secs(time_taken)
self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")

except BaseException as e:
self._dataloader_iter = None
Expand Down Expand Up @@ -1174,13 +1185,17 @@ def _internal_run_legacy(self) -> State:
time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken

# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
if not (self.should_terminate and self.skip_completed_after_termination):
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken

hours, mins, secs = _to_hours_mins_secs(time_taken)
self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")

except BaseException as e:
self._dataloader_iter = None
Expand Down
7 changes: 4 additions & 3 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ class Events(EventEnum):

- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
- COMPLETED : triggered when engine's run is completed
- COMPLETED : triggered when engine's run is completed or terminated with :meth:`~ignite.engine.engine.Engine.terminate()`,
unless the flag `skip_completed` is set to True.
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

The table below illustrates which events are triggered when various termination methods are called.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bonassifabio actually, I find myself as well this table very misleading and hard to understand.
I think we can improve it in the following way:

  • EVENT_COMPLETED -> EPOCH_COMPLETED
  • Let's add new column on the right after "TERMINATE" and call it "COMPLETED"
  • Check symbol for terminate() line on TERMINATE_SINGLE_EPOCH column is wrong actually and should be replaced with x.

By the way, here is how it is generated now: https://deploy-preview-3309--pytorch-ignite-preview.netlify.app/generated/ignite.engine.events.events#ignite.engine.events.Events

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I was interpreting the first column as Events.COMPLETED, but from your comment it would seem that's not the case.

What if we change the table to something like this?

Method TERMINATE_SINGLE_EPOCH EPOCH_COMPLETED TERMINATE COMPLETED
No termination x x
terminate_epoch() x x x
terminate() x x
terminate(skip_completed=True) x x x

Few comments:

  1. To my understanding (please correct me if I'm wrong) if terminate() is called, the epoch is not necessarily completed. If this statement is true, it would perhaps make more sense to move EPOCH_COMPLETED before TERMINATE in the list of Events, as I did in the table, since EPOCH_COMPLETED would almost never be fired after EPOCH_COMPLETED.
  2. The columns would thus be ordered so that the first two are "epoch wise", while the last two happen at the end of the engine run. Not sure I
  3. I included a new row for terminate(skip_completed=True). Not sure if that's really necessary.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the new table, looks good!
Few corrections however about its content:

The line terminate_epoch(), column "EPOCH_COMPLETED" and "COMPLETED" should be both checked as v.

About your comments, I'm OK with all your suggestions, let's keep this order and the new line.

Checking code:

from ignite.engine import Engine, Events
from ignite.utils import setup_logger, logging


train_data = range(10)
max_epochs = 5


def train_step(engine, batch):
    pass

trainer = Engine(train_step)

# Enable trainer logger for a debug mode
trainer.logger = setup_logger("trainer", level=logging.DEBUG)

@trainer.on(Events.ITERATION_COMPLETED(once=12))
def call_terminate_epoch():
    trainer.terminate_epoch()

trainer.run(train_data, max_epochs=max_epochs)

Output:

...
2024-12-03 09:06:33,056 trainer INFO: Terminate current epoch is signaled. Current epoch iteration will stop after current iteration is finished.
2024-12-03 09:06:33,058 trainer DEBUG: 2 | 12, Firing handlers for event Events.TERMINATE_SINGLE_EPOCH
2024-12-03 09:06:33,060 trainer DEBUG: 2 | 12, Firing handlers for event Events.EPOCH_COMPLETED
2024-12-03 09:06:33,061 trainer INFO: Epoch[2] Complete. Time taken: 00:00:00.023
2024-12-03 09:06:33,064 trainer DEBUG: 3 | 12, Firing handlers for event Events.EPOCH_STARTED
2024-12-03 09:06:33,066 trainer DEBUG: 3 | 12, Firing handlers for event Events.GET_BATCH_STARTED
...

You can run it in https://pytorch-ignite.ai/playground

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing this out!

  1. Do you think we should add an argument skip_completed to terminate_epoch() with the same behavior of suppressing the firing of Events.EPOCH_COMPLETED? That might be useful, for example, to avoid running the evaluator, the checkpointer, and the LR scheduler if an epoch has been terminated.
  2. I agree on checking EPOCH_COMPLETED for terminate_epoch(). However, I think that checking COMPLETED might seem to imply that Events.COMPLETED is also fired by terminate_epoch() right after Events.EPOCH_COMPLETED. What if we leave that cell empty?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should add an argument skip_completed to terminate_epoch() with the same behavior of suppressing the firing of Events.EPOCH_COMPLETED?

Good idea! Let's make this in a separate PR such that this one does not become too large

I agree on checking EPOCH_COMPLETED for terminate_epoch(). However, I think that checking COMPLETED might seem to imply that Events.COMPLETED is also fired by terminate_epoch() right after Events.EPOCH_COMPLETED. What if we leave that cell empty?

well, I was seeing this table as showing which events are triggered when we call terminate*() functions. I did not understand v mark on COMPLETED for terminate_epoch as the sequence of triggered events, but just whether an event is triggered or not. To avoid misunderstanding we can add a column between EPOCH_COMPLETED and TERMINATE named as for example "Other events" and check it where it is appropriate. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! Let's make this in a separate PR such that this one does not become too large

Cool, I'll create a new PR when we're done with this one

I did not understand v mark on COMPLETED for terminate_epoch as the sequence of triggered events, but just whether an event is triggered or not.

I see the point, but doesn't the firing of TERMINATE and COMPLETED independent of terminate_epoch()?
See for example this code.

from ignite.engine import Engine, Events
from ignite.utils import setup_logger, logging

train_data = range(10)
max_epochs = 5


def train_step(engine, batch):
    pass

trainer = Engine(train_step)

# Enable trainer logger for a debug mode
trainer.logger = setup_logger("trainer", level=logging.DEBUG)

@trainer.on(Events.ITERATION_COMPLETED(once=12))
def call_terminate_epoch():
    trainer.terminate_epoch()

@trainer.on(Events.ITERATION_COMPLETED(once=15))
def call_terminate():
    trainer.terminate(skip_completed=True)

trainer.run(train_data, max_epochs=max_epochs)

This is probably an edge case. As a newbie of ignite, it would honestly make more sense to see no entry for TERMINATE and COMPLETED for terminate_epoch(). However, you are for sure more experienced than me, so we can go ahead with your solution if you prefer 🙂

Solution 1

Method TERMINATE_SINGLE_EPOCH EPOCH_COMPLETED TERMINATE COMPLETED
No termination x x
terminate_epoch() x
terminate() x x
terminate(skip_completed=True) x x x

Solution 2

Method TERMINATE_SINGLE_EPOCH EPOCH_COMPLETED TERMINATE COMPLETED
No termination x x
terminate_epoch() x x
terminate() x x
terminate(skip_completed=True) x x x

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the point, but doesn't the firing of TERMINATE and COMPLETED independent of terminate_epoch()?

yes, they should be independent, i would say. If this is not the case, it should be a bug...

As a newbie of ignite, it would honestly make more sense to see no entry for TERMINATE and COMPLETED for terminate_epoch(). However, you are for sure more experienced than me, so we can go ahead with your solution if you prefer 🙂

Actually, your feedback is more important as other users may think the same (and those who know how it works do not read the docs :) ) !

Initially, the goal of the table is to provide visual representation of the info on which events are triggered in each case (when call terminate*()) and also compared with regular run.
Now, seeing its content I find this tabular representation more misleading than useful as triggered events depend on where the terminate function was called.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would propose to keep the table (either Table 1 or Table 2), and then to open a new issue to seek some input also from other developers/users and discuss the problem more in detail there.
We could for example think to a flow chart diagram or something like that...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I would remove it and maybe added some code output to show exactly how it works...
OK, up to you, either let's revert the changes or keep the most understandable version and update it later.


Expand All @@ -286,7 +287,7 @@ class Events(EventEnum):
- ✔
- ✗
* - :meth:`~ignite.engine.engine.Engine.terminate()`
-
- (✔)
- ✔
- ✔

Expand Down Expand Up @@ -357,7 +358,7 @@ class CustomEvents(EventEnum):
STARTED = "started"
"""triggered when engine's run is started."""
COMPLETED = "completed"
"""triggered when engine's run is completed"""
"""triggered when engine's run is completed, or after receiving terminate() call."""

ITERATION_STARTED = "iteration_started"
"""triggered when an iteration is started."""
Expand Down
1 change: 0 additions & 1 deletion tests/ignite/contrib/engines/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.utils.data.distributed import DistributedSampler

import ignite.distributed as idist

import ignite.handlers as handlers
from ignite.contrib.engines.common import (
_setup_logging,
Expand Down
48 changes: 33 additions & 15 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,15 @@ def set_interrupt_resume_enabled(self, interrupt_resume_enabled):

def test_terminate(self):
engine = Engine(lambda e, b: 1)
assert not engine.should_terminate
assert not engine.should_terminate and not engine.skip_completed_after_termination
bonassifabio marked this conversation as resolved.
Show resolved Hide resolved
engine.terminate()
assert engine.should_terminate
assert engine.should_terminate and not engine.skip_completed_after_termination
bonassifabio marked this conversation as resolved.
Show resolved Hide resolved

def test_terminate_and_not_complete(self):
engine = Engine(lambda e, b: 1)
assert not engine.should_terminate and not engine.skip_completed_after_termination
engine.terminate(skip_completed=True)
assert engine.should_terminate and engine.skip_completed_after_termination
bonassifabio marked this conversation as resolved.
Show resolved Hide resolved

def test_invalid_process_raises_with_invalid_signature(self):
with pytest.raises(ValueError, match=r"Engine must be given a processing function in order to run"):
Expand Down Expand Up @@ -236,25 +242,32 @@ def check_iter_and_data():
assert num_calls_check_iter_epoch == 1

@pytest.mark.parametrize(
"terminate_event, e, i",
"terminate_event, e, i, skip_event_completed",
[
(Events.STARTED, 0, 0),
(Events.EPOCH_STARTED(once=2), 2, None),
(Events.EPOCH_COMPLETED(once=2), 2, None),
(Events.GET_BATCH_STARTED(once=12), None, 12),
(Events.GET_BATCH_COMPLETED(once=12), None, 12),
(Events.ITERATION_STARTED(once=14), None, 14),
(Events.ITERATION_COMPLETED(once=14), None, 14),
(Events.STARTED, 0, 0, True),
(Events.EPOCH_STARTED(once=2), 2, None, True),
(Events.EPOCH_COMPLETED(once=2), 2, None, True),
(Events.GET_BATCH_STARTED(once=12), None, 12, True),
(Events.GET_BATCH_COMPLETED(once=12), None, 12, False),
(Events.ITERATION_STARTED(once=14), None, 14, True),
(Events.ITERATION_COMPLETED(once=14), None, 14, True),
(Events.STARTED, 0, 0, False),
(Events.EPOCH_STARTED(once=2), 2, None, False),
(Events.EPOCH_COMPLETED(once=2), 2, None, False),
(Events.GET_BATCH_STARTED(once=12), None, 12, False),
(Events.GET_BATCH_COMPLETED(once=12), None, 12, False),
(Events.ITERATION_STARTED(once=14), None, 14, False),
(Events.ITERATION_COMPLETED(once=14), None, 14, False),
],
)
def test_terminate_events_sequence(self, terminate_event, e, i):
def test_terminate_events_sequence(self, terminate_event, e, i, skip_completed):
engine = RecordedEngine(MagicMock(return_value=1))
data = range(10)
max_epochs = 5

@engine.on(terminate_event)
def call_terminate():
engine.terminate()
engine.terminate(skip_completed)

@engine.on(Events.EXCEPTION_RAISED)
def assert_no_exceptions(ee):
Expand All @@ -271,10 +284,15 @@ def assert_no_exceptions(ee):
if e is None:
e = i // len(data) + 1

if skip_completed:
assert engine.called_events[-1] == (e, i, Events.TERMINATE)
assert engine.called_events[-2] == (e, i, terminate_event)
else:
assert engine.called_events[-1] == (e, i, Events.COMPLETED)
assert engine.called_events[-2] == (e, i, Events.TERMINATE)
assert engine.called_events[-3] == (e, i, terminate_event)

assert engine.called_events[0] == (0, 0, Events.STARTED)
assert engine.called_events[-1] == (e, i, Events.COMPLETED)
assert engine.called_events[-2] == (e, i, Events.TERMINATE)
assert engine.called_events[-3] == (e, i, terminate_event)
assert engine._dataloader_iter is None

@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)])
Expand Down
Loading