Skip to content

Commit

Permalink
Adds torch.nn.functional.cosine_similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
jimlinntu committed Oct 23, 2024
1 parent 5a4d7b4 commit a596978
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
"nn.functional.conv_transpose2d",
"nn.functional.conv_transpose3d",
"nn.functional.cosine_embedding_loss",
"nn.functional.cosine_similarity",
"nn.functional.ctc_loss",
"nn.functional.dropout2d",
"nn.functional.dropout3d",
Expand Down
23 changes: 23 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,29 @@ def scaled_dot_product_attention(

return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa)

@register_function(torch.nn.functional.cosine_similarity)
def cosine_similarity(x1, x2, dim=1, eps=1e-8):
if len(x1.shape) == 0 and len(x2.shape) == 0:
assert dim == 0

numerator = x1 * x2
x1_norm = jnp.maximum(jnp.linalg.vector_norm(x1), eps)
x2_norm = jnp.maximum(jnp.linalg.vector_norm(x2), eps)

denominator = x1_norm * x2_norm
else:
broadcasted_x1, broadcasted_x2 = jnp.broadcast_arrays(x1, x2)

numerator = jnp.vecdot(broadcasted_x1, broadcasted_x2, axis=dim)

x1_norm = jnp.maximum(jnp.linalg.vector_norm(broadcasted_x1, axis=dim), eps)
x2_norm = jnp.maximum(jnp.linalg.vector_norm(broadcasted_x2, axis=dim), eps)

denominator = x1_norm * x2_norm

return numerator / denominator


@register_function(torch.Tensor.__getitem__)
def getitem(self, indexes):
if isinstance(indexes, list) and isinstance(indexes[0], int):
Expand Down

0 comments on commit a596978

Please sign in to comment.