Skip to content

Commit

Permalink
Fix deepseek calling - refactor to use base_llm_http_handler (#8266)
Browse files Browse the repository at this point in the history
* refactor(deepseek/): move deepseek to base llm http handler

Fixes #8128 (comment)

* fix(gpt_transformation.py): support stream parsing for gpt-like calls

* test(test_deepseek_completion.py): add async streaming test

* fix(gpt_transformation.py): fix import

* fix(gpt_transformation.py): return full api base and content type
  • Loading branch information
krrishdholakia authored Feb 5, 2025
1 parent 51b9a02 commit 3c813b3
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 6 deletions.
1 change: 1 addition & 0 deletions litellm/litellm_core_utils/get_llm_provider_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
or get_secret("DEEPSEEK_API_BASE")
or "https://api.deepseek.com/beta"
) # type: ignore

dynamic_api_key = api_key or get_secret_str("DEEPSEEK_API_KEY")
elif custom_llm_provider == "fireworks_ai":
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
Expand Down
106 changes: 102 additions & 4 deletions litellm/llms/openai/chat/gpt_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,27 @@
Support for gpt model family
"""

from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Iterator,
List,
Optional,
Union,
cast,
)

import httpx

import litellm
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from litellm.types.utils import ModelResponse, ModelResponseStream
from litellm.utils import convert_to_model_response_object

from ..common_utils import OpenAIError

Expand Down Expand Up @@ -210,7 +221,36 @@ def transform_response(
Returns:
dict: The transformed response.
"""
raise NotImplementedError

## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)

## RESPONSE OBJECT
try:
completion_response = raw_response.json()
except Exception as e:
response_headers = getattr(raw_response, "headers", None)
raise OpenAIError(
message="Unable to get json response - {}, Original Response: {}".format(
str(e), raw_response.text
),
status_code=raw_response.status_code,
headers=response_headers,
)
raw_response_headers = dict(raw_response.headers)
final_response_obj = convert_to_model_response_object(
response_object=completion_response,
model_response_object=model_response,
hidden_params={"headers": raw_response_headers},
_response_headers=raw_response_headers,
)

return cast(ModelResponse, final_response_obj)

def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
Expand All @@ -221,6 +261,30 @@ def get_error_class(
headers=cast(httpx.Headers, headers),
)

def get_complete_url(
self,
api_base: str,
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the API call.
Returns:
str: The complete URL for the API call.
"""
endpoint = "chat/completions"

# Remove trailing slash from api_base if present
api_base = api_base.rstrip("/")

# Check if endpoint is already in the api_base
if endpoint in api_base:
return api_base

return f"{api_base}/{endpoint}"

def validate_environment(
self,
headers: dict,
Expand All @@ -230,7 +294,14 @@ def validate_environment(
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
raise NotImplementedError
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"

# Ensure Content-Type is set to application/json
if "content-type" not in headers and "Content-Type" not in headers:
headers["Content-Type"] = "application/json"

return headers

def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
Expand Down Expand Up @@ -272,3 +343,30 @@ def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
or get_secret_str("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)

def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
return OpenAIChatCompletionStreamingHandler(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)


class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator):

def chunk_parser(self, chunk: dict) -> ModelResponseStream:
try:
return ModelResponseStream(
id=chunk["id"],
object="chat.completion.chunk",
created=chunk["created"],
model=chunk["model"],
choices=chunk["choices"],
)
except Exception as e:
raise e
31 changes: 30 additions & 1 deletion litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,36 @@ def completion( # type: ignore # noqa: PLR0915
"api_base": api_base,
},
)
elif custom_llm_provider == "deepseek":
## COMPLETION CALL
try:
response = base_llm_http_handler.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout, # type: ignore
client=client,
custom_llm_provider=custom_llm_provider,
encoding=encoding,
stream=stream,
)
except Exception as e:
## LOGGING - log the original exception returned
logging.post_call(
input=messages,
api_key=api_key,
original_response=str(e),
additional_args={"headers": headers},
)
raise e

elif custom_llm_provider == "azure_ai":
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
Expand Down Expand Up @@ -1611,7 +1641,6 @@ def completion( # type: ignore # noqa: PLR0915
or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova"
or custom_llm_provider == "volcengine"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral"
or custom_llm_provider == "openai"
Expand Down
11 changes: 11 additions & 0 deletions litellm/model_prices_and_context_window_backup.json
Original file line number Diff line number Diff line change
Expand Up @@ -8904,5 +8904,16 @@
"supports_function_calling": true,
"mode": "chat",
"supports_tool_choice": true
},
"hyperbolic/deepseek-v3": {
"max_tokens": 20480,
"max_input_tokens": 131072,
"max_output_tokens": 20480,
"litellm_provider": "openai",
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000025,
"mode": "chat",
"supports_function_calling": true,
"supports_response_schema": true
}
}
51 changes: 50 additions & 1 deletion tests/llm_translation/test_deepseek_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class TestDeepSeekChatCompletion(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {
"model": "deepseek/deepseek-chat",
"model": "deepseek/deepseek-reasoner",
}

def test_tool_call_no_arguments(self, tool_call_no_arguments):
Expand All @@ -21,3 +21,52 @@ def test_multilingual_requests(self):
Todo: if litellm.modify_params is True ensure it's a valid utf-8 sequence
"""
pass


@pytest.mark.parametrize("stream", [True, False])
def test_deepseek_mock_completion(stream):
"""
Deepseek API is hanging. Mock the call, to a fake endpoint, so we can confirm our integration is working.
"""
import litellm
from litellm import completion

litellm._turn_on_debug()

response = completion(
model="deepseek/deepseek-reasoner",
messages=[{"role": "user", "content": "Hello, world!"}],
api_base="https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions",
stream=stream,
)
print(f"response: {response}")
if stream:
for chunk in response:
print(chunk)
else:
assert response is not None


@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.asyncio
async def test_deepseek_mock_async_completion(stream):
"""
Deepseek API is hanging. Mock the call, to a fake endpoint, so we can confirm our integration is working.
"""
import litellm
from litellm import completion, acompletion

litellm._turn_on_debug()

response = await acompletion(
model="deepseek/deepseek-reasoner",
messages=[{"role": "user", "content": "Hello, world!"}],
api_base="https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions",
stream=stream,
)
print(f"response: {response}")
if stream:
async for chunk in response:
print(chunk)
else:
assert response is not None

0 comments on commit 3c813b3

Please sign in to comment.