diff --git a/tf_agents/bandits/agents/examples/v2/train_eval_ranking.py b/tf_agents/bandits/agents/examples/v2/train_eval_ranking.py index 71f3283ee..f98f59730 100644 --- a/tf_agents/bandits/agents/examples/v2/train_eval_ranking.py +++ b/tf_agents/bandits/agents/examples/v2/train_eval_ranking.py @@ -170,7 +170,17 @@ def _relevance_fn(global_obs, item_obs): raise NotImplementedError( 'Policy type {} is not implemented'.format(FLAGS.policy_type) ) - positional_bias_type = FLAGS.bias_type or None + if FLAGS.positional_bias_type == 'base': + positional_bias_type = ranking_agent.PositionalBiasType.BASE + elif FLAGS.positional_bias_type == 'exponent': + positional_bias_type = ranking_agent.PositionalBiasType.EXPONENT + else: + raise NotImplementedError( + 'Positional bias type {} is not implemented'.format( + FLAGS.positional_bias_type + ) + ) + agent = ranking_agent.RankingAgent( time_step_spec=environment.time_step_spec(), action_spec=environment.action_spec(), diff --git a/tf_agents/bandits/agents/ranking_agent.py b/tf_agents/bandits/agents/ranking_agent.py index 0d2706821..3c1b80f6b 100644 --- a/tf_agents/bandits/agents/ranking_agent.py +++ b/tf_agents/bandits/agents/ranking_agent.py @@ -116,6 +116,19 @@ class FeedbackModel(enum.Enum): SCORE_VECTOR = 2 +class PositionalBiasType(enum.Enum): + """Enumeration of positional bias types.""" + + UNSET = 0 + # The bias weight for each slot position is `k^s`, where `s` is the bias + # severity and `k` is the position. + BASE = 1 + # The weights are `s^k`. These bias adjustment types are inspired by Ovaisi + # et al. `Correcting for Selection Bias in Learning-to-rank Systems` + # (WWW 2020). + EXPONENT = 2 + + class RankingAgent(tf_agent.TFAgent): """Ranking agent class.""" @@ -129,7 +142,7 @@ def __init__( error_loss_fn: types.LossFn = tf.compat.v1.losses.mean_squared_error, feedback_model: FeedbackModel = FeedbackModel.CASCADING, non_click_score: Optional[float] = None, - positional_bias_type: Optional[Text] = None, + positional_bias_type: PositionalBiasType = PositionalBiasType.UNSET, positional_bias_severity: Optional[float] = None, positional_bias_positive_only: bool = False, logits_temperature: float = 1.0, @@ -162,16 +175,9 @@ def __init__( non_click_score: (float) For the cascading feedback model, this is the score value for items lying "before" the clicked item. If not set, -1 is used. It is recommended (but not enforced) to use a negative value. - positional_bias_type: (string) If not set (or set to `None`), the agent - does not apply bias adjustment. If set to either `base` or `exponent`, - it parameter determines what way the positional bias is accounted for. - `base`: The bias weight for each slot position is `k^s`, where `s` is - the bias severity (set in the next parameter), and `k` is the position. - `exponent`: The weights are `s^k`. These bias adjustment types are - inspired by Ovaisi et al. `Correcting for Selection Bias in - Learning-to-rank Systems` (WWW 2020). - positional_bias_severity: (float) The severity `s`, used as explained - above. If `positional_bias_type` is unset, this parameter has no effect. + positional_bias_type: Type of positional bias to use when training. + positional_bias_severity: (float) The severity `s`, used for the `BASE` + positional bias type. positional_bias_positive_only: Whether to use the above defined bias weights only for positives (that is, clicked items). If `positional_bias_type` is unset, this parameter has no effect. @@ -403,21 +409,19 @@ def _construct_sample_weights(self, reward, observation, weights): chosen_index + 1, self._num_slots, dtype=tf.float32 ) weights = multiplier * weights - if self._positional_bias_type is not None: + if self._positional_bias_type != PositionalBiasType.UNSET: batched_range = tf.broadcast_to( tf.range(self._num_slots, dtype=tf.float32), tf.shape(weights) ) - if self._positional_bias_type == 'base': + if self._positional_bias_type == PositionalBiasType.BASE: position_bias_multipliers = tf.pow( batched_range + 1, self._positional_bias_severity ) - elif self._positional_bias_type == 'exponent': + elif self._positional_bias_type == PositionalBiasType.EXPONENT: position_bias_multipliers = tf.pow( self._positional_bias_severity, batched_range ) else: - raise ValueError( - 'non-existing bias type: ' + self._positional_bias_type - ) + raise ValueError('non-existing positional bias type') weights = position_bias_multipliers * weights return weights diff --git a/tf_agents/bandits/agents/ranking_agent_test.py b/tf_agents/bandits/agents/ranking_agent_test.py index dffe72907..46b949228 100644 --- a/tf_agents/bandits/agents/ranking_agent_test.py +++ b/tf_agents/bandits/agents/ranking_agent_test.py @@ -308,7 +308,7 @@ def testTrainAgentScoreFeedback( 'item_dim': 3, 'num_items': 10, 'num_slots': 5, - 'positional_bias_type': 'base', + 'positional_bias_type': ranking_agent.PositionalBiasType.BASE, 'positional_bias_severity': 1.2, 'positional_bias_positive_only': False, }, @@ -320,7 +320,7 @@ def testTrainAgentScoreFeedback( 'item_dim': 5, 'num_items': 21, 'num_slots': 17, - 'positional_bias_type': 'exponent', + 'positional_bias_type': ranking_agent.PositionalBiasType.EXPONENT, 'positional_bias_severity': 1.3, 'positional_bias_positive_only': False, }, @@ -332,19 +332,7 @@ def testTrainAgentScoreFeedback( 'item_dim': 4, 'num_items': 13, 'num_slots': 11, - 'positional_bias_type': 'base', - 'positional_bias_severity': 1.0, - 'positional_bias_positive_only': True, - }, - { - 'feedback_model': ranking_agent.FeedbackModel.SCORE_VECTOR, - 'policy_type': ranking_agent.RankingPolicyType.DESCENDING_SCORES, - 'batch_size': 2, - 'global_dim': 3, - 'item_dim': 4, - 'num_items': 13, - 'num_slots': 11, - 'positional_bias_type': 'invalid', + 'positional_bias_type': ranking_agent.PositionalBiasType.BASE, 'positional_bias_severity': 1.0, 'positional_bias_positive_only': True, }, @@ -435,19 +423,15 @@ def testPositionalBiasParams( ), ) experience = _get_experience(initial_step, action_step, final_step) - if positional_bias_type == 'invalid': - with self.assertRaisesRegex(ValueError, 'non-existing bias type'): - agent.train(experience) - else: - agent.train(experience) - weights = agent._construct_sample_weights(scores, observations, None) - self.assertAllEqual(weights.shape, [batch_size, num_slots]) - expected = ( - 2**positional_bias_severity - if positional_bias_type == 'base' - else positional_bias_severity - ) - self.assertAllClose(weights[-1, 1], expected) + agent.train(experience) + weights = agent._construct_sample_weights(scores, observations, None) + self.assertAllEqual(weights.shape, [batch_size, num_slots]) + expected = ( + 2**positional_bias_severity + if positional_bias_type == ranking_agent.PositionalBiasType.BASE + else positional_bias_severity + ) + self.assertAllClose(weights[-1, 1], expected) if __name__ == '__main__':