diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index d9b72f2ff..b5a96485a 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -2,6 +2,7 @@ from dspy.utils.callback import with_callbacks + class Adapter(ABC): def __init__(self, callbacks=None): self.callbacks = callbacks or [] @@ -22,19 +23,26 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs): try: for output in outputs: - output_logprobs = None - if isinstance(output, dict): - output, output_logprobs = output["text"], output["logprobs"] + output_text = output["text"] + else: + output_text = output + + if output_text: + # Output text, e.g., response["choices"][0]["message"]["content"] can be None when tool calls are + # used. + value = self.parse(signature, output_text) + if not set(value.keys()) == set(signature.output_fields.keys()): + raise ValueError(f"Expected {signature.output_fields.keys()} but got {value.keys()}") + else: + value = {} + + if isinstance(output, dict) and "logprobs" in output: + value["logprobs"] = output["logprobs"] - value = self.parse(signature, output) + if isinstance(output, dict) and "tool_calls" in output: + value["tool_calls"] = output["tool_calls"] - assert set(value.keys()) == set(signature.output_fields.keys()), \ - f"Expected {signature.output_fields.keys()} but got {value.keys()}" - - if output_logprobs is not None: - value["logprobs"] = output_logprobs - values.append(value) return values diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 7eed644f9..fcc90b85a 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -58,10 +58,21 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs): values = [] for output in outputs: - value = self.parse(signature, output) - assert set(value.keys()) == set( - signature.output_fields.keys() - ), f"Expected {signature.output_fields.keys()} but got {value.keys()}" + if isinstance(output, dict): + output_text = output["text"] + else: + output_text = output + + if output_text: + # Output text, e.g., response["choices"][0]["message"]["content"] can be None when tool calls are used. + value = self.parse(signature, output_text) + if not set(value.keys()) == set(signature.output_fields.keys()): + raise ValueError(f"Expected {signature.output_fields.keys()} but got {value.keys()}") + else: + value = {} + + if isinstance(output, dict) and "tool_calls" in output: + value["tool_calls"] = output["output_tool_calls"] values.append(value) return values diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index e4e5bbcee..4154bf751 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -86,7 +86,13 @@ def __init__( ), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`" @with_callbacks - def __call__(self, prompt=None, messages=None, **kwargs): + def __call__(self, prompt=None, messages=None, tools=None, **kwargs): + if tools is not None and not litellm.supports_function_calling(self.model): + raise ValueError( + f"The model {self.model} does not support function calling, but tools were provided. Please use a " + "model that supports function calling." + ) + # Build the request. cache = kwargs.pop("cache", self.cache) messages = messages or [{"role": "user", "content": prompt}] @@ -99,19 +105,33 @@ def __call__(self, prompt=None, messages=None, **kwargs): completion = cached_litellm_text_completion if cache else litellm_text_completion response = completion( - request=dict(model=self.model, messages=messages, **kwargs), + request=dict(model=self.model, messages=messages, tools=tools, **kwargs), num_retries=self.num_retries, ) + + output_text = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] + output_logprobs = None + output_tool_calls = None if kwargs.get("logprobs"): - outputs = [ - { - "text": c.message.content if hasattr(c, "message") else c["text"], - "logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"], - } - for c in response["choices"] + output_logprobs = [c.logprobs if hasattr(c, "logprobs") else c["logprobs"] for c in response["choices"]] + if tools: + output_tool_calls = [ + c.message.tool_calls if hasattr(c, "message") else c["tool_calls"] for c in response["choices"] ] + + outputs = [] + if output_logprobs is None and output_tool_calls is None: + # If no logprobs or tool_calls are provided, return the text only as a list instead of a dict for + # backward compatibility. + outputs = output_text else: - outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] + for i, text in enumerate(output_text): + output = {"text": text} + if kwargs.get("logprobs"): + output["logprobs"] = output_logprobs[i] + if tools: + output["tool_calls"] = output_tool_calls[i] + outputs.append(output) # Logging, with removed api key & where `cost` is None on cache hit. kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")} diff --git a/dspy/predict/predict.py b/dspy/predict/predict.py index 0e15137b3..97dfacd78 100644 --- a/dspy/predict/predict.py +++ b/dspy/predict/predict.py @@ -106,6 +106,9 @@ def forward(self, **kwargs): missing = [k for k in signature.input_fields if k not in kwargs] print(f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}.") + if "tools" in kwargs: + config["tools"] = kwargs.pop("tools") + import dspy adapter = dspy.settings.adapter or dspy.ChatAdapter() completions = adapter(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs) diff --git a/dspy/predict/predict_with_tools.py b/dspy/predict/predict_with_tools.py new file mode 100644 index 000000000..6f79b2f09 --- /dev/null +++ b/dspy/predict/predict_with_tools.py @@ -0,0 +1,48 @@ +from dspy.predict.predict import Predict + + +class PredictWithTools(Predict): + """Predict with tool calling support. + + This class is a thin wrapper around the Predict class that explicit has tool calling support. + """ + def __init__(self, signature, callbacks=None, tools=None, tool_choice="auto", **configs): + """Initialize the PredictWithTools class. + + Args: + signature (str): The signature of `PredictWithTools`. + callbacks (list): The callbacks that are called before and after the Predict call. + tools (list): The list of tools to use. For more details, please refer to + https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools + tool_choice (str): The tool choice strategy, defaults to "auto". For more details, please refer to + https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice + **configs: Additional configurations. + """ + super().__init__(signature, callbacks, **configs) + self.tools = tools + self.tool_choice = tool_choice + + def __call__(self, tools=None, tool_choice="auto", **kwargs): + """Call the PredictWithTools class. + + This method performs prediction using the configured language model, with the ability + to invoke tools if specified. It can either return tool calling instructions or + direct prediction results. + + Args: + tools (list, optional): List of available tools for the LLM to choose from. + If provided, overrides the `tools` set during initialization. + tool_choice (str, optional): Strategy for tool selection, defaults to "auto". + If provided, overrides the `tool_choice` set during initialization. + **kwargs: Additional arguments passed to the underlying Predict class. + + Returns: + dspy.Prediction: One of two types: + - Tool calling scenario: Contains a single field 'tool_calls' with a list + of tool invocations and their arguments. + - Direct prediction: Contains fields as specified by self.signature. + """ + kwargs["tools"] = tools or self.tools + kwargs["tool_choice"] = tool_choice or self.tool_choice + + return super().__call__(**kwargs) diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 77a942727..8b2c96803 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -3,6 +3,7 @@ import litellm import pydantic import pytest +import os import dspy from tests.test_utils.server import litellm_test_server, read_litellm_test_server_request_logs @@ -156,3 +157,36 @@ def test_lm_text_calls_are_retried_for_expected_failures( request_logs = read_litellm_test_server_request_logs(server_log_file_path) assert len(request_logs) == expected_num_retries + 1 # 1 initial request + 1 retries + + +def test_tools_rejected_for_non_function_models(litellm_test_server): + api_base, server_log_file_path = litellm_test_server + + with mock.patch("dspy.clients.lm.litellm.supports_function_calling", return_value=False): + lm = dspy.LM( + model="openai/dspy-test-model", + api_base=api_base, + api_key="fakekey", + model_type="chat", + ) + with pytest.raises(ValueError): + lm("query", tools=[{"type": "function", "function": {}}]) + + +@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OpenAI API key is not set") +def test_lm_tool_calls_are_returned(): + openai_lm = dspy.LM(model="openai/gpt-4o-mini") + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + ] + outputs = openai_lm("what's the weather in Paris?", tools=tools) + assert "tool_calls" in outputs[0] diff --git a/tests/predict/test_predict.py b/tests/predict/test_predict.py index 828507db9..c1bf9f68a 100644 --- a/tests/predict/test_predict.py +++ b/tests/predict/test_predict.py @@ -9,6 +9,7 @@ import dspy from dspy import Predict, Signature from dspy.utils.dummies import DummyLM +import os def test_initialization_with_string_signature(): @@ -344,3 +345,22 @@ def test_load_state_chaining(): new_instance = Predict("question -> answer").load_state(state) assert new_instance is not None assert new_instance.demos == original.demos + +@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OpenAI API key is not set") +def test_predict_tool_calls_are_returned(): + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + ] + with dspy.context(lm=dspy.LM("openai/gpt-4o-mini")): + predict = Predict("question -> answer") + outputs = predict(question="what's the weather in Paris?", tools=tools) + assert "tool_calls" in outputs diff --git a/tests/predict/test_predict_with_tools.py b/tests/predict/test_predict_with_tools.py new file mode 100644 index 000000000..73885e617 --- /dev/null +++ b/tests/predict/test_predict_with_tools.py @@ -0,0 +1,26 @@ +import os + +import pytest + +import dspy +from dspy.predict.predict_with_tools import PredictWithTools + + +@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OpenAI API key is not set") +def test_basic_predict_with_tools(): + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + ] + with dspy.context(lm=dspy.LM("openai/gpt-4o-mini")): + predict = PredictWithTools("question -> answer", tools=tools) + outputs = predict(question="what's the weather in Paris?", tools=tools) + assert "tool_calls" in outputs \ No newline at end of file