Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tools calling for dspy.LM #2023

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dspy.utils.callback import with_callbacks


class Adapter(ABC):
def __init__(self, callbacks=None):
self.callbacks = callbacks or []
Expand All @@ -22,19 +23,26 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs):

try:
for output in outputs:
output_logprobs = None

if isinstance(output, dict):
output, output_logprobs = output["text"], output["logprobs"]
output_text = output["text"]
else:
output_text = output

if output_text:
# Output text, e.g., response["choices"][0]["message"]["content"] can be None when tool calls are
# used.
value = self.parse(signature, output_text)
if not set(value.keys()) == set(signature.output_fields.keys()):
raise ValueError(f"Expected {signature.output_fields.keys()} but got {value.keys()}")
else:
value = {}

if isinstance(output, dict) and "logprobs" in output:
value["logprobs"] = output["logprobs"]

value = self.parse(signature, output)
if isinstance(output, dict) and "tool_calls" in output:
value["tool_calls"] = output["tool_calls"]

assert set(value.keys()) == set(signature.output_fields.keys()), \
f"Expected {signature.output_fields.keys()} but got {value.keys()}"

if output_logprobs is not None:
value["logprobs"] = output_logprobs

values.append(value)

return values
Expand Down
19 changes: 15 additions & 4 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,21 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs):
values = []

for output in outputs:
value = self.parse(signature, output)
assert set(value.keys()) == set(
signature.output_fields.keys()
), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
if isinstance(output, dict):
output_text = output["text"]
else:
output_text = output

if output_text:
# Output text, e.g., response["choices"][0]["message"]["content"] can be None when tool calls are used.
value = self.parse(signature, output_text)
if not set(value.keys()) == set(signature.output_fields.keys()):
raise ValueError(f"Expected {signature.output_fields.keys()} but got {value.keys()}")
else:
value = {}

if isinstance(output, dict) and "tool_calls" in output:
value["tool_calls"] = output["output_tool_calls"]
values.append(value)

return values
Expand Down
38 changes: 29 additions & 9 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,13 @@ def __init__(
), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"

@with_callbacks
def __call__(self, prompt=None, messages=None, **kwargs):
def __call__(self, prompt=None, messages=None, tools=None, **kwargs):
if tools is not None and not litellm.supports_function_calling(self.model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @chenmoneygithub ! Right now, we have a general solution that works well.

I'm not sure we should replace it with a special solution that doesn't work for many model providers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's a good point. Yuki from Mlflow team reached out asking if we can use the standard tool calling supported by litellm so that we can trace the tool calls. Which kinda makes sense because now we are mixing the tools calls into the message.content, with this approach it's hard to trace the tool calls.

My plan is this:

  • for OAI/Anthropic/Databricks providers, which has standard function calling, we can use the standard way, which is identical across these providers.
  • for other models, e.g., local hosted models, we use our current logic.

I am seeing two downsides:

  1. When we change the logic in dspy.ReAct, there might be some performance change (drop or increase, I am not sure), we need to be cautious.
  2. dspy.ReAct will have 2 branches, one for LLms that support tool calling, the other for LLMs that don't support tool calling. So the code will be slightly more complex: https://github.com/stanfordnlp/dspy/blob/main/dspy/predict/react.py#L85-L96.

And I am seeing 2 benefits:

  1. We can enable clean tool calling tracing, which is useful for debugging complex agents.
  2. I expect the tool calling to be more robust with OAI/Anthropic if we use their protocal.

Please let me know your thoughts!

raise ValueError(
f"The model {self.model} does not support function calling, but tools were provided. Please use a "
"model that supports function calling."
)

# Build the request.
cache = kwargs.pop("cache", self.cache)
messages = messages or [{"role": "user", "content": prompt}]
Expand All @@ -99,19 +105,33 @@ def __call__(self, prompt=None, messages=None, **kwargs):
completion = cached_litellm_text_completion if cache else litellm_text_completion

response = completion(
request=dict(model=self.model, messages=messages, **kwargs),
request=dict(model=self.model, messages=messages, tools=tools, **kwargs),
num_retries=self.num_retries,
)

output_text = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
output_logprobs = None
output_tool_calls = None
if kwargs.get("logprobs"):
outputs = [
{
"text": c.message.content if hasattr(c, "message") else c["text"],
"logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"],
}
for c in response["choices"]
output_logprobs = [c.logprobs if hasattr(c, "logprobs") else c["logprobs"] for c in response["choices"]]
if tools:
output_tool_calls = [
c.message.tool_calls if hasattr(c, "message") else c["tool_calls"] for c in response["choices"]
]

outputs = []
if output_logprobs is None and output_tool_calls is None:
# If no logprobs or tool_calls are provided, return the text only as a list instead of a dict for
# backward compatibility.
outputs = output_text
else:
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
for i, text in enumerate(output_text):
output = {"text": text}
if kwargs.get("logprobs"):
output["logprobs"] = output_logprobs[i]
if tools:
output["tool_calls"] = output_tool_calls[i]
outputs.append(output)

# Logging, with removed api key & where `cost` is None on cache hit.
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}
Expand Down
3 changes: 3 additions & 0 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def forward(self, **kwargs):
missing = [k for k in signature.input_fields if k not in kwargs]
print(f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}.")

