Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add nvidia embedding implementation for new signature, task_type, output_dimention, text_truncation #1213

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

mattf
Copy link
Contributor

@mattf mattf commented Feb 21, 2025

What does this PR do?

updates nvidia inference provider's embedding implementation to use new signature

add support for task_type, output_dimensions, text_truncation parameters

Test Plan

LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/inference/test_embedding.py --embedding-model baai/bge-m3

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 21, 2025
@mattf
Copy link
Contributor Author

mattf commented Feb 21, 2025

this is blocked by meta-llama/llama-stack-client-python#162

@mattf
Copy link
Contributor Author

mattf commented Feb 21, 2025

cc @raspawar @cdgamarose-nv

@mattf mattf changed the title add nvidia embedding implementation for new signature, task_type, output_dimention, text_truncation feat: add nvidia embedding implementation for new signature, task_type, output_dimention, text_truncation Feb 21, 2025
@raspawar
Copy link
Contributor

lgtm, thanks @mattf for looking into this. I see the #162 is closed and this should be ok to merge.

Comment on lines 151 to 158
if text_truncation is not None:
text_truncation_options = {
TextTruncation.none: "NONE",
TextTruncation.end: "END",
TextTruncation.start: "START",
}
if text_truncation not in text_truncation_options:
raise ValueError(f"Invalid text_truncation: {text_truncation}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move this validation here [1], so it applies to all providers.

[1]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

turns out fastapi will type validate before invoking the routers. i can remove this defensive code.

Comment on lines 164 to 170
if task_type is not None:
task_type_options = {
EmbeddingTaskType.document: "passage",
EmbeddingTaskType.query: "query",
}
if task_type not in task_type_options:
raise ValueError(f"Invalid task_type: {task_type}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also removing, will rely on fastapi request validation

)


@pytest.mark.xfail(reason="Only valid for model supporting dimension reduction")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a model we can use for testing this?

Copy link
Contributor Author

@mattf mattf Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edit: you can use baai/bge-m3. i've also updated the test instructions.

oops. yes, nvidia/llama-3.2-nv-embedqa-1b-v2, see https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html

@mattf mattf requested a review from ehhuang February 25, 2025 12:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants