From 120308b044bbee111a8eb7a2129b515fef049309 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 20 Jan 2025 22:29:31 -0800 Subject: [PATCH] test: fix flaky test --- litellm/llms/huggingface/chat/handler.py | 1 + tests/local_testing/test_embedding.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/litellm/llms/huggingface/chat/handler.py b/litellm/llms/huggingface/chat/handler.py index e9b40be6a7f8..2b65e5b7dad6 100644 --- a/litellm/llms/huggingface/chat/handler.py +++ b/litellm/llms/huggingface/chat/handler.py @@ -432,6 +432,7 @@ def _transform_input( embed_url: str, ) -> dict: data: Dict = {} + ## TRANSFORMATION ## if "sentence-transformers" in model: if len(input) == 0: diff --git a/tests/local_testing/test_embedding.py b/tests/local_testing/test_embedding.py index 6bb1e9553295..4d5cc58d8a51 100644 --- a/tests/local_testing/test_embedding.py +++ b/tests/local_testing/test_embedding.py @@ -642,12 +642,16 @@ def tgi_mock_post(*args, **kwargs): from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio -async def test_hf_embedding_sentence_sim(sync_mode): +@patch("litellm.llms.huggingface.chat.handler.async_get_hf_task_embedding_for_model") +@pytest.mark.parametrize("sync_mode", [True, False]) +async def test_hf_embedding_sentence_sim( + mock_get_hf_task_embedding_for_model, sync_mode # Add this parameter +): try: # huggingface/microsoft/codebert-base # huggingface/facebook/bart-large + mock_get_hf_task_embedding_for_model.return_value = "sentence-similarity" if sync_mode is True: client = HTTPHandler(concurrent_limit=1) else: