From 528cef7c4aedf54158a0564fdca446fe9942aa2a Mon Sep 17 00:00:00 2001 From: TF-Agents Team Date: Thu, 12 Dec 2024 05:38:41 -0800 Subject: [PATCH] Add per example reward loss in LossInfo extra PiperOrigin-RevId: 705472790 Change-Id: I6ef2f76efd06edae950a1754b5ad0fe2773b455a --- .../greedy_multi_objective_neural_agent.py | 36 ++++++++++--------- .../agents/greedy_reward_prediction_agent.py | 33 ++++++++--------- 2 files changed, 33 insertions(+), 36 deletions(-) diff --git a/tf_agents/bandits/agents/greedy_multi_objective_neural_agent.py b/tf_agents/bandits/agents/greedy_multi_objective_neural_agent.py index 133536f46..df67a4af0 100644 --- a/tf_agents/bandits/agents/greedy_multi_objective_neural_agent.py +++ b/tf_agents/bandits/agents/greedy_multi_objective_neural_agent.py @@ -314,14 +314,14 @@ def _single_objective_loss( sample_weights = ( sample_weights * 0.5 * tf.exp(-action_predicted_log_variance) ) - loss = 0.5 * tf.reduce_mean(action_predicted_log_variance) + loss = 0.5 * action_predicted_log_variance # loss = 1/(2 * var(x)) * (y - f(x))^2 + 1/2 * log var(x) # Kendall, Alex, and Yarin Gal. "What Uncertainties Do We Need in # Bayesian Deep Learning for Computer Vision?." Advances in Neural # Information Processing Systems. 2017. https://arxiv.org/abs/1703.04977 else: predicted_values, _ = objective_network(observations, training=training) - loss = tf.constant(0.0) + loss = tf.zeros_like(single_objective_values) action_predicted_values = common.index_with_actions( predicted_values, tf.cast(actions, dtype=tf.int32) @@ -338,18 +338,13 @@ def _single_objective_loss( objective_idx ] * tf.reduce_mean(smoothness_batched * sample_weights) - # Reduction is done outside of the loss function because non-scalar - # weights with unknown shapes may trigger shape validation that fails - # XLA compilation. - loss += tf.reduce_mean( - tf.multiply( - self._error_loss_fns[objective_idx]( - single_objective_values, - action_predicted_values, - reduction=tf.compat.v1.losses.Reduction.NONE, - ), - sample_weights, - ) + loss += tf.multiply( + self._error_loss_fns[objective_idx]( + single_objective_values, + action_predicted_values, + reduction=tf.compat.v1.losses.Reduction.NONE, + ), + sample_weights, ) return loss @@ -401,10 +396,10 @@ def _loss( ) ) - objective_losses = [] + per_example_objective_losses = [] for idx in range(self._num_objectives): single_objective_values = objective_values[:, idx] - objective_losses.append( + per_example_objective_losses.append( self._single_objective_loss( idx, observations, @@ -414,10 +409,17 @@ def _loss( training, ) ) + per_example_loss = tf.reduce_sum( + tf.stack(per_example_objective_losses), axis=0 + ) + objective_losses = [ + tf.reduce_mean(per_example_objective_losses[idx]) + for idx in range(self._num_objectives) + ] self.compute_summaries(objective_losses) total_loss = tf.reduce_sum(objective_losses) - return tf_agent.LossInfo(total_loss, extra=()) + return tf_agent.LossInfo(total_loss, extra=per_example_loss) def compute_summaries(self, losses: Sequence[tf.Tensor]): if self._num_objectives != len(losses): diff --git a/tf_agents/bandits/agents/greedy_reward_prediction_agent.py b/tf_agents/bandits/agents/greedy_reward_prediction_agent.py index 155558fbf..c54307c94 100644 --- a/tf_agents/bandits/agents/greedy_reward_prediction_agent.py +++ b/tf_agents/bandits/agents/greedy_reward_prediction_agent.py @@ -287,7 +287,7 @@ def _train(self, experience, weights): return loss_info - def reward_loss( + def per_example_reward_loss( self, observations: types.NestedTensor, actions: types.Tensor, @@ -325,7 +325,7 @@ def reward_loss( sample_weights * 0.5 * tf.exp(-action_predicted_log_variance) ) - loss = 0.5 * tf.reduce_mean(action_predicted_log_variance) + loss = 0.5 * action_predicted_log_variance # loss = 1/(2 * var(x)) * (y - f(x))^2 + 1/2 * log var(x) # Kendall, Alex, and Yarin Gal. "What Uncertainties Do We Need in # Bayesian Deep Learning for Computer Vision?." Advances in Neural @@ -334,7 +334,7 @@ def reward_loss( predicted_values, _ = self._reward_network( observations, training=training ) - loss = tf.constant(0.0) + loss = tf.zeros_like(rewards) action_predicted_values = common.index_with_actions( predicted_values, tf.cast(actions, dtype=tf.int32) @@ -348,24 +348,18 @@ def reward_loss( self._laplacian_matrix, predicted_values, transpose_b=True ), ) - loss += self._laplacian_smoothing_weight * tf.reduce_mean( + loss += self._laplacian_smoothing_weight * ( tf.linalg.tensor_diag_part(smoothness_batched) * sample_weights ) - # Reduction is done outside of the loss function because non-scalar - # weights with unknown shapes may trigger shape validation that fails - # XLA compilation. - loss += tf.reduce_mean( - tf.multiply( - self._error_loss_fn( - rewards, - action_predicted_values, - reduction=tf.compat.v1.losses.Reduction.NONE, - ), - sample_weights, - ) + loss += tf.multiply( + self._error_loss_fn( + rewards, + action_predicted_values, + reduction=tf.compat.v1.losses.Reduction.NONE, + ), + sample_weights, ) - return loss def _loss( @@ -404,9 +398,10 @@ def _loss( rewards_tensor = rewards[bandit_spec_utils.REWARD_SPEC_KEY] else: rewards_tensor = rewards - reward_loss = self.reward_loss( + per_example_reward_loss = self.per_example_reward_loss( observations, actions, rewards_tensor, weights, training ) + reward_loss = tf.reduce_mean(per_example_reward_loss) constraint_loss = tf.constant(0.0) for i, c in enumerate(self._constraints, 0): @@ -424,7 +419,7 @@ def _loss( total_loss = reward_loss if self._constraints: total_loss += constraint_loss - return tf_agent.LossInfo(total_loss, extra=()) + return tf_agent.LossInfo(total_loss, extra=per_example_reward_loss) def compute_summaries( self, loss: types.Tensor, constraint_loss: Optional[types.Tensor] = None