diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 8cfd16ae9..de3d495f2 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -27,6 +27,7 @@ import torch from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings +from tensordict import TensorDict from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function from torch.distributed._shard.sharded_tensor import TensorProperties @@ -94,6 +95,7 @@ from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -656,9 +658,7 @@ def __init__( self._inverse_indices_permute_indices: Optional[torch.Tensor] = None # to support mean pooling callback hook self._has_mean_pooling_callback: bool = ( - True - if PoolingType.MEAN.value in self._pooling_type_to_rs_features - else False + PoolingType.MEAN.value in self._pooling_type_to_rs_features ) self._dim_per_key: Optional[torch.Tensor] = None self._kjt_key_indices: Dict[str, int] = {} @@ -1189,8 +1189,16 @@ def _create_inverse_indices_permute_indices( # pyre-ignore [14] def input_dist( - self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor + self, + ctx: EmbeddingBagCollectionContext, + features: Union[KeyedJaggedTensor, TensorDict], ) -> Awaitable[Awaitable[KJTList]]: + if isinstance(features, TensorDict): + feature_keys = list(features.keys()) # pyre-ignore[6] + if len(self._features_order) > 0: + feature_keys = [feature_keys[i] for i in self._features_order] + self._has_features_permute = False # feature_keys are in order + features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] ctx.variable_batch_per_feature = features.variable_stride_per_key() ctx.inverse_indices = features.inverse_indices_or_none() if self._has_uninitialized_input_dist: diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index e8dc5eccb..fdb900fe0 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -160,7 +160,7 @@ def main( tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 1000, + num_embeddings=max(i + 1, 100) * 1000, embedding_dim=dim_emb, name="table_" + str(i), feature_names=["feature_" + str(i)], @@ -169,7 +169,7 @@ def main( ] weighted_tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 1000, + num_embeddings=max(i + 1, 100) * 1000, embedding_dim=dim_emb, name="weighted_table_" + str(i), feature_names=["weighted_feature_" + str(i)], diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 307d66639..4ade3df2f 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -19,6 +19,7 @@ pooling_type_to_str, ) from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt @torch.fx.wrap @@ -229,6 +230,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: KeyedTensor """ flat_feature_names: List[str] = [] + features = maybe_td_to_kjt(features, None) for names in self._feature_names: flat_feature_names.extend(names) inverse_indices = reorder_inverse_indices(