Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tear_down for train extensions; fixes #1467 #1502

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 98 additions & 75 deletions pylearn2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ def setup_extensions(self):
for ext in self.extensions:
ext.setup(self.model, self.dataset, self.algorithm)

def tear_down_extensions(self):
""" Calls tear_down on all extensions."""
for ext in self.extensions:
try:
ext.tear_down(self.model, self.dataset, self.algorithm)
except Exception:
log.debug('%s train extension failed to terminate gracefully',
exc_info=True)

def exceeded_time_budget(self, t0, time_budget):
"""
.. todo::
Expand Down Expand Up @@ -126,6 +135,12 @@ def setup(self):
# make sure the constraints are enforced from the start.
self.model.enforce_constraints()

def tear_down(self):
"""
Called at the end of main loop.
"""
self.tear_down_extensions()

def main_loop(self, time_budget=None):
"""
Repeatedly runs an epoch of the training algorithm, runs any
Expand All @@ -139,95 +154,103 @@ def main_loop(self, time_budget=None):
"""
t0 = datetime.now()
self.setup()
if self.algorithm is None:
self.run_callbacks_and_monitoring()
while True:
if self.exceeded_time_budget(t0, time_budget):
break
try:
if self.algorithm is None:
self.run_callbacks_and_monitoring()
while True:
if self.exceeded_time_budget(t0, time_budget):
break

rval = self.model.train_all(dataset=self.dataset)
if rval is not None:
raise ValueError("Model.train_all should not return " +
"anything. Use Model.continue_learning " +
"to control whether learning continues.")
self.model.monitor.report_epoch()
extension_continue = self.run_callbacks_and_monitoring()
freq = self.save_freq
if freq > 0 and self.model.monitor.get_epochs_seen() % freq == 0:
self.save()
continue_learning = (self.model.continue_learning() and
extension_continue)
assert continue_learning in [True, False, 0, 1]
if not continue_learning:
break
else:
if not hasattr(self.model, 'monitor'):
# TODO: is this really necessary? I just put this error here
# to prevent an AttributeError later, but I think we could
# rewrite to avoid the AttributeError
raise RuntimeError("The algorithm is responsible for setting"
" up the Monitor, but failed to.")
if len(self.model.monitor._datasets) > 0:
# This monitoring channel keeps track of a shared variable,
# which does not need inputs nor data.
self.training_seconds.__doc__ = """\
rval = self.model.train_all(dataset=self.dataset)
if rval is not None:
raise ValueError("Model.train_all should not return " +
"anything. Use Model.continue_learning " +
"to control whether learning continues.")
self.model.monitor.report_epoch()
extension_continue = self.run_callbacks_and_monitoring()
freq = self.save_freq
if freq > 0 and self.model.monitor.get_epochs_seen() % freq == 0:
self.save()
continue_learning = (self.model.continue_learning() and
extension_continue)
assert continue_learning in [True, False, 0, 1]
if not continue_learning:
break
else:
if not hasattr(self.model, 'monitor'):
# TODO: is this really necessary? I just put this error here
# to prevent an AttributeError later, but I think we could
# rewrite to avoid the AttributeError
raise RuntimeError("The algorithm is responsible for setting"
" up the Monitor, but failed to.")
if len(self.model.monitor._datasets) > 0:
# This monitoring channel keeps track of a shared variable,
# which does not need inputs nor data.
self.training_seconds.__doc__ = """\
The number of seconds that were spent in actual training during the most
recent epoch. This excludes seconds that were spent running callbacks for
the extensions, computing monitoring channels, etc."""
self.model.monitor.add_channel(
name="training_seconds_this_epoch",
ipt=None,
val=self.training_seconds,
data_specs=(NullSpace(), ''),
dataset=self.model.monitor._datasets[0])
self.total_seconds.__doc__ = """\
self.model.monitor.add_channel(
name="training_seconds_this_epoch",
ipt=None,
val=self.training_seconds,
data_specs=(NullSpace(), ''),
dataset=self.model.monitor._datasets[0])
self.total_seconds.__doc__ = """\
The number of seconds that were spent on the entirety of processing for the
previous epoch. This includes not only training but also the computation of
the monitoring channels, running TrainExtension callbacks, etc. This value
is reported for the *previous* epoch because the amount of time spent on
monitoring for this epoch is not known until the monitoring channels have
already been reported."""
self.model.monitor.add_channel(
name="total_seconds_last_epoch",
ipt=None,
val=self.total_seconds,
data_specs=(NullSpace(), ''),
dataset=self.model.monitor._datasets[0])
self.run_callbacks_and_monitoring()
self.model.monitor.add_channel(
name="total_seconds_last_epoch",
ipt=None,
val=self.total_seconds,
data_specs=(NullSpace(), ''),
dataset=self.model.monitor._datasets[0])
self.run_callbacks_and_monitoring()

while True:
if self.exceeded_time_budget(t0, time_budget):
break
while True:
if self.exceeded_time_budget(t0, time_budget):
break

with log_timing(log, None, level=logging.DEBUG,
callbacks=[self.total_seconds.set_value]):
with log_timing(
log, None, final_msg='Time this epoch:',
callbacks=[self.training_seconds.set_value]):
rval = self.algorithm.train(dataset=self.dataset)
if rval is not None:
raise ValueError("TrainingAlgorithm.train should not "
"return anything. Use "
"TrainingAlgorithm.continue_learning "
"to control whether learning "
"continues.")
self.model.monitor.report_epoch()
extension_continue = self.run_callbacks_and_monitoring()
if self.save_freq > 0 and \
self.model.monitor.get_epochs_seen() % self.save_freq == 0:
self.save()
continue_learning = (
self.algorithm.continue_learning(self.model) and
extension_continue
)
assert continue_learning in [True, False, 0, 1]
if not continue_learning:
break
with log_timing(log, None, level=logging.DEBUG,
callbacks=[self.total_seconds.set_value]):
with log_timing(
log, None, final_msg='Time this epoch:',
callbacks=[self.training_seconds.set_value]):
rval = self.algorithm.train(dataset=self.dataset)
if rval is not None:
raise ValueError("TrainingAlgorithm.train should not "
"return anything. Use "
"TrainingAlgorithm.continue_learning "
"to control whether learning "
"continues.")
self.model.monitor.report_epoch()
extension_continue = self.run_callbacks_and_monitoring()
if self.save_freq > 0 and \
self.model.monitor.get_epochs_seen() % self.save_freq == 0:
self.save()
continue_learning = (
self.algorithm.continue_learning(self.model) and
extension_continue
)
assert continue_learning in [True, False, 0, 1]
if not continue_learning:
break

self.model.monitor.training_succeeded = True
self.model.monitor.training_succeeded = True

if self.save_freq > 0:
self.save()
if self.save_freq > 0:
self.save()
except Exception:
self.tear_down()
log.error("Uncaught exception in Train's main loop",
exc_info=True)
raise
else:
self.tear_down()

def run_callbacks_and_monitoring(self):
"""
Expand Down
17 changes: 17 additions & 0 deletions pylearn2/train_extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,23 @@ def setup(self, model, dataset, algorithm):
used to train the model.
"""

def tear_down(self, model, dataset, algorithm):
"""
Train calls this after the main loop.

