diff --git a/tf_agents/bandits/policies/ranking_policy.py b/tf_agents/bandits/policies/ranking_policy.py index 4757d36ad..4ff535236 100644 --- a/tf_agents/bandits/policies/ranking_policy.py +++ b/tf_agents/bandits/policies/ranking_policy.py @@ -14,7 +14,7 @@ # limitations under the License. """Ranking policy.""" -from typing import Optional, Sequence, Text +from typing import Optional, Text import numpy as np import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import @@ -62,15 +62,18 @@ def __init__( def _penalizer_fn( self, logits: types.Float, - features: types.Float, - slots: Sequence[types.Int], + slots: tf.Tensor, + num_slotted: tf.Tensor, ): """Downscores items by their similarity to already selected items. Args: - logits: The current logits of all items. - features: the feature vectors of the items. - slots: list of indices of already selected items. + logits: The current logits of all items, shaped as [batch_size, + num_items]. + slots: A tensor of indices of the selected items, shaped as [batch_size, + num_slots]. Only the first `num_slotted` columns correspond to valid + indices. + num_slotted: The number of slots filled so far. Returns: New logits. @@ -78,16 +81,66 @@ def _penalizer_fn( raise NotImplementedError() def _sample_n(self, n, seed=None): + # TODO(b/251139151): Support n > 1. + del n + # The scores (logits) of all items, shaped as [batch_size, num_items]. logits = tf.convert_to_tensor(self.scores) - sample_shape = tf.concat([[n], tf.shape(logits)], axis=0) - slots = [] - for _ in range(self._num_slots): - items = tfd.Categorical(logits=logits).sample() - slots.append(items) - logits -= tf.one_hot(items, sample_shape[-1], on_value=np.inf) - logits = self._penalizer_fn(logits, self._features, slots) - sample = tf.expand_dims(tf.stack(slots, axis=-1), axis=0) - return sample + # The index of the next slot to sample. + slot_idx = tf.constant(0, dtype=tf.int32) + # The indices of the items that have been sampled, shaped as + # [batch_size, num_slots]. + slots = tf.zeros( + shape=(tf.shape(logits)[0], self._num_slots), dtype=tf.int32 + ) + + def _sample_next_slot(slot_idx, slots, logits): + # Samples the batch of item indices for the next slot. + items = tf.ensure_shape( + tfd.Categorical(logits=logits).sample(), + (self._features.shape[0],), + name='ensure_shape_items', + ) + slots = tf.ensure_shape( + slots, + (self._features.shape[0], self._num_slots), + name='ensure_shape_slots', + ) + # Updates the indices of sampled items by incorporating the sampled item + # indices for the next slot. + slots = tf.ensure_shape( + slots + + tf.expand_dims(items, axis=-1) + * tf.expand_dims( + tf.one_hot( + slot_idx, + self._num_slots, + dtype=tf.int32, + name='one_hot_for_slot_idx', + ), + 0, + ), + (self._features.shape[0], self._num_slots), + name='ensure_shape_slots_after_update', + ) + # Discounts the scores (logits) of the items that have been sampled, so + # they will not be selected again. + logits -= tf.one_hot( + items, logits.shape[-1], on_value=np.inf, name='one_hot_for_items' + ) + # Applies the penalty function to the logits. + logits = tf.ensure_shape( + self._penalizer_fn(logits, slots, num_slotted=slot_idx + 1), + self.scores.shape, + ) + return slot_idx + 1, slots, logits + + _, slots, _ = tf.while_loop( + cond=lambda slot_idx, slots, logits: True, + body=_sample_next_slot, + loop_vars=(slot_idx, slots, logits), + maximum_iterations=self._num_slots, + ) + return tf.expand_dims(slots, axis=0) def _event_shape(self, scores=None): return self._num_slots @@ -96,31 +149,55 @@ def _event_shape(self, scores=None): class CosinePenalizedPlackettLuce(PenalizedPlackettLuce): """A distribution that samples items based on scores and cosine similarity.""" - def _penalizer_fn(self, logits, features, slots): + def __init__( + self, + features: types.Tensor, + num_slots: int, + logits: types.Tensor, + penalty_mixture_coefficient: float = 1.0, + ): + """Initializes an instance of CosinePenalizedPlackettLuce. + + Args: + features: Item features based on which similarity is calculated. + num_slots: The number of slots to fill: this many items will be sampled. + logits: Unnormalized log probabilities for the PlackettLuce distribution. + Shape is `[num_items]`. + penalty_mixture_coefficient: A parameter responsible for the balance + between selecting high scoring items and enforcing diverisity. + """ + super().__init__(features, num_slots, logits, penalty_mixture_coefficient) + num_items = features.shape[1] + # Computes the cosine similarity matrix between all items, shaped as + # [batch_size, num_items, num_items]. + self._sim_matrix = tf.reshape( + tf.keras.losses.cosine_similarity( + tf.repeat(features, num_items, axis=1, name='repeat_features'), + tf.tile( + features, + [1, num_items, 1], + name='tile_features', + ), + ) + - 1, + shape=[-1, num_items, num_items], + ) + + def _penalizer_fn(self, logits, slots, num_slotted): num_items = logits.shape[-1] - num_slotted = len(slots) - slot_tensor = tf.stack(slots, axis=-1) + # Gathers the pairwise similarity matrix between all items and the items + # that have been selected, shaped as [batch_size, num_slotted, num_items]. # The tfd.Categorical distribution will give the sample `num_items` if all # the logits are `-inf`. Hence, we need to apply minimum. This happens when # `num_actions` is less than `num_slots`. To this end, the action taken by # the policy always has to be taken together with the `num_actions` # observation, to know how many slots are filled with valid items. - slotted_features = tf.gather( - features, tf.minimum(slot_tensor, num_items - 1), batch_dims=1 - ) - - # Calculate the similarity between all pairs from - # `slotted_features x all_features`. - all_sims = ( - tf.keras.losses.cosine_similarity( - tf.repeat(features, num_slotted, axis=1), - tf.tile(slotted_features, [1, num_items, 1]), - ) - - 1 + sim_matrix_against_slotted = tf.gather( + self._sim_matrix, + tf.minimum(slots[..., :num_slotted], num_items - 1), + batch_dims=1, ) - - sim_matrix = tf.reshape(all_sims, shape=[-1, num_items, num_slotted]) - similarity_boosts = tf.reduce_min(sim_matrix, axis=-1) + similarity_boosts = tf.reduce_min(sim_matrix_against_slotted, axis=1) adjusted_logits = logits + ( self._penalty_mixture_coefficient * similarity_boosts ) diff --git a/tf_agents/bandits/policies/ranking_policy_test.py b/tf_agents/bandits/policies/ranking_policy_test.py index 48ac8022f..61ebb7270 100644 --- a/tf_agents/bandits/policies/ranking_policy_test.py +++ b/tf_agents/bandits/policies/ranking_policy_test.py @@ -14,7 +14,9 @@ # limitations under the License. """Tests for ranking_policy.""" + from absl.testing import parameterized +import numpy as np import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import from tf_agents.bandits.networks import global_and_arm_feature_network as arm_net from tf_agents.bandits.policies import ranking_policy @@ -84,8 +86,13 @@ def testPolicy(self, policy_class, batch_size, num_items, num_slots): ) time_spec = ts.restart(observation, batch_size=batch_size) action_step = policy.action(time_spec) + unique_item_counts = tf.map_fn( + lambda action: tf.unique_with_counts(action)[2], action_step.action + ) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllEqual(action_step.action.shape, [batch_size, num_slots]) + # All ranked items should appear exactly once in the ranked list. + self.assertAllEqual(unique_item_counts, np.ones((batch_size, num_slots))) def testTemperature(self): if not tf.executing_eagerly():