From 6a605242687dc9a2399cc7064ce0f9a031eac30e Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 8 Jan 2025 09:56:07 +0800 Subject: [PATCH] Add tools calling for dspy.LM --- dspy/clients/lm.py | 12 ++++++++++-- tests/clients/test_lm.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index e4e5bbcee..b2272b832 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -86,7 +86,13 @@ def __init__( ), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`" @with_callbacks - def __call__(self, prompt=None, messages=None, **kwargs): + def __call__(self, prompt=None, messages=None, tools=None, **kwargs): + if tools is not None and not litellm.supports_function_calling(self.model): + raise ValueError( + f"The model {self.model} does not support function calling, but tools were provided. Please use a " + "model that supports function calling." + ) + # Build the request. cache = kwargs.pop("cache", self.cache) messages = messages or [{"role": "user", "content": prompt}] @@ -99,7 +105,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): completion = cached_litellm_text_completion if cache else litellm_text_completion response = completion( - request=dict(model=self.model, messages=messages, **kwargs), + request=dict(model=self.model, messages=messages, tools=tools, **kwargs), num_retries=self.num_retries, ) if kwargs.get("logprobs"): @@ -110,6 +116,8 @@ def __call__(self, prompt=None, messages=None, **kwargs): } for c in response["choices"] ] + elif tools: + outputs = [{"text": c.message.content, "tool_calls": c.message.tool_calls} for c in response["choices"]] else: outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 77a942727..1564bbf8c 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -156,3 +156,35 @@ def test_lm_text_calls_are_retried_for_expected_failures( request_logs = read_litellm_test_server_request_logs(server_log_file_path) assert len(request_logs) == expected_num_retries + 1 # 1 initial request + 1 retries + + +def test_tools_rejected_for_non_function_models(): + with pytest.raises(ValueError): + lm = dspy.LM(model="palm/chat-bison") + lm("query", tools=[{"type": "function", "function": {}}]) + + +def test_lm_tool_calls_are_returned(litellm_test_server): + api_base, server_log_file_path = litellm_test_server + + openai_lm = dspy.LM( + model="openai/dspy-test-model", + api_base=api_base, + api_key="fakekey", + model_type="chat", + cache=False, + ) + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + ] + outputs = openai_lm("Query", tools=tools) + assert "tool_calls" in outputs[0]