Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623979125
  • Loading branch information
tensorflower-gardener committed Apr 11, 2024
1 parent 3830799 commit 1731726
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions official/modeling/multitask/interleaving_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ def __init__(self,
trainer_options=trainer_options)
self._task_sampler = task_sampler

# TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
# on TensorBoard.
self._task_step_counters = {
name: orbit.utils.create_global_step() for name in self.multi_task.tasks
}

# Build per task train step.
def _get_task_step(task_name, task):

Expand All @@ -63,8 +57,6 @@ def step_fn(inputs):
optimizer=self.optimizer,
metrics=self.training_metrics[task_name])
self.training_losses[task_name].update_state(task_logs[task.loss])
self.global_step.assign_add(1)
self.task_step_counter(task_name).assign_add(1)

return step_fn

Expand All @@ -73,6 +65,12 @@ def step_fn(inputs):
for name, task in self.multi_task.tasks.items()
}

# TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
# on TensorBoard.
self._task_step_counters = {
name: orbit.utils.create_global_step() for name in self.multi_task.tasks
}

# If the new Keras optimizer is used, we require all model variables are
# created before the training and let the optimizer to create the slot
# variable all together.
Expand All @@ -99,6 +97,8 @@ def train_step(self, iterator_map):
if rn >= begin and rn < end:
self._strategy.run(
self._task_train_step_map[name], args=(next(iterator_map[name]),))
self.global_step.assign_add(1)
self.task_step_counter(name).assign_add(1)

def train_loop_end(self):
"""Record loss and metric values per task."""
Expand Down

0 comments on commit 1731726

Please sign in to comment.