diff --git a/official/modeling/multitask/interleaving_trainer.py b/official/modeling/multitask/interleaving_trainer.py index 3d9c8554fec..27d5aef7757 100644 --- a/official/modeling/multitask/interleaving_trainer.py +++ b/official/modeling/multitask/interleaving_trainer.py @@ -43,12 +43,6 @@ def __init__(self, trainer_options=trainer_options) self._task_sampler = task_sampler - # TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging - # on TensorBoard. - self._task_step_counters = { - name: orbit.utils.create_global_step() for name in self.multi_task.tasks - } - # Build per task train step. def _get_task_step(task_name, task): @@ -63,8 +57,6 @@ def step_fn(inputs): optimizer=self.optimizer, metrics=self.training_metrics[task_name]) self.training_losses[task_name].update_state(task_logs[task.loss]) - self.global_step.assign_add(1) - self.task_step_counter(task_name).assign_add(1) return step_fn @@ -73,6 +65,12 @@ def step_fn(inputs): for name, task in self.multi_task.tasks.items() } + # TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging + # on TensorBoard. + self._task_step_counters = { + name: orbit.utils.create_global_step() for name in self.multi_task.tasks + } + # If the new Keras optimizer is used, we require all model variables are # created before the training and let the optimizer to create the slot # variable all together. @@ -99,6 +97,8 @@ def train_step(self, iterator_map): if rn >= begin and rn < end: self._strategy.run( self._task_train_step_map[name], args=(next(iterator_map[name]),)) + self.global_step.assign_add(1) + self.task_step_counter(name).assign_add(1) def train_loop_end(self): """Record loss and metric values per task."""