if "tools" in kwargs:
config["tools"] = kwargs.pop("tools")

import dspy
adapter = dspy.settings.adapter or dspy.ChatAdapter()
completions = adapter(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs)
Expand Down
48 changes: 48 additions & 0 deletions dspy/predict/predict_with_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from dspy.predict.predict import Predict


class PredictWithTools(Predict):
"""Predict with tool calling support.

This class is a thin wrapper around the Predict class that explicit has tool calling support.
"""
def __init__(self, signature, callbacks=None, tools=None, tool_choice="auto", **configs):
"""Initialize the PredictWithTools class.

Args:
signature (str): The signature of `PredictWithTools`.
callbacks (list): The callbacks that are called before and after the Predict call.
tools (list): The list of tools to use. For more details, please refer to
https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools
tool_choice (str): The tool choice strategy, defaults to "auto". For more details, please refer to
https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
**configs: Additional configurations.
"""
super().__init__(signature, callbacks, **configs)
self.tools = tools
self.tool_choice = tool_choice

def __call__(self, tools=None, tool_choice="auto", **kwargs):
"""Call the PredictWithTools class.

This method performs prediction using the configured language model, with the ability
to invoke tools if specified. It can either return tool calling instructions or
direct prediction results.

Args:
tools (list, optional): List of available tools for the LLM to choose from.
If provided, overrides the `tools` set during initialization.
tool_choice (str, optional): Strategy for tool selection, defaults to "auto".
If provided, overrides the `tool_choice` set during initialization.
**kwargs: Additional arguments passed to the underlying Predict class.

Returns:
dspy.Prediction: One of two types:
- Tool calling scenario: Contains a single field 'tool_calls' with a list
of tool invocations and their arguments.
- Direct prediction: Contains fields as specified by self.signature.
"""
kwargs["tools"] = tools or self.tools
kwargs["tool_choice"] = tool_choice or self.tool_choice

return super().__call__(**kwargs)
34 changes: 34 additions & 0 deletions tests/clients/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import litellm
import pydantic
import pytest
import os

import dspy
from tests.test_utils.server import litellm_test_server, read_litellm_test_server_request_logs
Expand Down Expand Up @@ -156,3 +157,36 @@ def test_lm_text_calls_are_retried_for_expected_failures(

request_logs = read_litellm_test_server_request_logs(server_log_file_path)
assert len(request_logs) == expected_num_retries + 1 # 1 initial request + 1 retries


def test_tools_rejected_for_non_function_models(litellm_test_server):
api_base, server_log_file_path = litellm_test_server

with mock.patch("dspy.clients.lm.litellm.supports_function_calling", return_value=False):
lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
model_type="chat",
)
with pytest.raises(ValueError):
lm("query", tools=[{"type": "function", "function": {}}])


@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OpenAI API key is not set")
def test_lm_tool_calls_are_returned():
openai_lm = dspy.LM(model="openai/gpt-4o-mini")
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
]
outputs = openai_lm("what's the weather in Paris?", tools=tools)
assert "tool_calls" in outputs[0]
20 changes: 20 additions & 0 deletions tests/predict/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dspy
from dspy import Predict, Signature
from dspy.utils.dummies import DummyLM
import os


def test_initialization_with_string_signature():
Expand Down Expand Up @@ -344,3 +345,22 @@ def test_load_state_chaining():
new_instance = Predict("question -> answer").load_state(state)
assert new_instance is not None
assert new_instance.demos == original.demos

@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OpenAI API key is not set")
def test_predict_tool_calls_are_returned():
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
]
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini")):
predict = Predict("question -> answer")
outputs = predict(question="what's the weather in Paris?", tools=tools)
assert "tool_calls" in outputs
26 changes: 26 additions & 0 deletions tests/predict/test_predict_with_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os

import pytest

import dspy
from dspy.predict.predict_with_tools import PredictWithTools


@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OpenAI API key is not set")
def test_basic_predict_with_tools():
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
]
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini")):
predict = PredictWithTools("question -> answer", tools=tools)
outputs = predict(question="what's the weather in Paris?", tools=tools)
assert "tool_calls" in outputs
Loading