Skip to content

Commit

Permalink
Add support for any LLM by using LiteLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
RussellLuo committed Jan 17, 2025
1 parent c7aa9a8 commit 9c11a73
Show file tree
Hide file tree
Showing 9 changed files with 2,193 additions and 30 deletions.
10 changes: 4 additions & 6 deletions coagent/agents/aswarm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from typing import Any, AsyncIterator, List, Callable, Union

# Package/library imports
from openai import OpenAI
from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
from coagent.agents.messages import ChatMessage
from coagent.agents.model_client import ModelClient
from coagent.core.agent import is_async_iterator
from coagent.core.util import get_func_args, pretty_trace_tool_call

Expand All @@ -33,10 +33,8 @@


class Swarm:
def __init__(self, client=None):
if not client:
client = OpenAI()
self.client = client
def __init__(self, client: ModelClient):
self.client: ModelClient = client

async def get_chat_completion(
self,
Expand Down Expand Up @@ -86,7 +84,7 @@ async def get_chat_completion(
p.pop("refusal", None)

try:
response = await self.client.chat.completions.create(**create_params)
response = await self.client.acompletion(**create_params)
async for chunk in response:
yield chunk
except Exception as exc:
Expand Down
4 changes: 4 additions & 0 deletions coagent/agents/aswarm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def debug_print(debug: bool, *args: str) -> None:
def merge_fields(target, source):
for key, value in source.items():
if isinstance(value, str):
# A dirty workaround to avoid containing duplicate "function" in
# the `type` field. (e.g. "functionfunction")
if key == "type" and target[key] == "function":
continue
target[key] += value
elif value is not None and isinstance(value, dict):
merge_fields(target[key], value)
Expand Down
2 changes: 1 addition & 1 deletion coagent/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def __init__(
if getattr(meth, "is_tool", False):
tools.append(meth)

self._swarm_client = Swarm(self.client.azure_client)
self._swarm_client = Swarm(self.client)

self._swarm_agent = SwarmAgent(
name=self.name,
Expand Down
2 changes: 1 addition & 1 deletion coagent/agents/dynamic_triage.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self._inclusive: bool = inclusive
self._client: ModelClient = client

self._swarm_client = Swarm(self.client.azure_client)
self._swarm_client = Swarm(self.client)

self._sub_agents: dict[str, Schema] = {}
self._swarm_agent: SwarmAgent | None = None
Expand Down
61 changes: 53 additions & 8 deletions coagent/agents/model_client.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,67 @@
import os

from openai import AsyncAzureOpenAI
# Importing litellm is super slow, so we are using lazy import for now.
# See https://github.com/BerriAI/litellm/issues/7605.
#
# import litellm
from pydantic import BaseModel, Field


# class ModelResponse(litellm.ModelResponse):
# pass


class ModelClient(BaseModel):
model: str = Field(os.getenv("AZURE_MODEL", ""), description="The model name.")
provider: str = Field("", description="The model provider.")
model: str = Field(..., description="The model name.")
api_base: str = Field("", description="The API base URL.")
api_version: str = Field("", description="The API version.")
api_key: str = Field("", description="The API key.")

@property
def azure_client(self) -> AsyncAzureOpenAI:
return AsyncAzureOpenAI(
azure_endpoint=self.api_base or os.getenv("AZURE_API_BASE"),
api_version=self.api_version or os.getenv("AZURE_API_VERSION"),
api_key=self.api_key or os.getenv("AZURE_API_KEY"),
def llm_provider(self) -> str:
if self.provider:
return self.provider

import litellm

_, provider, _, _ = litellm.get_llm_provider(
self.model,
api_base=self.api_base or None,
)
return provider

async def acompletion(
self,
messages: list[dict],
model: str = "",
stream: bool = False,
temperature: float = 0.1,
tools: list | None = None,
tool_choice: str | None = None,
**kwargs,
): # -> ModelResponse:
import litellm

model = model or self.model
response = await litellm.acompletion(
model=model,
messages=messages,
stream=stream,
temperature=temperature,
tools=tools,
tool_choice=tool_choice,
api_base=self.api_base,
api_version=self.api_version,
api_key=self.api_key,
**kwargs,
)
return response


default_model_client = ModelClient()
default_model_client = ModelClient(
model=os.getenv("AZURE_MODEL", ""),
api_base=os.getenv("AZURE_API_BASE", ""),
api_version=os.getenv("AZURE_API_VERSION", ""),
api_key=os.getenv("AZURE_API_KEY", ""),
)
6 changes: 2 additions & 4 deletions coagent/agents/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
async def chat(
messages: list[ChatMessage], client: ModelClient = default_model_client
) -> ChatMessage:
response = await client.azure_client.chat.completions.create(
model=default_model_client.model,
response = await client.acompletion(
messages=[m.model_dump() for m in messages],
)
msg = response.choices[0].message
Expand All @@ -25,8 +24,7 @@ async def chat(
async def chat_stream(
messages: list[ChatMessage], client: ModelClient = default_model_client
) -> AsyncIterator[ChatMessage]:
response = await default_model_client.azure_client.chat.completions.create(
model=client.model,
response = await client.acompletion(
messages=[m.model_dump() for m in messages],
stream=True,
)
Expand Down
Loading

0 comments on commit 9c11a73

Please sign in to comment.