Parameters
----------
model : pylearn2.models.Model
The model object being trained.

dataset : pylearn2.datasets.Dataset
The dataset object being trained.

algorithm : pylearn2.training_algorithms.TrainingAlgorithm
The object representing the training algorithm being
used to train the model.
"""

class SharedSetter(TrainExtension):
"""
Sets shared variables to take on the specified values after the
Expand Down
28 changes: 19 additions & 9 deletions pylearn2/train_extensions/live_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,22 +170,32 @@ def __init__(self, address='*', req_port=5555, pub_port=5556):
assert(pub_port > 1024 and pub_port < 65536)
self.pub_port = pub_port

address_template = self.address + ':%d'
self.address_template = self.address + ':%d'
self.context = zmq.Context()

self.req_sock = None
if self.req_port > 0:
self.req_sock = self.context.socket(zmq.REP)
self.req_sock.bind(address_template % self.req_port)

self.pub_sock = None
if self.pub_port > 0:
self.pub_sock = self.context.socket(zmq.PUB)
self.req_sock.bind(address_template % self.pub_port)

# Tracks the number of times on_monitor has been called
self.counter = 0

@wraps(TrainExtension.setup)
def setup(self, model, dataset, algorithm):
if self.req_port > 0:
self.req_sock = self.context.socket(zmq.REP)
self.req_sock.bind(self.address_template % self.req_port)
if self.pub_port > 0:
self.pub_sock = self.context.socket(zmq.PUB)
self.req_sock.bind(self.address_template % self.pub_port)

@wraps(TrainExtension.tear_down)
def tear_down(self, model, dataset, algorithm):
if self.req_sock:
self.req_sock.unbind(self.address_template % self.req_port)
self.req_sock = None
if self.pub_sock:
self.req_sock.unbind(self.address_template % self.pub_port)
self.pub_sock = None

@wraps(TrainExtension.on_monitor)
def on_monitor(self, model, dataset, algorithm):
monitor = Monitor.get_monitor(model)
Expand Down