Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add normalize_embeddings Argument to SentenceTransformer for Simplified Embedding Normalization #3064

Open
AIMacGyver opened this issue Nov 17, 2024 · 0 comments

Comments

@AIMacGyver
Copy link

Currently, embedding normalization in SentenceTransformer can be achieved in two ways:

  1. Adding a Normalize module to the model pipeline
  2. Manually normalizing embeddings post-encode

Both approaches work but can add complexity and may not align seamlessly with production deployment workflows.

Feature Request:
Add normalize_embeddings as an argument to SentenceTransformer.__init__ that would be passed through to encode methods, similar to how truncate_dim works. This would provide a cleaner, built-in way to control normalization behavior.

Current Workarounds:
Currently, we need to either:

  1. Subclass SentenceTransformer to add normalization
  2. Apply normalization post-encode
  3. Always include a Normalize module

Example Subclass Workaround

import numpy as np
import torch
import torch.nn.functional as F

from torch import Tensor
from sentence_transformers import SentenceTransformer 

class NormalizedSentenceTransformer(SentenceTransformer):
    def encode(self, *args, **kwargs):
        embeddings = super().encode(*args, **kwargs)
        if isinstance(embeddings, np.ndarray):
            embeddings = torch.tensor(embeddings, dtype=torch.float32)
            embeddings = F.normalize(embeddings, p=2, dim=1)
            return embeddings.numpy()
        elif isinstance(embeddings, list):
            return [F.normalize(embedding, p=2, dim=1) for embedding in embeddings]
        elif isinstance(embeddings, Tensor):
            return F.normalize(embeddings, p=2, dim=1)
        else:
            raise ValueError(f"Unsupported type for embeddings: {type(embeddings)}")

Questions for Discussion:

  1. How should this interact with models that already include a Normalize module? Should we:

    • Skip additional normalization if a Normalize module is present?
    • Allow override via the normalize_embeddings parameter?
    • Raise a warning to avoid redundancy?
  2. Should this behavior be configurable in the config_sentence_transformers.json file, like other model options?

Use Case:
When deploying embedding models to production serving endpoints (e.g., Databricks via MLflow), having normalization as a built-in configurable parameter would:

  1. Eliminate the need to wrap or subclass models.
  2. Reduce the risk of unnormalized embeddings in similarity or clustering tasks, which can drastically impact performance.
  3. Align with production-friendly design by providing a clean and intuitive API for normalization.

Example Usage

from sentence_transformers import SentenceTransformer

# Normalize embeddings during encoding
model = SentenceTransformer(
    model_name_or_path="jinaai/jina-embeddings-v2-base-code",
    trust_remote_code=True,
    normalize_embeddings=True
)
embeddings = model.encode(["This is a test sentence."])

If this feature makes sense, I'd be happy to contribute by working on a PR or incorporating any guidance and feedback on the request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant