Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LiteLLM in place of ChatOpenAI #84

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,13 @@ Wandbot is a question-answering bot designed specifically for Weights & Biases [

## What's New

### wandbot v1.2.0
### wandbot v1.2.1

This release introduces a number of exciting updates and improvements:

- **Parallel LLM Calls**: Replaced the llama-index with the LECL, enabling parallel LLM calls for increased efficiency.
- **ChromaDB Integration**: Transitioned from FAISS to ChromaDB to leverage metadata filtering and speed.
- **Query Enhancer Optimization**: Improved the query enhancer to operate with a single LLM call.
- **Modular RAG Pipeline**: Split the RAG pipeline into three distinct modules: query enhancement, retrieval, and response synthesis, for improved clarity and maintenance.
- **Parent Document Retrieval**: Introduced parent document retrieval functionality within the retrieval module to enhance contextuality.
- **Sub-query Answering**: Added sub-query answering capabilities in the response synthesis module to handle complex queries more effectively.
- **API Restructuring**: Redesigned the API into separate routers for retrieval, database, and chat operations.

These updates are part of our ongoing commitment to improve performance and usability.
Key updates:
- Model-agnostic fallback system with LiteLLM
- Support for OpenAI, Anthropic, and Google models
- Improved error handling and automatic model fallback
- Provider/model format (e.g., "openai/gpt-4-0125-preview")

## Evaluation
English
Expand All @@ -39,7 +33,7 @@ Japanese
- It features periodic data ingestion and report generation, contributing to the bot's continuous improvement. You can view the latest data ingestion report [here](https://wandb.ai/wandbot/wandbot-dev/reportlist).
- The bot is integrated with Discord and Slack, facilitating seamless integration with these popular collaboration platforms.
- Performance monitoring and continuous improvement are made possible through logging and analysis with Weights & Biases Tables. Visit the workspace for more details [here](https://wandb.ai/wandbot/wandbot_public).
- Wandbot has a fallback mechanism for model selection, which is used when GPT-4 fails to generate a response.
- Model-agnostic fallback system with support for OpenAI, Anthropic, and Google models.
- The bot's performance is evaluated using a mix of metrics, including retrieval accuracy, string similarity, and the correctness of model-generated responses.
- Curious about the custom system prompt used by the bot? You can view the full prompt [here](data/prompts/chat_prompt.json).

Expand All @@ -60,6 +54,20 @@ poetry install --all-extras

## Usage

### Model Configuration

Models are specified using provider/model format. Example:
```python
model_config = {
"model_name": "openai/gpt-4-0125-preview", # Required
"temperature": 0.1, # Required
"fallback_models": [ # Optional
"anthropic/claude-3-haiku",
"gemini/gemini-pro"
]
}
```

### Data Ingestion

The data ingestion module pulls code and markdown from Weights & Biases repositories [docodile](https://github.com/wandb/docodile) and [examples](https://github.com/wandb/examples) ingests them into vectorstores for the retrieval augmented generation pipeline.
Expand Down
14 changes: 13 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repository = "https://github.com/wandb/wandbot"
include = ["src/**/*", "LICENSE", "README.md"]

[tool.poetry.dependencies]
python = ">=3.10.0,<=3.12.4"
python = ">=3.10.0,<=3.12.8"
numpy = "^1.26.1"
pandas = "^2.1.2"
pydantic-settings = "^2.0.3"
Expand All @@ -27,6 +27,8 @@ tree-sitter-languages = "^1.7.1"
markdownify = "^0.11.6"
uvicorn = "^0.24.0"
openai = "^1.3.2"
google-generativeai = ">=0.8.3"
anthropic = "^0.18.1"
weave = "^0.50.12"
colorlog = "^6.8.0"
litellm = "^1.15.1"
Expand All @@ -53,6 +55,16 @@ ragas = "^0.1.7"
dataclasses-json = "^0.6.4"
llama-index = "^0.10.30"


[tool.poetry.group.dev.dependencies]
pytest = "^8.3.4"

[tool.pytest.ini_options]
filterwarnings = [
'ignore:.*Type google._upb._message.*uses PyType_Spec.*:DeprecationWarning',
'ignore:.*custom tp_new.*:DeprecationWarning'
]

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Expand Down
106 changes: 58 additions & 48 deletions src/wandbot/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from wandbot.retriever import VectorStore
from wandbot.utils import Timer, get_logger

from openai import OpenAI
from wandbot.rag.utils import ChatModel

logger = get_logger(__name__)

Expand Down Expand Up @@ -87,6 +87,9 @@ def _get_answer(

return result

# Translation model
ja_to_en_model: ChatModel = ChatModel(max_retries=2)

@weave.op()
def _translate_ja_to_en(self, text: str) -> str:
"""
Expand All @@ -98,30 +101,35 @@ def _translate_ja_to_en(self, text: str) -> str:
Returns:
The translated text in English.
"""
client = OpenAI()
response = client.chat.completions.create(
model="gpt-4o-2024-08-06",
messages=[
{
"role": "system",
"content": f"You are a professional translator. \n\n\
Translate the user's question about Weights & Biases into English according to the specified rules. \n\
Rule of translation. \n\
- Maintain the original nuance\n\
- Keep code unchanged.\n\
- Only return the English translation without any additional explanation"
},
{
"role": "user",
"content": text
}
],
temperature=0,
max_tokens=1000,
top_p=1
)
# Configure model
self.ja_to_en_model = {
"model_name": "openai/gpt-4o-2024-08-06",
"temperature": 0
}

# Call model
response = self.ja_to_en_model([
{
"role": "system",
"content": """You are a professional translator.

Translate the user's question about Weights & Biases into English according to the specified rules.
Rule of translation:
- Maintain the original nuance
- Keep code unchanged.
- Only return the English translation without any additional explanation"""
},
{
"role": "user",
"content": text
}
], max_tokens=1000)

return response.choices[0].message.content

# Translation model
en_to_ja_model: ChatModel = ChatModel(max_retries=2)

@weave.op()
def _translate_en_to_ja(self, text: str) -> str:
"""
Expand All @@ -133,31 +141,33 @@ def _translate_en_to_ja(self, text: str) -> str:
Returns:
The translated text in Japanese.
"""
client = OpenAI()
response = client.chat.completions.create(
model="gpt-4o-2024-08-06",
messages=[
{
"role": "system",
"content": f"You are a professional translator. \n\n\
Translate the user's text into Japanese according to the specified rules. \n\
Rule of translation. \n\
- Maintain the original nuance\n\
- Use 'run' in English where appropriate, as it's a term used in Wandb.\n\
- Translate the terms 'reference artifacts' and 'lineage' into Katakana. \n\
- Include specific terms in English or Katakana where appropriate\n\
- Keep code unchanged.\n\
- Only return the Japanese translation without any additional explanation"
},
{
"role": "user",
"content": text
}
],
temperature=0,
max_tokens=1000,
top_p=1
)
# Configure model
self.en_to_ja_model = {
"model_name": "openai/gpt-4o-2024-08-06",
"temperature": 0
}

# Call model
response = self.en_to_ja_model([
{
"role": "system",
"content": """You are a professional translator.

Translate the user's text into Japanese according to the specified rules.
Rule of translation:
- Maintain the original nuance
- Use 'run' in English where appropriate, as it's a term used in Wandb.
- Translate the terms 'reference artifacts' and 'lineage' into Katakana.
- Include specific terms in English or Katakana where appropriate
- Keep code unchanged.
- Only return the Japanese translation without any additional explanation"""
},
{
"role": "user",
"content": text
}
], max_tokens=1000)

return response.choices[0].message.content

@weave.op()
Expand Down
4 changes: 2 additions & 2 deletions src/wandbot/chat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ChatConfig(BaseSettings):
english_reranker_model: str = "rerank-english-v2.0"
multilingual_reranker_model: str = "rerank-multilingual-v2.0"
# Response synthesis settings
response_synthesizer_model: str = "gpt-4-0125-preview"
response_synthesizer_model: str = "openai/gpt-4-0125-preview" # Format: provider/model_name
response_synthesizer_temperature: float = 0.1
response_synthesizer_fallback_model: str = "gpt-4-0125-preview"
response_synthesizer_fallback_model: str = "openai/gpt-4-0125-preview" # Format: provider/model_name
response_synthesizer_fallback_temperature: float = 0.1
4 changes: 2 additions & 2 deletions src/wandbot/rag/response_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableLambda, RunnableParallel
from langchain_openai import ChatOpenAI


from wandbot.rag.utils import ChatModel, combine_documents, create_query_str

Expand Down Expand Up @@ -139,7 +139,7 @@ def chain(self) -> Runnable:
self._chain = base_chain.with_fallbacks([fallback_chain])
return self._chain

def _load_chain(self, model: ChatOpenAI) -> Runnable:
def _load_chain(self, model: ChatModel) -> Runnable:
response_synthesis_chain = (
RunnableLambda(
lambda x: {
Expand Down
65 changes: 58 additions & 7 deletions src/wandbot/rag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate, format_document
from langchain_openai import ChatOpenAI
import litellm

from wandbot.retriever.web_search import YouSearchResults
from wandbot.utils import clean_document_content


class ChatModel:
"""Chat model descriptor that wraps LiteLLM for provider-agnostic interface."""

def __init__(self, max_retries: int = 2):
self.max_retries = max_retries

Expand All @@ -21,12 +23,61 @@ def __get__(self, obj, obj_type=None):
return value

def __set__(self, obj, value):
model = ChatOpenAI(
model_name=value["model_name"],
temperature=value["temperature"],
max_retries=self.max_retries,
)
setattr(obj, self.private_name, model)
"""Configure LiteLLM with the given model settings.

Args:
value: Dictionary containing:
- model_name: Name of the model to use (e.g., "openai/gpt-4")
- temperature: Sampling temperature between 0 and 1
- fallback_models: Optional list of fallback models
"""
if not 0 <= value["temperature"] <= 1:
raise ValueError("Temperature must be between 0 and 1")

# Configure LiteLLM
litellm.drop_params = True # Remove unsupported params
litellm.set_verbose = False
litellm.success_callback = []
litellm.failure_callback = []

# Configure fallbacks
litellm.model_fallbacks = {} # Reset fallbacks
litellm.fallbacks = False # Reset fallbacks flag
if value.get("fallback_models"):
litellm.model_fallbacks = {
value["model_name"]: value["fallback_models"]
}
litellm.fallbacks = True

# Create completion function
def completion_fn(messages, **kwargs):
try:
response = litellm.completion(
model=value["model_name"],
messages=messages,
temperature=value["temperature"],
num_retries=self.max_retries,
**kwargs
)
return response
except Exception as e:
# Return error response
return type("Response", (), {
"choices": [
type("Choice", (), {
"message": type("Message", (), {
"content": ""
})()
})()
],
"error": {
"type": type(e).__name__,
"message": str(e)
},
"model": value["model_name"]
})()

setattr(obj, self.private_name, completion_fn)


DEFAULT_QUESTION_PROMPT = PromptTemplate.from_template(
Expand Down
Loading
Loading