-
Notifications
You must be signed in to change notification settings - Fork 888
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add nvidia embedding implementation for new signature, task_type, output_dimention, text_truncation #1213
base: main
Are you sure you want to change the base?
Conversation
…put_dimention, text_truncation
this is blocked by meta-llama/llama-stack-client-python#162 |
if text_truncation is not None: | ||
text_truncation_options = { | ||
TextTruncation.none: "NONE", | ||
TextTruncation.end: "END", | ||
TextTruncation.start: "START", | ||
} | ||
if text_truncation not in text_truncation_options: | ||
raise ValueError(f"Invalid text_truncation: {text_truncation}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can move this validation here [1], so it applies to all providers.
[1]
async def embeddings( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
turns out fastapi will type validate before invoking the routers. i can remove this defensive code.
if task_type is not None: | ||
task_type_options = { | ||
EmbeddingTaskType.document: "passage", | ||
EmbeddingTaskType.query: "query", | ||
} | ||
if task_type not in task_type_options: | ||
raise ValueError(f"Invalid task_type: {task_type}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also removing, will rely on fastapi request validation
) | ||
|
||
|
||
@pytest.mark.xfail(reason="Only valid for model supporting dimension reduction") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a model we can use for testing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edit: you can use baai/bge-m3. i've also updated the test instructions.
oops. yes, nvidia/llama-3.2-nv-embedqa-1b-v2, see https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
What does this PR do?
updates nvidia inference provider's embedding implementation to use new signature
add support for task_type, output_dimensions, text_truncation parameters
Test Plan
LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/inference/test_embedding.py --embedding-model baai/bge-m3