Skip to content

Commit

Permalink
(Fixes) OpenAI Streaming Token Counting + Fixes usage track when `lit…
Browse files Browse the repository at this point in the history
…ellm.turn_off_message_logging=True` (#8156)

* working streaming usage tracking

* fix test_async_chat_openai_stream_options

* fix await asyncio.sleep(1)

* test_async_chat_azure

* fix s3 logging

* fix get_stream_options

* fix get_stream_options

* fix streaming handler

* test_stream_token_counting_with_redaction

* fix codeql concern
  • Loading branch information
ishaan-jaff authored Jan 31, 2025
1 parent 9f0f2b3 commit 2cf0daa
Show file tree
Hide file tree
Showing 8 changed files with 267 additions and 93 deletions.
71 changes: 40 additions & 31 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,21 +1029,13 @@ def success_handler( # noqa: PLR0915
] = None
if "complete_streaming_response" in self.model_call_details:
return # break out of this.
if self.stream and (
isinstance(result, litellm.ModelResponse)
or isinstance(result, TextCompletionResponse)
or isinstance(result, ModelResponseStream)
):
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
result=result,
start_time=start_time,
end_time=end_time,
request_kwargs=self.model_call_details,
streaming_chunks=self.sync_streaming_chunks,
is_async=False,
)
complete_streaming_response = self._get_assembled_streaming_response(
result=result,
start_time=start_time,
end_time=end_time,
is_async=False,
streaming_chunks=self.sync_streaming_chunks,
)
if complete_streaming_response is not None:
verbose_logger.debug(
"Logging Details LiteLLM-Success Call streaming complete"
Expand Down Expand Up @@ -1542,22 +1534,13 @@ async def async_success_handler( # noqa: PLR0915
return # break out of this.
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = None
if self.stream is True and (
isinstance(result, litellm.ModelResponse)
or isinstance(result, litellm.ModelResponseStream)
or isinstance(result, TextCompletionResponse)
):
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
result=result,
start_time=start_time,
end_time=end_time,
request_kwargs=self.model_call_details,
streaming_chunks=self.streaming_chunks,
is_async=True,
)
] = self._get_assembled_streaming_response(
result=result,
start_time=start_time,
end_time=end_time,
is_async=True,
streaming_chunks=self.streaming_chunks,
)

if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response")
Expand Down Expand Up @@ -2259,6 +2242,32 @@ def _remove_internal_custom_logger_callbacks(self, callbacks: List) -> List:
_new_callbacks.append(_c)
return _new_callbacks

def _get_assembled_streaming_response(
self,
result: Union[ModelResponse, TextCompletionResponse, ModelResponseStream, Any],
start_time: datetime.datetime,
end_time: datetime.datetime,
is_async: bool,
streaming_chunks: List[Any],
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
if isinstance(result, ModelResponse):
return result
elif isinstance(result, TextCompletionResponse):
return result
elif isinstance(result, ModelResponseStream):
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
result=result,
start_time=start_time,
end_time=end_time,
request_kwargs=self.model_call_details,
streaming_chunks=streaming_chunks,
is_async=is_async,
)
return complete_streaming_response
return None


def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
"""
Expand Down
68 changes: 18 additions & 50 deletions litellm/litellm_core_utils/streaming_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional, cast

import httpx
Expand All @@ -14,6 +13,7 @@
import litellm
from litellm import verbose_logger
from litellm.litellm_core_utils.redact_messages import LiteLLMLoggingObject
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.types.utils import Delta
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import (
Expand All @@ -29,11 +29,6 @@
from .llm_response_utils.get_api_base import get_api_base
from .rules import Rules

MAX_THREADS = 100

# Create a ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)


def is_async_iterable(obj: Any) -> bool:
"""
Expand Down Expand Up @@ -1568,21 +1563,6 @@ async def __anext__(self): # noqa: PLR0915
)
if processed_chunk is None:
continue
## LOGGING
## LOGGING
executor.submit(
self.logging_obj.success_handler,
result=processed_chunk,
start_time=None,
end_time=None,
cache_hit=cache_hit,
)

asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit
)
)

if self.logging_obj._llm_caching_handler is not None:
asyncio.create_task(
Expand Down Expand Up @@ -1634,16 +1614,6 @@ async def __anext__(self): # noqa: PLR0915
)
if processed_chunk is None:
continue
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log processed_chunk
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit
)
)

choice = processed_chunk.choices[0]
if isinstance(choice, StreamingChoices):
Expand Down Expand Up @@ -1671,33 +1641,31 @@ async def __anext__(self): # noqa: PLR0915
"usage",
getattr(complete_streaming_response, "usage"),
)
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(response, None, None, cache_hit),
).start() # log response
if self.sent_stream_usage is False and self.send_stream_usage is True:
self.sent_stream_usage = True
return response

asyncio.create_task(
self.logging_obj.async_success_handler(
response, cache_hit=cache_hit
complete_streaming_response,
cache_hit=cache_hit,
start_time=None,
end_time=None,
)
)
if self.sent_stream_usage is False and self.send_stream_usage is True:
self.sent_stream_usage = True
return response

executor.submit(
self.logging_obj.success_handler,
complete_streaming_response,
cache_hit=cache_hit,
start_time=None,
end_time=None,
)

