diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 6bcb14847..08cc321b4 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2689,7 +2689,18 @@ def rsub(self, inputs, input_types): def embedding(self, inputs, input_types): weight = inputs[0] indices = inputs[1] - return _op.embedding(weight, indices.astype("int32"), axis=0) + # Check the type of indices + indicies_dtype = _infer_type(indices).checked_type.dtype + if indicies_dtype != "int32" and indicies_dtype != "int64": + # we want to cast indices to int32 if they are not already. + # However, If indices are int64, there is no need to cast them to int32 + # because forge doesn't support int64 and indices will be cast to int32 + # during lowering from torch to forge. Therefore adding cast of int64 to int32 + # will result in cast op that casts int32 to int32. This is not only redundant, but also + # exposes a few bugs in tt-mlir https://github.com/tenstorrent/tt-mlir/issues/1215 + logger.warning("Casting input indices of embedding op from {} to int32", indicies_dtype) + indices = tvm.relay.cast(indices, "int32") + return _op.embedding(weight, indices, axis=0) def embedding_bag(self, inputs, input_types):