diff --git a/docs/my-website/docs/proxy/prompt_management.md b/docs/my-website/docs/proxy/prompt_management.md index 4b19409217b1..328a73b8ebec 100644 --- a/docs/my-website/docs/proxy/prompt_management.md +++ b/docs/my-website/docs/proxy/prompt_management.md @@ -1,57 +1,182 @@ import Image from '@theme/IdealImage'; +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; # Prompt Management -LiteLLM supports using [Langfuse](https://langfuse.com/docs/prompts/get-started) for prompt management on the proxy. +Run experiments or change the specific model (e.g. from gpt-4o to gpt4o-mini finetune) from your prompt management tool (e.g. Langfuse) instead of making changes in the application. + +Supported Integrations: +- [Langfuse](https://langfuse.com/docs/prompts/get-started) ## Quick Start -1. Add Langfuse as a 'callback' in your config.yaml + + + + + +```python +import os +import litellm + +os.environ["LANGFUSE_PUBLIC_KEY"] = "public_key" # [OPTIONAL] set here or in `.completion` +os.environ["LANGFUSE_SECRET_KEY"] = "secret_key" # [OPTIONAL] set here or in `.completion` + +litellm.set_verbose = True # see raw request to provider + +resp = litellm.completion( + model="langfuse/gpt-3.5-turbo", + prompt_id="test-chat-prompt", + prompt_variables={"user_message": "this is used"}, # [OPTIONAL] + messages=[{"role": "user", "content": ""}], +) +``` + + + + + + +1. Setup config.yaml ```yaml model_list: - model_name: gpt-3.5-turbo litellm_params: - model: azure/chatgpt-v-2 - api_key: os.environ/AZURE_API_KEY - api_base: os.environ/AZURE_API_BASE - -litellm_settings: - callbacks: ["langfuse"] # 👈 KEY CHANGE + model: langfuse/gpt-3.5-turbo + prompt_id: "" + api_key: os.environ/OPENAI_API_KEY ``` 2. Start the proxy ```bash -litellm-proxy --config config.yaml +litellm --config config.yaml --detailed_debug ``` 3. Test it! + + + ```bash curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \ -H 'Content-Type: application/json' \ -H 'Authorization: Bearer sk-1234' \ -d '{ - "model": "gpt-4", + "model": "gpt-3.5-turbo", "messages": [ { "role": "user", "content": "THIS WILL BE IGNORED" } ], - "metadata": { - "langfuse_prompt_id": "value", - "langfuse_prompt_variables": { # [OPTIONAL] - "key": "value" - } + "prompt_variables": { + "key": "this is used" } }' ``` + + + +```python +import openai +client = openai.OpenAI( + api_key="anything", + base_url="http://0.0.0.0:4000" +) + +# request sent to model set on litellm proxy, `litellm --model` +response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages = [ + { + "role": "user", + "content": "this is a test request, write a short poem" + } + ], + extra_body={ + "prompt_variables": { # [OPTIONAL] + "key": "this is used" + } + } +) + +print(response) +``` + + + + + + + + +**Expected Logs:** + +``` +POST Request Sent from LiteLLM: +curl -X POST \ +https://api.openai.com/v1/ \ +-d '{'model': 'gpt-3.5-turbo', 'messages': }' +``` + +## How to set model + +### Set the model on LiteLLM -## What is 'langfuse_prompt_id'? +You can do `langfuse/` -- `langfuse_prompt_id`: The ID of the prompt that will be used for the request. + + + +```python +litellm.completion( + model="langfuse/gpt-3.5-turbo", # or `langfuse/anthropic/claude-3-5-sonnet` + ... +) +``` + + + + +```yaml +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: langfuse/gpt-3.5-turbo # OR langfuse/anthropic/claude-3-5-sonnet + prompt_id: + api_key: os.environ/OPENAI_API_KEY +``` + + + + +### Set the model in Langfuse + +If the model is specified in the Langfuse config, it will be used. + + + +```yaml +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: azure/chatgpt-v-2 + api_key: os.environ/AZURE_API_KEY + api_base: os.environ/AZURE_API_BASE +``` + +## What is 'prompt_variables'? + +- `prompt_variables`: A dictionary of variables that will be used to replace parts of the prompt. + + + +## What is 'prompt_id'? + +- `prompt_id`: The ID of the prompt that will be used for the request. @@ -59,25 +184,29 @@ curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \ ### `/chat/completions` messages -The message will be added to the start of the prompt. +The `messages` field sent in by the client is ignored. -- if the Langfuse prompt is a list, it will be added to the start of the messages list (assuming it's an OpenAI compatible message). +The Langfuse prompt will replace the `messages` field. -- if the Langfuse prompt is a string, it will be added as a system message. +To replace parts of the prompt, use the `prompt_variables` field. [See how prompt variables are used](https://github.com/BerriAI/litellm/blob/017f83d038f85f93202a083cf334de3544a3af01/litellm/integrations/langfuse/langfuse_prompt_management.py#L127) -```python -if isinstance(compiled_prompt, list): - data["messages"] = compiled_prompt + data["messages"] -else: - data["messages"] = [ - {"role": "system", "content": compiled_prompt} - ] + data["messages"] -``` +If the Langfuse prompt is a string, it will be sent as a user message (not all providers support system messages). -### `/completions` messages +If the Langfuse prompt is a list, it will be sent as is (Langfuse chat prompts are OpenAI compatible). -The message will be added to the start of the prompt. +## Architectural Overview -```python -data["prompt"] = compiled_prompt + "\n" + data["prompt"] -``` \ No newline at end of file + + +## API Reference + +These are the params you can pass to the `litellm.completion` function in SDK and `litellm_params` in config.yaml + +``` +prompt_id: str # required +prompt_variables: Optional[dict] # optional +langfuse_public_key: Optional[str] # optional +langfuse_secret: Optional[str] # optional +langfuse_secret_key: Optional[str] # optional +langfuse_host: Optional[str] # optional +``` diff --git a/docs/my-website/docusaurus.config.js b/docs/my-website/docusaurus.config.js index 73d500b14c4b..1dcb0613789a 100644 --- a/docs/my-website/docusaurus.config.js +++ b/docs/my-website/docusaurus.config.js @@ -130,15 +130,7 @@ const config = { href: 'https://discord.com/invite/wuPM9dRgDw', label: 'Discord', position: 'right', - }, - { - type: 'html', - position: 'right', - value: - ` - I'm Confused - ` - }, + } ], }, footer: { diff --git a/docs/my-website/img/langfuse_prompt_management_model_config.png b/docs/my-website/img/langfuse_prompt_management_model_config.png new file mode 100644 index 000000000000..d611ab3941c5 Binary files /dev/null and b/docs/my-website/img/langfuse_prompt_management_model_config.png differ diff --git a/docs/my-website/img/prompt_management_architecture_doc.png b/docs/my-website/img/prompt_management_architecture_doc.png new file mode 100644 index 000000000000..2040cb7fa3e9 Binary files /dev/null and b/docs/my-website/img/prompt_management_architecture_doc.png differ diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 9372f2c32ced..34701fa32471 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -135,7 +135,6 @@ const sidebars = { "oidc" ] }, - "proxy/prompt_management", "proxy/caching", "proxy/call_hooks", "proxy/rules", @@ -228,6 +227,7 @@ const sidebars = { "completion/batching", "completion/mock_requests", "completion/reliable_completions", + 'tutorials/litellm_proxy_aporia', ] }, @@ -309,8 +309,29 @@ const sidebars = { label: "LangChain, LlamaIndex, Instructor Integration", items: ["langchain/langchain", "tutorials/instructor"], }, + { + type: "category", + label: "Tutorials", + items: [ + + 'tutorials/azure_openai', + 'tutorials/instructor', + "tutorials/gradio_integration", + "tutorials/huggingface_codellama", + "tutorials/huggingface_tutorial", + "tutorials/TogetherAI_liteLLM", + "tutorials/finetuned_chat_gpt", + "tutorials/text_completion", + "tutorials/first_playground", + "tutorials/model_fallbacks", + ], + }, ], }, + { + type: "doc", + id: "proxy/prompt_management" + }, { type: "category", label: "Load Testing", @@ -362,23 +383,7 @@ const sidebars = { "observability/opik_integration", ], }, - { - type: "category", - label: "Tutorials", - items: [ - 'tutorials/litellm_proxy_aporia', - 'tutorials/azure_openai', - 'tutorials/instructor', - "tutorials/gradio_integration", - "tutorials/huggingface_codellama", - "tutorials/huggingface_tutorial", - "tutorials/TogetherAI_liteLLM", - "tutorials/finetuned_chat_gpt", - "tutorials/text_completion", - "tutorials/first_playground", - "tutorials/model_fallbacks", - ], - }, + { type: "category", label: "Extras", diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index dac95324864a..6045244c4d94 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -14,6 +14,7 @@ EmbeddingResponse, ImageResponse, ModelResponse, + StandardCallbackDynamicParams, StandardLoggingPayload, ) @@ -60,6 +61,26 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): pass + #### PROMPT MANAGEMENT HOOKS #### + + def get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + headers: dict, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Tuple[str, List[AllMessageValues], dict]: + """ + Returns: + - model: str - the model to use (can be pulled from prompt management tool) + - messages: List[AllMessageValues] - the messages to use (can be pulled from prompt management tool) + - non_default_params: dict - update with any optional params (e.g. temperature, max_tokens, etc.) to use (can be pulled from prompt management tool) + """ + return model, messages, non_default_params + #### PRE-CALL CHECKS - router/proxy only #### """ Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks). diff --git a/litellm/integrations/langfuse/langfuse_prompt_management.py b/litellm/integrations/langfuse/langfuse_prompt_management.py index 8f5ee02fda93..cb41d4aa526c 100644 --- a/litellm/integrations/langfuse/langfuse_prompt_management.py +++ b/litellm/integrations/langfuse/langfuse_prompt_management.py @@ -3,16 +3,93 @@ """ import os -import traceback -from typing import Literal, Optional, Union +from functools import lru_cache +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union, cast from packaging.version import Version +from typing_extensions import TypeAlias from litellm._logging import verbose_proxy_logger from litellm.caching.dual_cache import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.proxy._types import UserAPIKeyAuth -from litellm.secret_managers.main import str_to_bool +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import StandardCallbackDynamicParams + +if TYPE_CHECKING: + from langfuse import Langfuse + from langfuse.client import ChatPromptClient, TextPromptClient + + LangfuseClass: TypeAlias = Langfuse + + PROMPT_CLIENT = Union[TextPromptClient, ChatPromptClient] +else: + PROMPT_CLIENT = Any + LangfuseClass = Any + + +@lru_cache(maxsize=10) +def langfuse_client_init( + langfuse_public_key=None, + langfuse_secret=None, + langfuse_host=None, + flush_interval=1, +) -> LangfuseClass: + """ + Initialize Langfuse client with caching to prevent multiple initializations. + + Args: + langfuse_public_key (str, optional): Public key for Langfuse. Defaults to None. + langfuse_secret (str, optional): Secret key for Langfuse. Defaults to None. + langfuse_host (str, optional): Host URL for Langfuse. Defaults to None. + flush_interval (int, optional): Flush interval in seconds. Defaults to 1. + + Returns: + Langfuse: Initialized Langfuse client instance + + Raises: + Exception: If langfuse package is not installed + """ + try: + import langfuse + from langfuse import Langfuse + except Exception as e: + raise Exception( + f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n\033[0m" + ) + + # Instance variables + secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY") + public_key = langfuse_public_key or os.getenv("LANGFUSE_PUBLIC_KEY") + langfuse_host = langfuse_host or os.getenv( + "LANGFUSE_HOST", "https://cloud.langfuse.com" + ) + + if not ( + langfuse_host.startswith("http://") or langfuse_host.startswith("https://") + ): + # add http:// if unset, assume communicating over private network - e.g. render + langfuse_host = "http://" + langfuse_host + + langfuse_release = os.getenv("LANGFUSE_RELEASE") + langfuse_debug = os.getenv("LANGFUSE_DEBUG") + langfuse_flush_interval = os.getenv("LANGFUSE_FLUSH_INTERVAL") or flush_interval + + parameters = { + "public_key": public_key, + "secret_key": secret_key, + "host": langfuse_host, + "release": langfuse_release, + "debug": langfuse_debug, + "flush_interval": langfuse_flush_interval, # flush interval in seconds + } + + if Version(langfuse.version.__version__) >= Version("2.6.0"): + parameters["sdk_integration"] = "litellm" + + client = Langfuse(**parameters) + + return client class LangfusePromptManagement(CustomLogger): @@ -23,102 +100,42 @@ def __init__( langfuse_host=None, flush_interval=1, ): - try: - import langfuse - from langfuse import Langfuse - except Exception as e: - raise Exception( - f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\n{traceback.format_exc()}\033[0m" - ) - # Instance variables - self.secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY") - self.public_key = langfuse_public_key or os.getenv("LANGFUSE_PUBLIC_KEY") - self.langfuse_host = langfuse_host or os.getenv( - "LANGFUSE_HOST", "https://cloud.langfuse.com" - ) - if not ( - self.langfuse_host.startswith("http://") - or self.langfuse_host.startswith("https://") - ): - # add http:// if unset, assume communicating over private network - e.g. render - self.langfuse_host = "http://" + self.langfuse_host - self.langfuse_release = os.getenv("LANGFUSE_RELEASE") - self.langfuse_debug = os.getenv("LANGFUSE_DEBUG") - self.langfuse_flush_interval = ( - os.getenv("LANGFUSE_FLUSH_INTERVAL") or flush_interval + self.Langfuse = langfuse_client_init( + langfuse_public_key=langfuse_public_key, + langfuse_secret=langfuse_secret, + langfuse_host=langfuse_host, + flush_interval=flush_interval, ) - parameters = { - "public_key": self.public_key, - "secret_key": self.secret_key, - "host": self.langfuse_host, - "release": self.langfuse_release, - "debug": self.langfuse_debug, - "flush_interval": self.langfuse_flush_interval, # flush interval in seconds - } - - if Version(langfuse.version.__version__) >= Version("2.6.0"): - parameters["sdk_integration"] = "litellm" - - self.Langfuse = Langfuse(**parameters) - - # set the current langfuse project id in the environ - # this is used by Alerting to link to the correct project - try: - project_id = self.Langfuse.client.projects.get().data[0].id - os.environ["LANGFUSE_PROJECT_ID"] = project_id - except Exception: - project_id = None - - if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None: - upstream_langfuse_debug = ( - str_to_bool(self.upstream_langfuse_debug) - if self.upstream_langfuse_debug is not None - else None - ) - self.upstream_langfuse_secret_key = os.getenv( - "UPSTREAM_LANGFUSE_SECRET_KEY" - ) - self.upstream_langfuse_public_key = os.getenv( - "UPSTREAM_LANGFUSE_PUBLIC_KEY" - ) - self.upstream_langfuse_host = os.getenv("UPSTREAM_LANGFUSE_HOST") - self.upstream_langfuse_release = os.getenv("UPSTREAM_LANGFUSE_RELEASE") - self.upstream_langfuse_debug = os.getenv("UPSTREAM_LANGFUSE_DEBUG") - self.upstream_langfuse = Langfuse( - public_key=self.upstream_langfuse_public_key, - secret_key=self.upstream_langfuse_secret_key, - host=self.upstream_langfuse_host, - release=self.upstream_langfuse_release, - debug=( - upstream_langfuse_debug - if upstream_langfuse_debug is not None - else False - ), - ) - else: - self.upstream_langfuse = None + def _get_prompt_from_id( + self, langfuse_prompt_id: str, langfuse_client: LangfuseClass + ) -> PROMPT_CLIENT: + return langfuse_client.get_prompt(langfuse_prompt_id) def _compile_prompt( self, - metadata: dict, + langfuse_prompt_client: PROMPT_CLIENT, + langfuse_prompt_variables: Optional[dict], call_type: Union[Literal["completion"], Literal["text_completion"]], ) -> Optional[Union[str, list]]: compiled_prompt: Optional[Union[str, list]] = None - if isinstance(metadata, dict): - langfuse_prompt_id = metadata.get("langfuse_prompt_id") - langfuse_prompt_variables = metadata.get("langfuse_prompt_variables") or {} - if ( - langfuse_prompt_id - and isinstance(langfuse_prompt_id, str) - and isinstance(langfuse_prompt_variables, dict) - ): - langfuse_prompt = self.Langfuse.get_prompt(langfuse_prompt_id) - compiled_prompt = langfuse_prompt.compile(**langfuse_prompt_variables) + if langfuse_prompt_variables is None: + langfuse_prompt_variables = {} + + compiled_prompt = langfuse_prompt_client.compile(**langfuse_prompt_variables) return compiled_prompt + def _get_model_from_prompt( + self, langfuse_prompt_client: PROMPT_CLIENT, model: str + ) -> str: + config = langfuse_prompt_client.config + if "model" in config: + return config["model"] + else: + return model.replace("langfuse/", "") + async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, @@ -137,9 +154,29 @@ async def async_pre_call_hook( ) -> Union[Exception, str, dict, None]: metadata = data.get("metadata") or {} + + if isinstance(metadata, dict): + langfuse_prompt_id = cast(Optional[str], metadata.get("langfuse_prompt_id")) + + langfuse_prompt_variables = cast( + Optional[dict], metadata.get("langfuse_prompt_variables") or {} + ) + else: + return None + + if langfuse_prompt_id is None: + return None + + prompt_client = self._get_prompt_from_id( + langfuse_prompt_id=langfuse_prompt_id, langfuse_client=self.Langfuse + ) compiled_prompt: Optional[Union[str, list]] = None if call_type == "completion" or call_type == "text_completion": - compiled_prompt = self._compile_prompt(metadata, call_type) + compiled_prompt = self._compile_prompt( + langfuse_prompt_client=prompt_client, + langfuse_prompt_variables=langfuse_prompt_variables, + call_type=call_type, + ) if compiled_prompt is None: return await super().async_pre_call_hook( user_api_key_dict, cache, data, call_type @@ -161,3 +198,53 @@ async def async_pre_call_hook( return await super().async_pre_call_hook( user_api_key_dict, cache, data, call_type ) + + def get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + headers: dict, + prompt_id: str, + prompt_variables: Optional[dict], + dynamic_callback_params: StandardCallbackDynamicParams, + ) -> Tuple[ + str, + List[AllMessageValues], + dict, + ]: + if prompt_id is None: + raise ValueError( + "Langfuse prompt id is required. Pass in as parameter 'langfuse_prompt_id'" + ) + langfuse_client = langfuse_client_init( + langfuse_public_key=dynamic_callback_params.get("langfuse_public_key"), + langfuse_secret=dynamic_callback_params.get("langfuse_secret"), + langfuse_host=dynamic_callback_params.get("langfuse_host"), + ) + langfuse_prompt_client = self._get_prompt_from_id( + langfuse_prompt_id=prompt_id, langfuse_client=langfuse_client + ) + + ## SET PROMPT + compiled_prompt = self._compile_prompt( + langfuse_prompt_client=langfuse_prompt_client, + langfuse_prompt_variables=prompt_variables, + call_type="completion", + ) + + if compiled_prompt is None: + raise ValueError(f"Langfuse prompt not found. Prompt id={prompt_id}") + if isinstance(compiled_prompt, list): + messages = compiled_prompt + elif isinstance(compiled_prompt, str): + messages = [{"role": "user", "content": compiled_prompt}] + else: + raise ValueError( + f"Langfuse prompt is not a list or string. Prompt id={prompt_id}, compiled_prompt type={type(compiled_prompt)}" + ) + + ## SET MODEL + model = self._get_model_from_prompt(langfuse_prompt_client, model) + + return model, messages, non_default_params diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 725ba5e8902d..a2fe21a680c2 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -34,7 +34,7 @@ redact_message_input_output_from_custom_logger, redact_message_input_output_from_logging, ) -from litellm.types.llms.openai import HttpxBinaryResponseContent +from litellm.types.llms.openai import AllMessageValues, HttpxBinaryResponseContent from litellm.types.rerank import RerankResponse from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.utils import ( @@ -425,6 +425,40 @@ def update_environment_variables( if "custom_llm_provider" in self.model_call_details: self.custom_llm_provider = self.model_call_details["custom_llm_provider"] + def get_chat_completion_prompt( + self, + model: str, + messages: List[AllMessageValues], + non_default_params: dict, + headers: dict, + prompt_id: str, + prompt_variables: Optional[dict], + ) -> Tuple[str, List[AllMessageValues], dict]: + for ( + custom_logger_compatible_callback + ) in litellm._known_custom_logger_compatible_callbacks: + if model.startswith(custom_logger_compatible_callback): + custom_logger = _init_custom_logger_compatible_class( + logging_integration=custom_logger_compatible_callback, + internal_usage_cache=None, + llm_router=None, + ) + if custom_logger is None: + continue + model, messages, non_default_params = ( + custom_logger.get_chat_completion_prompt( + model=model, + messages=messages, + non_default_params=non_default_params, + headers=headers, + prompt_id=prompt_id, + prompt_variables=prompt_variables, + dynamic_callback_params=self.standard_callback_dynamic_params, + ) + ) + + return model, messages, non_default_params + def _pre_call(self, input, api_key, model=None, additional_args={}): """ Common helper function across the sync + async pre-call function diff --git a/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py b/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py index 344f66682f19..b8ddeb03c92b 100644 --- a/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py +++ b/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py @@ -96,10 +96,10 @@ def completion( from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) - except Exception: + except Exception as e: raise VertexAIError( status_code=400, - message="""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`""", + message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""", ) if not ( diff --git a/litellm/llms/vertex_ai/vertex_model_garden/main.py b/litellm/llms/vertex_ai/vertex_model_garden/main.py index 20ee38e97916..7b54d4e34b9c 100644 --- a/litellm/llms/vertex_ai/vertex_model_garden/main.py +++ b/litellm/llms/vertex_ai/vertex_model_garden/main.py @@ -75,11 +75,11 @@ def completion( from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) - except Exception: + except Exception as e: raise VertexAIError( status_code=400, - message="""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`""", + message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""", ) if not ( diff --git a/litellm/main.py b/litellm/main.py index 0b3288accdfa..804a747aa827 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -846,6 +846,9 @@ def completion( # type: ignore # noqa: PLR0915 client = kwargs.get("client", None) ### Admin Controls ### no_log = kwargs.get("no-log", False) + ### PROMPT MANAGEMENT ### + prompt_id = cast(Optional[str], kwargs.get("prompt_id", None)) + prompt_variables = cast(Optional[dict], kwargs.get("prompt_variables", None)) ### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489 messages = get_completion_messages( messages=messages, @@ -894,11 +897,26 @@ def completion( # type: ignore # noqa: PLR0915 ] default_params = openai_params + all_litellm_params + litellm_params = {} # used to prevent unbound var errors non_default_params = { k: v for k, v in kwargs.items() if k not in default_params } # model-specific params - pass them straight to the model/provider + ## PROMPT MANAGEMENT HOOKS ## + + if isinstance(litellm_logging_obj, LiteLLMLoggingObj) and prompt_id is not None: + model, messages, optional_params = ( + litellm_logging_obj.get_chat_completion_prompt( + model=model, + messages=messages, + non_default_params=non_default_params, + headers=headers, + prompt_id=prompt_id, + prompt_variables=prompt_variables, + ) + ) + try: if base_url is not None: api_base = base_url @@ -1002,9 +1020,6 @@ def completion( # type: ignore # noqa: PLR0915 and supports_system_message is False ): messages = map_system_message_pt(messages=messages) - model_api_key = get_api_key( - llm_provider=custom_llm_provider, dynamic_api_key=api_key - ) # get the api key from the environment if required for the model if dynamic_api_key is not None: api_key = dynamic_api_key diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 487bc64bfe02..254395ea51e0 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1361,85 +1361,6 @@ class ResponseFormatChunk(TypedDict, total=False): response_schema: dict -all_litellm_params = [ - "metadata", - "litellm_trace_id", - "tags", - "acompletion", - "aimg_generation", - "atext_completion", - "text_completion", - "caching", - "mock_response", - "api_key", - "api_version", - "api_base", - "force_timeout", - "logger_fn", - "verbose", - "custom_llm_provider", - "litellm_logging_obj", - "litellm_call_id", - "use_client", - "id", - "fallbacks", - "azure", - "headers", - "model_list", - "num_retries", - "context_window_fallback_dict", - "retry_policy", - "retry_strategy", - "roles", - "final_prompt_value", - "bos_token", - "eos_token", - "request_timeout", - "complete_response", - "self", - "client", - "rpm", - "tpm", - "max_parallel_requests", - "input_cost_per_token", - "output_cost_per_token", - "input_cost_per_second", - "output_cost_per_second", - "hf_model_name", - "model_info", - "proxy_server_request", - "preset_cache_key", - "caching_groups", - "ttl", - "cache", - "no-log", - "base_model", - "stream_timeout", - "supports_system_message", - "region_name", - "allowed_model_region", - "model_config", - "fastest_response", - "cooldown_time", - "cache_key", - "max_retries", - "azure_ad_token_provider", - "tenant_id", - "client_id", - "client_secret", - "user_continue_message", - "configurable_clientside_auth_params", - "weight", - "ensure_alternating_roles", - "assistant_continue_message", - "user_continue_message", - "fallback_depth", - "max_fallbacks", - "max_budget", - "budget_duration", -] - - class LoggedLiteLLMParams(TypedDict, total=False): force_timeout: Optional[float] custom_llm_provider: Optional[str] @@ -1646,6 +1567,87 @@ class StandardCallbackDynamicParams(TypedDict, total=False): turn_off_message_logging: Optional[bool] # when true will not log messages +all_litellm_params = [ + "metadata", + "litellm_trace_id", + "tags", + "acompletion", + "aimg_generation", + "atext_completion", + "text_completion", + "caching", + "mock_response", + "api_key", + "api_version", + "prompt_id", + "prompt_variables", + "api_base", + "force_timeout", + "logger_fn", + "verbose", + "custom_llm_provider", + "litellm_logging_obj", + "litellm_call_id", + "use_client", + "id", + "fallbacks", + "azure", + "headers", + "model_list", + "num_retries", + "context_window_fallback_dict", + "retry_policy", + "retry_strategy", + "roles", + "final_prompt_value", + "bos_token", + "eos_token", + "request_timeout", + "complete_response", + "self", + "client", + "rpm", + "tpm", + "max_parallel_requests", + "input_cost_per_token", + "output_cost_per_token", + "input_cost_per_second", + "output_cost_per_second", + "hf_model_name", + "model_info", + "proxy_server_request", + "preset_cache_key", + "caching_groups", + "ttl", + "cache", + "no-log", + "base_model", + "stream_timeout", + "supports_system_message", + "region_name", + "allowed_model_region", + "model_config", + "fastest_response", + "cooldown_time", + "cache_key", + "max_retries", + "azure_ad_token_provider", + "tenant_id", + "client_id", + "client_secret", + "user_continue_message", + "configurable_clientside_auth_params", + "weight", + "ensure_alternating_roles", + "assistant_continue_message", + "user_continue_message", + "fallback_depth", + "max_fallbacks", + "max_budget", + "budget_duration", +] + list(StandardCallbackDynamicParams.__annotations__.keys()) + + class KeyGenerationConfig(TypedDict, total=False): required_params: List[ str diff --git a/litellm/utils.py b/litellm/utils.py index 7fbd586d9812..3038598f0b30 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2442,6 +2442,23 @@ def _remove_strict_from_schema(schema): return schema +def _remove_unsupported_params( + non_default_params: dict, supported_openai_params: Optional[List[str]] +) -> dict: + """ + Remove unsupported params from non_default_params + """ + remove_keys = [] + if supported_openai_params is None: + return {} # no supported params, so no optional openai params to send + for param in non_default_params.keys(): + if param not in supported_openai_params: + remove_keys.append(param) + for key in remove_keys: + non_default_params.pop(key, None) + return non_default_params + + def get_optional_params( # noqa: PLR0915 # use the openai defaults # https://platform.openai.com/docs/api-reference/chat/create @@ -2688,11 +2705,13 @@ def _check_valid_arg(supported_params): # Always keeps this in elif code blocks else: unsupported_params[k] = non_default_params[k] + if unsupported_params: if litellm.drop_params is True or ( drop_params is not None and drop_params is True ): - pass + for k in unsupported_params.keys(): + non_default_params.pop(k, None) else: raise UnsupportedParamsError( status_code=500, diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index 5aa23e2dfa23..5efb280746f3 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -2224,13 +2224,19 @@ def mock_transform(*args, **kwargs): captured_data = result return result - with patch('litellm.AmazonConverseConfig._transform_request', side_effect=mock_transform): + with patch( + "litellm.AmazonConverseConfig._transform_request", side_effect=mock_transform + ): litellm.completion(**data) # Assert that additionalRequestParameters exists and contains topK - assert 'additionalModelRequestFields' in captured_data - assert 'inferenceConfig' in captured_data['additionalModelRequestFields'] - assert captured_data['additionalModelRequestFields']['inferenceConfig']['topK'] == 10 + assert "additionalModelRequestFields" in captured_data + assert "inferenceConfig" in captured_data["additionalModelRequestFields"] + assert ( + captured_data["additionalModelRequestFields"]["inferenceConfig"]["topK"] + == 10 + ) + def test_bedrock_empty_content_real_call(): completion( @@ -2261,3 +2267,61 @@ def test_bedrock_process_empty_text_blocks(): } modified_message = process_empty_text_blocks(**message) assert modified_message["content"][0]["text"] == "Please continue." + + +def test_nova_optional_params_tool_choice(): + litellm.drop_params = True + litellm.set_verbose = True + litellm.completion( + messages=[ + {"role": "user", "content": "A WWII competitive game for 4-8 players"} + ], + model="bedrock/us.amazon.nova-pro-v1:0", + temperature=0.3, + tools=[ + { + "type": "function", + "function": { + "name": "GameDefinition", + "description": "Correctly extracted `GameDefinition` with all the required parameters with correct types", + "parameters": { + "$defs": { + "TurnDurationEnum": { + "enum": ["action", "encounter", "battle", "operation"], + "title": "TurnDurationEnum", + "type": "string", + } + }, + "properties": { + "id": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": None, + "title": "Id", + }, + "prompt": {"title": "Prompt", "type": "string"}, + "name": {"title": "Name", "type": "string"}, + "description": {"title": "Description", "type": "string"}, + "competitve": {"title": "Competitve", "type": "boolean"}, + "players_min": {"title": "Players Min", "type": "integer"}, + "players_max": {"title": "Players Max", "type": "integer"}, + "turn_duration": { + "$ref": "#/$defs/TurnDurationEnum", + "description": "how long the passing of a turn should represent for a game at this scale", + }, + }, + "required": [ + "competitve", + "description", + "name", + "players_max", + "players_min", + "prompt", + "turn_duration", + ], + "type": "object", + }, + }, + } + ], + tool_choice={"type": "function", "function": {"name": "GameDefinition"}}, + ) diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 8c01a6889db3..bb7b38addc91 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -60,7 +60,7 @@ "gemini-1.5-flash-exp-0827", "gemini-pro-flash", "gemini-1.5-flash-exp-0827", - "gemini-2.0-flash-exp" + "gemini-2.0-flash-exp", ] @@ -308,7 +308,7 @@ async def test_vertex_ai_anthropic_async(): # ) @pytest.mark.asyncio @pytest.mark.flaky(retries=3, delay=1) -async def test_vertex_ai_anthropic_async_streaming(): +async def test_aaavertex_ai_anthropic_async_streaming(): # load_vertex_ai_credentials() try: litellm.set_verbose = True diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index b7d7473885fb..5902047b70ca 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -4513,3 +4513,23 @@ def test_openai_hallucinated_tool_call_util(function_name, expect_modification): else: assert len(response) == 1 assert response[0].function.name == function_name + + +def test_langfuse_completion(monkeypatch): + monkeypatch.setenv( + "LANGFUSE_PUBLIC_KEY", "pk-lf-b3db7e8e-c2f6-4fc7-825c-a541a8fbe003" + ) + monkeypatch.setenv( + "LANGFUSE_SECRET_KEY", "sk-lf-b11ef3a8-361c-4445-9652-12318b8596e4" + ) + monkeypatch.setenv("LANGFUSE_HOST", "https://us.cloud.langfuse.com") + litellm.set_verbose = True + resp = litellm.completion( + model="langfuse/gpt-3.5-turbo", + langfuse_public_key=os.getenv("LANGFUSE_PUBLIC_KEY"), + langfuse_secret_key=os.getenv("LANGFUSE_SECRET_KEY"), + langfuse_host="https://us.cloud.langfuse.com", + prompt_id="test-chat-prompt", + prompt_variables={"user_message": "this is used"}, + messages=[{"role": "user", "content": "this is ignored"}], + ) diff --git a/tests/local_testing/test_cost_calc.py b/tests/local_testing/test_cost_calc.py index 1831c2a45e0a..1cfc7609109a 100644 --- a/tests/local_testing/test_cost_calc.py +++ b/tests/local_testing/test_cost_calc.py @@ -92,7 +92,7 @@ def test_run(model: str): print("Non-stream cost : NONE") print(f"Non-stream cost : {completion_cost(response) * 100:.4f} (response)") - response = router.completion(**kwargs, stream=True) # type: ignore + response = router.completion(**kwargs, stream=True, stream_options={"include_usage": True}) # type: ignore response = stream_chunk_builder(list(response), messages=kwargs["messages"]) # type: ignore output = response.choices[0].message.content.replace("\n", "") # type: ignore streaming_cost_calc = completion_cost(response) * 100