Skip to content

Commit

Permalink
Internal Code Change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590739394
  • Loading branch information
rchen152 authored and copybara-github committed Dec 13, 2023
1 parent 1b0203e commit 90fdc49
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def train_sample(self, sample, optimizer, global_batch_size):
self.decode_batch(sample, inputs)

with tf.GradientTape() as t:
predictions = self.network(inputs['image']) # pytype: disable=key-error
predictions = self.network(inputs['image'])
outputs = self.decode_predictions(predictions, inputs, {})
outputs = self.compute_loss(inputs, outputs, global_batch_size)
network_gradients = t.gradient(outputs['loss/total'], # pytype: disable=key-error
network_gradients = t.gradient(outputs['loss/total'],
self.network.trainable_weights)
optimizer.apply_gradients(zip(network_gradients,
self.network.trainable_weights))
Expand Down

0 comments on commit 90fdc49

Please sign in to comment.