raise StopAsyncIteration # Re-raise StopIteration
else:
self.sent_last_chunk = True
processed_chunk = self.finish_reason_handler()
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk, None, None, cache_hit),
).start() # log response
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk, cache_hit=cache_hit
)
)
return processed_chunk
except httpx.TimeoutException as e: # if httpx read timeout error occues
traceback_exception = traceback.format_exc()
Expand Down
5 changes: 5 additions & 0 deletions litellm/litellm_core_utils/thread_pool_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from concurrent.futures import ThreadPoolExecutor

MAX_THREADS = 100
# Create a ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
25 changes: 21 additions & 4 deletions litellm/llms/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Union,
cast,
)
from urllib.parse import urlparse

import httpx
import openai
Expand Down Expand Up @@ -833,8 +834,9 @@ def streaming(
stream_options: Optional[dict] = None,
):
data["stream"] = True
if stream_options is not None:
data["stream_options"] = stream_options
data.update(
self.get_stream_options(stream_options=stream_options, api_base=api_base)
)

openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False,
Expand Down Expand Up @@ -893,8 +895,9 @@ async def async_streaming(
):
response = None
data["stream"] = True
if stream_options is not None:
data["stream_options"] = stream_options
data.update(
self.get_stream_options(stream_options=stream_options, api_base=api_base)
)
for _ in range(2):
try:
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
Expand Down Expand Up @@ -977,6 +980,20 @@ async def async_streaming(
status_code=500, message=f"{str(e)}", headers=error_headers
)

def get_stream_options(
self, stream_options: Optional[dict], api_base: Optional[str]
) -> dict:
"""
Pass `stream_options` to the data dict for OpenAI requests
"""
if stream_options is not None:
return {"stream_options": stream_options}
else:
# by default litellm will include usage for openai endpoints
if api_base is None or urlparse(api_base).hostname == "api.openai.com":
return {"stream_options": {"include_usage": True}}
return {}

# Embedding
@track_llm_api_timing()
async def make_openai_embedding_request(
Expand Down
6 changes: 1 addition & 5 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@
# Convert to str (if necessary)
claude_json_str = json.dumps(json_data)
import importlib.metadata
from concurrent.futures import ThreadPoolExecutor
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -185,6 +184,7 @@

from openai import OpenAIError as OriginalError

from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig,
)
Expand Down Expand Up @@ -235,10 +235,6 @@

####### ENVIRONMENT VARIABLES ####################
# Adjust to your specific application needs / system capabilities.
MAX_THREADS = 100

# Create a ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
sentry_sdk_instance = None
capture_exception = None
add_breadcrumb = None
Expand Down
20 changes: 20 additions & 0 deletions tests/local_testing/test_custom_callback_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ async def test_async_chat_openai_stream():
)
async for chunk in response:
continue

await asyncio.sleep(1)
## test failure callback
try:
response = await litellm.acompletion(
Expand All @@ -428,6 +430,7 @@ async def test_async_chat_openai_stream():
)
async for chunk in response:
continue
await asyncio.sleep(1)
except Exception:
pass
time.sleep(1)
Expand Down Expand Up @@ -499,6 +502,8 @@ async def test_async_chat_azure_stream():
)
async for chunk in response:
continue

await asyncio.sleep(1)
# test failure callback
try:
response = await litellm.acompletion(
Expand All @@ -509,6 +514,7 @@ async def test_async_chat_azure_stream():
)
async for chunk in response:
continue
await asyncio.sleep(1)
except Exception:
pass
await asyncio.sleep(1)
Expand Down Expand Up @@ -540,6 +546,8 @@ async def test_async_chat_openai_stream_options():

async for chunk in response:
continue

await asyncio.sleep(1)
print("mock client args list=", mock_client.await_args_list)
mock_client.assert_awaited_once()
except Exception as e:
Expand Down Expand Up @@ -607,6 +615,8 @@ async def test_async_chat_bedrock_stream():
async for chunk in response:
print(f"chunk: {chunk}")
continue

await asyncio.sleep(1)
## test failure callback
try:
response = await litellm.acompletion(
Expand All @@ -617,6 +627,8 @@ async def test_async_chat_bedrock_stream():
)
async for chunk in response:
continue

await asyncio.sleep(1)
except Exception:
pass
await asyncio.sleep(1)
Expand Down Expand Up @@ -770,6 +782,8 @@ async def test_async_text_completion_bedrock():
async for chunk in response:
print(f"chunk: {chunk}")
continue

await asyncio.sleep(1)
## test failure callback
try:
response = await litellm.atext_completion(
Expand All @@ -780,6 +794,8 @@ async def test_async_text_completion_bedrock():
)
async for chunk in response:
continue

await asyncio.sleep(1)
except Exception:
pass
time.sleep(1)
Expand Down Expand Up @@ -809,6 +825,8 @@ async def test_async_text_completion_openai_stream():
async for chunk in response:
print(f"chunk: {chunk}")
continue

await asyncio.sleep(1)
## test failure callback
try:
response = await litellm.atext_completion(
Expand All @@ -819,6 +837,8 @@ async def test_async_text_completion_openai_stream():
)
async for chunk in response:
continue

await asyncio.sleep(1)
except Exception:
pass
time.sleep(1)
Expand Down
Loading

0 comments on commit 2cf0daa

Please sign in to comment.