Skip to content

Commit

Permalink
Add tools calling for dspy.LM
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub committed Jan 10, 2025
1 parent fea2d38 commit 6a60524
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
12 changes: 10 additions & 2 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,13 @@ def __init__(
), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"

@with_callbacks
def __call__(self, prompt=None, messages=None, **kwargs):
def __call__(self, prompt=None, messages=None, tools=None, **kwargs):
if tools is not None and not litellm.supports_function_calling(self.model):
raise ValueError(
f"The model {self.model} does not support function calling, but tools were provided. Please use a "
"model that supports function calling."
)

# Build the request.
cache = kwargs.pop("cache", self.cache)
messages = messages or [{"role": "user", "content": prompt}]
Expand All @@ -99,7 +105,7 @@ def __call__(self, prompt=None, messages=None, **kwargs):
completion = cached_litellm_text_completion if cache else litellm_text_completion

response = completion(
request=dict(model=self.model, messages=messages, **kwargs),
request=dict(model=self.model, messages=messages, tools=tools, **kwargs),
num_retries=self.num_retries,
)
if kwargs.get("logprobs"):
Expand All @@ -110,6 +116,8 @@ def __call__(self, prompt=None, messages=None, **kwargs):
}
for c in response["choices"]
]
elif tools:
outputs = [{"text": c.message.content, "tool_calls": c.message.tool_calls} for c in response["choices"]]
else:
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]

Expand Down
32 changes: 32 additions & 0 deletions tests/clients/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,35 @@ def test_lm_text_calls_are_retried_for_expected_failures(

request_logs = read_litellm_test_server_request_logs(server_log_file_path)
assert len(request_logs) == expected_num_retries + 1 # 1 initial request + 1 retries


def test_tools_rejected_for_non_function_models():
with pytest.raises(ValueError):
lm = dspy.LM(model="palm/chat-bison")
lm("query", tools=[{"type": "function", "function": {}}])


def test_lm_tool_calls_are_returned(litellm_test_server):
api_base, server_log_file_path = litellm_test_server

openai_lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
model_type="chat",
cache=False,
)
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
]
outputs = openai_lm("Query", tools=tools)
assert "tool_calls" in outputs[0]

0 comments on commit 6a60524

Please sign in to comment.