-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6cba89a
commit ee0495c
Showing
54 changed files
with
3,082 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
377 changes: 377 additions & 0 deletions
377
docs/docs/examples/workflow/function_calling_agent.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,377 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Workflow for a Function Calling Agent\n", | ||
"\n", | ||
"This notebook walks through setting up a `Workflow` to construct a function calling agent from scratch.\n", | ||
"\n", | ||
"Function calling agents work by using an LLM that supports tools/functions in its API (OpenAI, Ollama, Anthropic, etc.) to call functions an use tools.\n", | ||
"\n", | ||
"Our workflow will be stateful with memory, and will be able to call the LLM to select tools and process incoming user messages." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install -U llama-index" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"\n", | ||
"os.environ[\"OPENAI_API_KEY\"] = \"sk-proj-...\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Since workflows are async first, this all runs fine in a notebook. If you were running in your own code, you would want to use `asyncio.run()` to start an async event loop if one isn't already running.\n", | ||
"\n", | ||
"```python\n", | ||
"async def main():\n", | ||
" <async code>\n", | ||
"\n", | ||
"if __name__ == \"__main__\":\n", | ||
" import asyncio\n", | ||
" asyncio.run(main())\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Designing the Workflow\n", | ||
"\n", | ||
"An agent consists of several steps\n", | ||
"1. Handling the latest incoming user message, including adding to memory and getting the latest chat history\n", | ||
"2. Calling the LLM with tools + chat history\n", | ||
"3. Parsing out tool calls (if any)\n", | ||
"4. If there are tool calls, call them, and loop until there are none\n", | ||
"5. When there is no tool calls, return the LLM response\n", | ||
"\n", | ||
"### The Workflow Events\n", | ||
"\n", | ||
"To handle these steps, we need to define a few events:\n", | ||
"1. An event to handle new messages and prepare the chat history\n", | ||
"2. An event to trigger tool calls\n", | ||
"3. An event to handle the results of tool calls\n", | ||
"\n", | ||
"The other steps will use the built-in `StartEvent` and `StopEvent` events." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from llama_index.core.llms import ChatMessage\n", | ||
"from llama_index.core.tools import ToolSelection, ToolOutput\n", | ||
"from llama_index.core.workflow import Event\n", | ||
"\n", | ||
"\n", | ||
"class InputEvent(Event):\n", | ||
" input: list[ChatMessage]\n", | ||
"\n", | ||
"\n", | ||
"class ToolCallEvent(Event):\n", | ||
" tool_calls: list[ToolSelection]\n", | ||
"\n", | ||
"\n", | ||
"class FunctionOutputEvent(Event):\n", | ||
" output: ToolOutput" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### The Workflow Itself\n", | ||
"\n", | ||
"With our events defined, we can construct our workflow and steps. \n", | ||
"\n", | ||
"Note that the workflow automatically validates itself using type annotations, so the type annotations on our steps are very helpful!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from typing import Any, List\n", | ||
"\n", | ||
"from llama_index.core.llms.function_calling import FunctionCallingLLM\n", | ||
"from llama_index.core.memory import ChatMemoryBuffer\n", | ||
"from llama_index.core.tools.types import BaseTool\n", | ||
"from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step\n", | ||
"\n", | ||
"\n", | ||
"class FuncationCallingAgent(Workflow):\n", | ||
" def __init__(\n", | ||
" self,\n", | ||
" *args: Any,\n", | ||
" llm: FunctionCallingLLM | None = None,\n", | ||
" tools: List[BaseTool] | None = None,\n", | ||
" **kwargs: Any,\n", | ||
" ) -> None:\n", | ||
" super().__init__(*args, **kwargs)\n", | ||
" self.tools = tools or []\n", | ||
"\n", | ||
" self.llm = llm or OpenAI()\n", | ||
" assert self.llm.metadata.is_function_calling_model\n", | ||
"\n", | ||
" self.memory = ChatMemoryBuffer.from_defaults(llm=llm)\n", | ||
" self.sources = []\n", | ||
"\n", | ||
" @step()\n", | ||
" async def prepare_chat_history(self, ev: StartEvent) -> InputEvent:\n", | ||
" # clear sources\n", | ||
" self.sources = []\n", | ||
"\n", | ||
" # get user input\n", | ||
" user_input = ev.get(\"input\")\n", | ||
" user_msg = ChatMessage(role=\"user\", content=user_input)\n", | ||
" self.memory.put(user_msg)\n", | ||
"\n", | ||
" # get chat history\n", | ||
" chat_history = self.memory.get()\n", | ||
" return InputEvent(input=chat_history)\n", | ||
"\n", | ||
" @step()\n", | ||
" async def handle_llm_input(\n", | ||
" self, ev: InputEvent\n", | ||
" ) -> ToolCallEvent | StopEvent:\n", | ||
" chat_history = ev.input\n", | ||
"\n", | ||
" response = await self.llm.achat_with_tools(\n", | ||
" self.tools, chat_history=chat_history\n", | ||
" )\n", | ||
" self.memory.put(response.message)\n", | ||
"\n", | ||
" tool_calls = self.llm.get_tool_calls_from_response(\n", | ||
" response, error_on_no_tool_call=False\n", | ||
" )\n", | ||
"\n", | ||
" if not tool_calls:\n", | ||
" return StopEvent(\n", | ||
" result={\"response\": response, \"sources\": [*self.sources]}\n", | ||
" )\n", | ||
" else:\n", | ||
" return ToolCallEvent(tool_calls=tool_calls)\n", | ||
"\n", | ||
" @step()\n", | ||
" async def handle_tool_calls(self, ev: ToolCallEvent) -> InputEvent:\n", | ||
" tool_calls = ev.tool_calls\n", | ||
" tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools}\n", | ||
"\n", | ||
" tool_msgs = []\n", | ||
"\n", | ||
" # call tools -- safely!\n", | ||
" for tool_call in tool_calls:\n", | ||
" tool = tools_by_name.get(tool_call.tool_name)\n", | ||
" additional_kwargs = {\n", | ||
" \"tool_call_id\": tool_call.tool_id,\n", | ||
" \"name\": tool.metadata.get_name(),\n", | ||
" }\n", | ||
" if not tool:\n", | ||
" tool_msgs.append(\n", | ||
" ChatMessage(\n", | ||
" role=\"tool\",\n", | ||
" content=f\"Tool {tool_call.tool_name} does not exist\",\n", | ||
" additional_kwargs=additional_kwargs,\n", | ||
" )\n", | ||
" )\n", | ||
" continue\n", | ||
"\n", | ||
" try:\n", | ||
" tool_output = tool(**tool_call.tool_kwargs)\n", | ||
" self.sources.append(tool_output)\n", | ||
" tool_msgs.append(\n", | ||
" ChatMessage(\n", | ||
" role=\"tool\",\n", | ||
" content=tool_output.content,\n", | ||
" additional_kwargs=additional_kwargs,\n", | ||
" )\n", | ||
" )\n", | ||
" except Exception as e:\n", | ||
" tool_msgs.append(\n", | ||
" ChatMessage(\n", | ||
" role=\"tool\",\n", | ||
" content=f\"Encountered error in tool call: {e}\",\n", | ||
" additional_kwargs=additional_kwargs,\n", | ||
" )\n", | ||
" )\n", | ||
"\n", | ||
" for msg in tool_msgs:\n", | ||
" self.memory.put(msg)\n", | ||
"\n", | ||
" chat_history = self.memory.get()\n", | ||
" return InputEvent(input=chat_history)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"And thats it! Let's explore the workflow we wrote a bit.\n", | ||
"\n", | ||
"`prepare_chat_history()`:\n", | ||
"This is our main entry point. It handles adding the user message to memory, and uses the memory to get the latest chat history. It returns an `InputEvent`.\n", | ||
"\n", | ||
"`handle_llm_input()`:\n", | ||
"Triggered by an `InputEvent`, it uses the chat history and tools to prompt the llm. If tool calls are found, a `ToolCallEvent` is emitted. Otherwise, we say the workflow is done an emit a `StopEvent`\n", | ||
"\n", | ||
"`handle_tool_calls()`:\n", | ||
"Triggered by `ToolCallEvent`, it calls tools with error handling and returns tool outputs. This event triggers a **loop** since it emits an `InputEvent`, which takes us back to `handle_llm_input()`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Run the Workflow!\n", | ||
"\n", | ||
"**NOTE:** With loops, we need to be mindful of runtime. Here, we set a timeout of 120s." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Running step prepare_chat_history\n", | ||
"Step prepare_chat_history produced event InputEvent\n", | ||
"Running step handle_llm_input\n", | ||
"Step handle_llm_input produced event StopEvent\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from llama_index.core.tools import FunctionTool\n", | ||
"from llama_index.llms.openai import OpenAI\n", | ||
"\n", | ||
"\n", | ||
"def add(x: int, y: int) -> int:\n", | ||
" \"\"\"Useful function to add two numbers.\"\"\"\n", | ||
" return x + y\n", | ||
"\n", | ||
"\n", | ||
"def multiply(x: int, y: int) -> int:\n", | ||
" \"\"\"Useful function to multiply two numbers.\"\"\"\n", | ||
" return x * y\n", | ||
"\n", | ||
"\n", | ||
"tools = [\n", | ||
" FunctionTool.from_defaults(add),\n", | ||
" FunctionTool.from_defaults(multiply),\n", | ||
"]\n", | ||
"\n", | ||
"agent = FuncationCallingAgent(\n", | ||
" llm=OpenAI(model=\"gpt-4o-mini\"), tools=tools, timeout=120, verbose=True\n", | ||
")\n", | ||
"\n", | ||
"ret = await agent.run(input=\"Hello!\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"assistant: Hello! How can I assist you today?\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(ret[\"response\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Running step prepare_chat_history\n", | ||
"Step prepare_chat_history produced event InputEvent\n", | ||
"Running step handle_llm_input\n", | ||
"Step handle_llm_input produced event ToolCallEvent\n", | ||
"Running step handle_tool_calls\n", | ||
"Step handle_tool_calls produced event InputEvent\n", | ||
"Running step handle_llm_input\n", | ||
"Step handle_llm_input produced event ToolCallEvent\n", | ||
"Running step handle_tool_calls\n", | ||
"Step handle_tool_calls produced event InputEvent\n", | ||
"Running step handle_llm_input\n", | ||
"Step handle_llm_input produced event StopEvent\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"ret = await agent.run(input=\"What is (2123 + 2321) * 312?\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"assistant: The result of \\((2123 + 2321) \\times 312\\) is \\(1,386,528\\).\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(ret[\"response\"])" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "llama-index-cDlKpkFt-py3.11", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.