-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for any LLM by using LiteLLM
- Loading branch information
1 parent
c7aa9a8
commit 9c11a73
Showing
9 changed files
with
2,193 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,67 @@ | ||
import os | ||
|
||
from openai import AsyncAzureOpenAI | ||
# Importing litellm is super slow, so we are using lazy import for now. | ||
# See https://github.com/BerriAI/litellm/issues/7605. | ||
# | ||
# import litellm | ||
from pydantic import BaseModel, Field | ||
|
||
|
||
# class ModelResponse(litellm.ModelResponse): | ||
# pass | ||
|
||
|
||
class ModelClient(BaseModel): | ||
model: str = Field(os.getenv("AZURE_MODEL", ""), description="The model name.") | ||
provider: str = Field("", description="The model provider.") | ||
model: str = Field(..., description="The model name.") | ||
api_base: str = Field("", description="The API base URL.") | ||
api_version: str = Field("", description="The API version.") | ||
api_key: str = Field("", description="The API key.") | ||
|
||
@property | ||
def azure_client(self) -> AsyncAzureOpenAI: | ||
return AsyncAzureOpenAI( | ||
azure_endpoint=self.api_base or os.getenv("AZURE_API_BASE"), | ||
api_version=self.api_version or os.getenv("AZURE_API_VERSION"), | ||
api_key=self.api_key or os.getenv("AZURE_API_KEY"), | ||
def llm_provider(self) -> str: | ||
if self.provider: | ||
return self.provider | ||
|
||
import litellm | ||
|
||
_, provider, _, _ = litellm.get_llm_provider( | ||
self.model, | ||
api_base=self.api_base or None, | ||
) | ||
return provider | ||
|
||
async def acompletion( | ||
self, | ||
messages: list[dict], | ||
model: str = "", | ||
stream: bool = False, | ||
temperature: float = 0.1, | ||
tools: list | None = None, | ||
tool_choice: str | None = None, | ||
**kwargs, | ||
): # -> ModelResponse: | ||
import litellm | ||
|
||
model = model or self.model | ||
response = await litellm.acompletion( | ||
model=model, | ||
messages=messages, | ||
stream=stream, | ||
temperature=temperature, | ||
tools=tools, | ||
tool_choice=tool_choice, | ||
api_base=self.api_base, | ||
api_version=self.api_version, | ||
api_key=self.api_key, | ||
**kwargs, | ||
) | ||
return response | ||
|
||
|
||
default_model_client = ModelClient() | ||
default_model_client = ModelClient( | ||
model=os.getenv("AZURE_MODEL", ""), | ||
api_base=os.getenv("AZURE_API_BASE", ""), | ||
api_version=os.getenv("AZURE_API_VERSION", ""), | ||
api_key=os.getenv("AZURE_API_KEY", ""), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.