diff --git a/tf_agents/bandits/agents/examples/v2/train_eval_ranking.py b/tf_agents/bandits/agents/examples/v2/train_eval_ranking.py index f98f59730..2da34f3dd 100644 --- a/tf_agents/bandits/agents/examples/v2/train_eval_ranking.py +++ b/tf_agents/bandits/agents/examples/v2/train_eval_ranking.py @@ -64,12 +64,19 @@ 'bias_type', '', 'Whether the agent models the positional ' - 'bias with the basis or the exponent changes. If unset, the' + 'bias with the basis, the exponent or fixed bias weights. If unset, the' ' agent applies no positional bias.', ) flags.DEFINE_float( 'bias_severity', 1.0, 'The severity of the bias adjustment by the agent.' ) +flags.DEFINE_list( + 'bias_weights', + [], + 'The positional bias weights. For FIXED_BIAS_WEIGHTS type, the agent will' + ' use these weights to adjust the rewards. The length of the list must be' + ' equal to the number of slots.', +) flags.DEFINE_bool( 'bias_positive_only', False, @@ -174,12 +181,15 @@ def _relevance_fn(global_obs, item_obs): positional_bias_type = ranking_agent.PositionalBiasType.BASE elif FLAGS.positional_bias_type == 'exponent': positional_bias_type = ranking_agent.PositionalBiasType.EXPONENT + elif FLAGS.positional_bias_type == 'fixed_bias_weights': + positional_bias_type = ranking_agent.PositionalBiasType.FIXED_BIAS_WEIGHTS else: raise NotImplementedError( 'Positional bias type {} is not implemented'.format( FLAGS.positional_bias_type ) ) + positional_bias_weights = [float(w) for w in FLAGS.positional_bias_weights] agent = ranking_agent.RankingAgent( time_step_spec=environment.time_step_spec(), @@ -190,6 +200,7 @@ def _relevance_fn(global_obs, item_obs): feedback_model=feedback_model, positional_bias_type=positional_bias_type, positional_bias_severity=FLAGS.bias_severity, + positional_bias_weights=positional_bias_weights, positional_bias_positive_only=FLAGS.bias_positive_only, summarize_grads_and_vars=True, ) diff --git a/tf_agents/bandits/agents/ranking_agent.py b/tf_agents/bandits/agents/ranking_agent.py index 3c1b80f6b..9558fbcec 100644 --- a/tf_agents/bandits/agents/ranking_agent.py +++ b/tf_agents/bandits/agents/ranking_agent.py @@ -38,8 +38,9 @@ recommendation. The user is responsible for converting the observation to the syntax required by the agent. """ + import enum -from typing import Optional, Text +from typing import List, Optional, Text import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import from tf_agents.agents import tf_agent @@ -127,6 +128,9 @@ class PositionalBiasType(enum.Enum): # et al. `Correcting for Selection Bias in Learning-to-rank Systems` # (WWW 2020). EXPONENT = 2 + # The bias weight for each slot position is `bias_weights[k]`, where + # `bias_weights` is the given bias weight array and `k` is the position. + FIXED_BIAS_WEIGHTS = 3 class RankingAgent(tf_agent.TFAgent): @@ -144,6 +148,7 @@ def __init__( non_click_score: Optional[float] = None, positional_bias_type: PositionalBiasType = PositionalBiasType.UNSET, positional_bias_severity: Optional[float] = None, + positional_bias_weights: Optional[List[float]] = None, positional_bias_positive_only: bool = False, logits_temperature: float = 1.0, summarize_grads_and_vars: bool = False, @@ -178,6 +183,8 @@ def __init__( positional_bias_type: Type of positional bias to use when training. positional_bias_severity: (float) The severity `s`, used for the `BASE` positional bias type. + positional_bias_weights: (float array) The positional bias weight for each + slot position. positional_bias_positive_only: Whether to use the above defined bias weights only for positives (that is, clicked items). If `positional_bias_type` is unset, this parameter has no effect. @@ -230,6 +237,22 @@ def __init__( ) self._positional_bias_type = positional_bias_type self._positional_bias_severity = positional_bias_severity + # Validate positional_bias_weights for FIXED_BIAS_WEIGHTS PositionalBiasType + if self._positional_bias_type == PositionalBiasType.FIXED_BIAS_WEIGHTS: + if positional_bias_weights is None: + raise ValueError( + 'positional_bias_weights is None but should never be for' + ' FIXED_BIAS_WEIGHTS PositionalBiasType.' + ) + elif len(positional_bias_weights) != self._num_slots: + raise ValueError( + 'The length of positional_bias_weights should be the same as the' + ' number of slots. The length of positional_bias_weights is {} and' + ' the number of slots is {}.'.format( + len(positional_bias_weights), self._num_slots + ) + ) + self._positional_bias_weights = positional_bias_weights self._positional_bias_positive_only = positional_bias_positive_only if policy_type == RankingPolicyType.UNKNOWN: policy_type = RankingPolicyType.COSINE_DISTANCE @@ -409,19 +432,27 @@ def _construct_sample_weights(self, reward, observation, weights): chosen_index + 1, self._num_slots, dtype=tf.float32 ) weights = multiplier * weights - if self._positional_bias_type != PositionalBiasType.UNSET: - batched_range = tf.broadcast_to( - tf.range(self._num_slots, dtype=tf.float32), tf.shape(weights) + + if self._positional_bias_type == PositionalBiasType.UNSET: + return weights + + batched_range = tf.broadcast_to( + tf.range(self._num_slots, dtype=tf.float32), tf.shape(weights) + ) + if self._positional_bias_type == PositionalBiasType.BASE: + position_bias_multipliers = tf.pow( + batched_range + 1, self._positional_bias_severity ) - if self._positional_bias_type == PositionalBiasType.BASE: - position_bias_multipliers = tf.pow( - batched_range + 1, self._positional_bias_severity - ) - elif self._positional_bias_type == PositionalBiasType.EXPONENT: - position_bias_multipliers = tf.pow( - self._positional_bias_severity, batched_range - ) - else: - raise ValueError('non-existing positional bias type') - weights = position_bias_multipliers * weights + elif self._positional_bias_type == PositionalBiasType.EXPONENT: + position_bias_multipliers = tf.pow( + self._positional_bias_severity, batched_range + ) + elif self._positional_bias_type == PositionalBiasType.FIXED_BIAS_WEIGHTS: + position_bias_multipliers = tf.tile( + tf.expand_dims(self._positional_bias_weights, axis=0), + [batch_size, 1], + ) + else: + raise ValueError('non-existing positional bias type') + weights = position_bias_multipliers * weights return weights diff --git a/tf_agents/bandits/agents/ranking_agent_test.py b/tf_agents/bandits/agents/ranking_agent_test.py index 46b949228..421d59d7f 100644 --- a/tf_agents/bandits/agents/ranking_agent_test.py +++ b/tf_agents/bandits/agents/ranking_agent_test.py @@ -311,6 +311,8 @@ def testTrainAgentScoreFeedback( 'positional_bias_type': ranking_agent.PositionalBiasType.BASE, 'positional_bias_severity': 1.2, 'positional_bias_positive_only': False, + 'positional_bias_weights': None, + 'expected_second_weight': 2.2974, # 2**positional_bias_severity }, { 'feedback_model': ranking_agent.FeedbackModel.SCORE_VECTOR, @@ -323,6 +325,8 @@ def testTrainAgentScoreFeedback( 'positional_bias_type': ranking_agent.PositionalBiasType.EXPONENT, 'positional_bias_severity': 1.3, 'positional_bias_positive_only': False, + 'positional_bias_weights': None, + 'expected_second_weight': 1.3, # positional_bias_severity }, { 'feedback_model': ranking_agent.FeedbackModel.SCORE_VECTOR, @@ -335,6 +339,36 @@ def testTrainAgentScoreFeedback( 'positional_bias_type': ranking_agent.PositionalBiasType.BASE, 'positional_bias_severity': 1.0, 'positional_bias_positive_only': True, + 'positional_bias_weights': None, + 'expected_second_weight': 2.0, # 2**positional_bias_severity + }, + { + 'feedback_model': ranking_agent.FeedbackModel.SCORE_VECTOR, + 'policy_type': ranking_agent.RankingPolicyType.DESCENDING_SCORES, + 'batch_size': 2, + 'global_dim': 3, + 'item_dim': 4, + 'num_items': 13, + 'num_slots': 11, + 'positional_bias_type': ( + ranking_agent.PositionalBiasType.FIXED_BIAS_WEIGHTS + ), + 'positional_bias_severity': None, + 'positional_bias_positive_only': True, + 'positional_bias_weights': [ + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.8, + 0.9, + 1.0, + 1.1, + ], + 'expected_second_weight': 0.2, # positional_bias_weights[1] }, ]) def testPositionalBiasParams( @@ -349,6 +383,8 @@ def testPositionalBiasParams( positional_bias_type, positional_bias_severity, positional_bias_positive_only, + positional_bias_weights, + expected_second_weight, ): if not tf.executing_eagerly(): self.skipTest('Only works in eager mode.') @@ -386,6 +422,7 @@ def testPositionalBiasParams( positional_bias_type=positional_bias_type, positional_bias_severity=positional_bias_severity, positional_bias_positive_only=positional_bias_positive_only, + positional_bias_weights=positional_bias_weights, optimizer=optimizer, ) global_obs = tf.reshape( @@ -426,12 +463,7 @@ def testPositionalBiasParams( agent.train(experience) weights = agent._construct_sample_weights(scores, observations, None) self.assertAllEqual(weights.shape, [batch_size, num_slots]) - expected = ( - 2**positional_bias_severity - if positional_bias_type == ranking_agent.PositionalBiasType.BASE - else positional_bias_severity - ) - self.assertAllClose(weights[-1, 1], expected) + self.assertAllClose(weights[-1, 1], expected_second_weight) if __name__ == '__main__':