Skip to content

Commit

Permalink
fix(litellm): conflicts with openai when overriding
Browse files Browse the repository at this point in the history
Removed original_oai_create and original_oai_create_async variables
since we no longer override OpenAI's methods
Modified override() and undo_override() to only handle LiteLLM's methods
Updated _override_completion() and _override_async_completion() to only
store and patch LiteLLM's methods
This way, when both providers are used:
OpenAIProvider will handle overriding OpenAI's completion methods
LiteLLMProvider will only handle overriding LiteLLM's completion methods
No more conflicts between the two providers
  • Loading branch information
teocns committed Oct 31, 2024
1 parent 4ce6f80 commit 5e08b9a
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 21 deletions.
27 changes: 6 additions & 21 deletions agentops/llms/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,12 @@ def override(self):
self._override_completion()

def undo_override(self):
if (
self.original_create is not None
and self.original_create_async is not None
and self.original_oai_create is not None
and self.original_oai_create_async is not None
):
if self.original_create is not None and self.original_create_async is not None:
import litellm
from openai.resources.chat import completions


litellm.acompletion = self.original_create_async
litellm.completion = self.original_create

completions.Completions.create = self.original_oai_create
completions.AsyncCompletions.create = self.original_oai_create_async

def handle_response(
self, response, kwargs, init_timestamp, session: Optional[Session] = None
) -> dict:
Expand Down Expand Up @@ -171,13 +162,10 @@ async def async_generator():

def _override_completion(self):
import litellm
from openai.types.chat import (
ChatCompletion,
) # Note: litellm calls all LLM APIs using the OpenAI format
from openai.resources.chat import completions
from openai.types.chat import ChatCompletion

# Only store and override litellm's completion method
self.original_create = litellm.completion
self.original_oai_create = completions.Completions.create

def patched_function(*args, **kwargs):
init_timestamp = get_ISO_time()
Expand Down Expand Up @@ -207,13 +195,10 @@ def patched_function(*args, **kwargs):

def _override_async_completion(self):
import litellm
from openai.types.chat import (
ChatCompletion,
) # Note: litellm calls all LLM APIs using the OpenAI format
from openai.resources.chat import completions
from openai.types.chat import ChatCompletion

# Only store and override litellm's async completion method
self.original_create_async = litellm.acompletion
self.original_oai_create_async = completions.AsyncCompletions.create

async def patched_function(*args, **kwargs):
init_timestamp = get_ISO_time()
Expand Down
170 changes: 170 additions & 0 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import unittest
from unittest.mock import MagicMock, patch
import litellm
import openai
from openai.resources.chat.completions import Completions, AsyncCompletions
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice, CompletionUsage

from agentops.llms.openai import OpenAiProvider
from agentops.llms.litellm import LiteLLMProvider

class TestProviders(unittest.TestCase):
def setUp(self):
# Create mock clients
self.mock_openai_client = MagicMock()
self.mock_litellm_client = MagicMock()

# Store original methods before any overrides
self.original_litellm_completion = litellm.completion
self.original_litellm_acompletion = litellm.acompletion

# Test parameters
self.test_messages = [{"role": "user", "content": "test"}]
self.test_params = {
"messages": self.test_messages,
"model": "gpt-3.5-turbo",
"temperature": 0.7,
"max_tokens": 100
}

# Create a proper ChatCompletion mock response
message = ChatCompletionMessage(
role="assistant",
content="test response"
)

choice = Choice(
index=0,
message=message,
finish_reason="stop"
)

usage = CompletionUsage(
prompt_tokens=10,
completion_tokens=20,
total_tokens=30
)

self.mock_response = ChatCompletion(
id="test_id",
model="gpt-3.5-turbo",
object="chat.completion",
choices=[choice],
usage=usage,
created=1234567890
)

def tearDown(self):
# Restore original methods after each test
litellm.completion = self.original_litellm_completion
litellm.acompletion = self.original_litellm_acompletion

@patch('openai.resources.chat.completions.Completions.create')
def test_provider_override_independence(self, mock_openai_create):
"""Test that OpenAI and LiteLLM providers don't interfere with each other's method overrides"""

# Initialize both providers
openai_provider = OpenAiProvider(self.mock_openai_client)
litellm_provider = LiteLLMProvider(self.mock_litellm_client)

# Set up mock returns
mock_openai_create.return_value = self.mock_response

# Create a MagicMock for litellm completion
mock_litellm_completion = MagicMock(return_value=self.mock_response)

try:
# Store original and set mock
original_litellm_completion = litellm.completion
litellm.completion = mock_litellm_completion

# Override both providers
openai_provider.override()
litellm_provider.override()

# Test OpenAI completion
Completions.create(**self.test_params)
self.assertTrue(
mock_openai_create.called,
"OpenAI's create method should be called"
)

# Test LiteLLM completion
litellm.completion(**self.test_params)
self.assertTrue(
mock_litellm_completion.called,
"LiteLLM's completion method should be called"
)

finally:
# Restore litellm's completion function
litellm.completion = original_litellm_completion

# Undo overrides
openai_provider.undo_override()
litellm_provider.undo_override()

@patch('openai.resources.chat.completions.Completions.create')
def test_provider_override_order_independence(self, mock_openai_create):
"""Test that the order of provider overrides doesn't matter"""

# Set up mock returns
mock_openai_create.return_value = self.mock_response

# Create a MagicMock for litellm completion
mock_litellm_completion = MagicMock(return_value=self.mock_response)

try:
# Store original and set mock
original_litellm_completion = litellm.completion
litellm.completion = mock_litellm_completion

# Test overriding OpenAI first, then LiteLLM
openai_provider = OpenAiProvider(self.mock_openai_client)
litellm_provider = LiteLLMProvider(self.mock_litellm_client)

openai_provider.override()
first_openai_create = Completions.create
litellm_provider.override()

# Test both providers work independently
Completions.create(**self.test_params)
litellm.completion(**self.test_params)

# Verify methods weren't affected by each other
self.assertIs(Completions.create, first_openai_create)

# Cleanup first test
litellm_provider.undo_override()
openai_provider.undo_override()

# Reset the mock for the second test
mock_litellm_completion.reset_mock()

# Now test overriding LiteLLM first, then OpenAI
litellm_provider = LiteLLMProvider(self.mock_litellm_client)
openai_provider = OpenAiProvider(self.mock_openai_client)

litellm_provider.override()
first_litellm_method = litellm.completion
openai_provider.override()

# Test both providers work independently
Completions.create(**self.test_params)
litellm.completion(**self.test_params)

# Verify methods weren't affected by each other
self.assertIs(litellm.completion, first_litellm_method)

finally:
# Restore litellm's completion function
litellm.completion = original_litellm_completion

# Cleanup
openai_provider.undo_override()
litellm_provider.undo_override()


if __name__ == '__main__':
unittest.main()

0 comments on commit 5e08b9a

Please sign in to comment.