From 3c813b3a878e678a4fb1e085a9b8adad27275f63 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Tue, 4 Feb 2025 22:30:00 -0800 Subject: [PATCH] Fix deepseek calling - refactor to use base_llm_http_handler (#8266) * refactor(deepseek/): move deepseek to base llm http handler Fixes https://github.com/BerriAI/litellm/issues/8128#issuecomment-2635430457 * fix(gpt_transformation.py): support stream parsing for gpt-like calls * test(test_deepseek_completion.py): add async streaming test * fix(gpt_transformation.py): fix import * fix(gpt_transformation.py): return full api base and content type --- .../get_llm_provider_logic.py | 1 + .../llms/openai/chat/gpt_transformation.py | 106 +++++++++++++++++- litellm/main.py | 31 ++++- ...odel_prices_and_context_window_backup.json | 11 ++ .../test_deepseek_completion.py | 51 ++++++++- 5 files changed, 194 insertions(+), 6 deletions(-) diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index 302865629a14..a64e7dd700d0 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -490,6 +490,7 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915 or get_secret("DEEPSEEK_API_BASE") or "https://api.deepseek.com/beta" ) # type: ignore + dynamic_api_key = api_key or get_secret_str("DEEPSEEK_API_KEY") elif custom_llm_provider == "fireworks_ai": # fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1 diff --git a/litellm/llms/openai/chat/gpt_transformation.py b/litellm/llms/openai/chat/gpt_transformation.py index 6fa43cccbfeb..98c3254da4a0 100644 --- a/litellm/llms/openai/chat/gpt_transformation.py +++ b/litellm/llms/openai/chat/gpt_transformation.py @@ -2,16 +2,27 @@ Support for gpt model family """ -from typing import TYPE_CHECKING, Any, List, Optional, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Iterator, + List, + Optional, + Union, + cast, +) import httpx import litellm +from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator from litellm.llms.base_llm.base_utils import BaseLLMModelInfo from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import AllMessageValues -from litellm.types.utils import ModelResponse +from litellm.types.utils import ModelResponse, ModelResponseStream +from litellm.utils import convert_to_model_response_object from ..common_utils import OpenAIError @@ -210,7 +221,36 @@ def transform_response( Returns: dict: The transformed response. """ - raise NotImplementedError + + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=raw_response.text, + additional_args={"complete_input_dict": request_data}, + ) + + ## RESPONSE OBJECT + try: + completion_response = raw_response.json() + except Exception as e: + response_headers = getattr(raw_response, "headers", None) + raise OpenAIError( + message="Unable to get json response - {}, Original Response: {}".format( + str(e), raw_response.text + ), + status_code=raw_response.status_code, + headers=response_headers, + ) + raw_response_headers = dict(raw_response.headers) + final_response_obj = convert_to_model_response_object( + response_object=completion_response, + model_response_object=model_response, + hidden_params={"headers": raw_response_headers}, + _response_headers=raw_response_headers, + ) + + return cast(ModelResponse, final_response_obj) def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] @@ -221,6 +261,30 @@ def get_error_class( headers=cast(httpx.Headers, headers), ) + def get_complete_url( + self, + api_base: str, + model: str, + optional_params: dict, + stream: Optional[bool] = None, + ) -> str: + """ + Get the complete URL for the API call. + + Returns: + str: The complete URL for the API call. + """ + endpoint = "chat/completions" + + # Remove trailing slash from api_base if present + api_base = api_base.rstrip("/") + + # Check if endpoint is already in the api_base + if endpoint in api_base: + return api_base + + return f"{api_base}/{endpoint}" + def validate_environment( self, headers: dict, @@ -230,7 +294,14 @@ def validate_environment( api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: - raise NotImplementedError + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + + # Ensure Content-Type is set to application/json + if "content-type" not in headers and "Content-Type" not in headers: + headers["Content-Type"] = "application/json" + + return headers def get_models( self, api_key: Optional[str] = None, api_base: Optional[str] = None @@ -272,3 +343,30 @@ def get_api_base(api_base: Optional[str] = None) -> Optional[str]: or get_secret_str("OPENAI_API_BASE") or "https://api.openai.com/v1" ) + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ) -> Any: + return OpenAIChatCompletionStreamingHandler( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, + ) + + +class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator): + + def chunk_parser(self, chunk: dict) -> ModelResponseStream: + try: + return ModelResponseStream( + id=chunk["id"], + object="chat.completion.chunk", + created=chunk["created"], + model=chunk["model"], + choices=chunk["choices"], + ) + except Exception as e: + raise e diff --git a/litellm/main.py b/litellm/main.py index cc71d3133bd8..a6171ec9effb 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1370,6 +1370,36 @@ def completion( # type: ignore # noqa: PLR0915 "api_base": api_base, }, ) + elif custom_llm_provider == "deepseek": + ## COMPLETION CALL + try: + response = base_llm_http_handler.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + timeout=timeout, # type: ignore + client=client, + custom_llm_provider=custom_llm_provider, + encoding=encoding, + stream=stream, + ) + except Exception as e: + ## LOGGING - log the original exception returned + logging.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + additional_args={"headers": headers}, + ) + raise e + elif custom_llm_provider == "azure_ai": api_base = ( api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there @@ -1611,7 +1641,6 @@ def completion( # type: ignore # noqa: PLR0915 or custom_llm_provider == "cerebras" or custom_llm_provider == "sambanova" or custom_llm_provider == "volcengine" - or custom_llm_provider == "deepseek" or custom_llm_provider == "anyscale" or custom_llm_provider == "mistral" or custom_llm_provider == "openai" diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 987ef948a57c..03eca86e960b 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -8904,5 +8904,16 @@ "supports_function_calling": true, "mode": "chat", "supports_tool_choice": true + }, + "hyperbolic/deepseek-v3": { + "max_tokens": 20480, + "max_input_tokens": 131072, + "max_output_tokens": 20480, + "litellm_provider": "openai", + "input_cost_per_token": 0.00000025, + "output_cost_per_token": 0.00000025, + "mode": "chat", + "supports_function_calling": true, + "supports_response_schema": true } } diff --git a/tests/llm_translation/test_deepseek_completion.py b/tests/llm_translation/test_deepseek_completion.py index 000c02d671b7..a07bf3ffe88b 100644 --- a/tests/llm_translation/test_deepseek_completion.py +++ b/tests/llm_translation/test_deepseek_completion.py @@ -7,7 +7,7 @@ class TestDeepSeekChatCompletion(BaseLLMChatTest): def get_base_completion_call_args(self) -> dict: return { - "model": "deepseek/deepseek-chat", + "model": "deepseek/deepseek-reasoner", } def test_tool_call_no_arguments(self, tool_call_no_arguments): @@ -21,3 +21,52 @@ def test_multilingual_requests(self): Todo: if litellm.modify_params is True ensure it's a valid utf-8 sequence """ pass + + +@pytest.mark.parametrize("stream", [True, False]) +def test_deepseek_mock_completion(stream): + """ + Deepseek API is hanging. Mock the call, to a fake endpoint, so we can confirm our integration is working. + """ + import litellm + from litellm import completion + + litellm._turn_on_debug() + + response = completion( + model="deepseek/deepseek-reasoner", + messages=[{"role": "user", "content": "Hello, world!"}], + api_base="https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions", + stream=stream, + ) + print(f"response: {response}") + if stream: + for chunk in response: + print(chunk) + else: + assert response is not None + + +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.asyncio +async def test_deepseek_mock_async_completion(stream): + """ + Deepseek API is hanging. Mock the call, to a fake endpoint, so we can confirm our integration is working. + """ + import litellm + from litellm import completion, acompletion + + litellm._turn_on_debug() + + response = await acompletion( + model="deepseek/deepseek-reasoner", + messages=[{"role": "user", "content": "Hello, world!"}], + api_base="https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions", + stream=stream, + ) + print(f"response: {response}") + if stream: + async for chunk in response: + print(chunk) + else: + assert response is not None