Skip to content

Commit

Permalink
feat: allow specifying specific tool within toolgroup
Browse files Browse the repository at this point in the history
Summary:

E.g. `builtin::rag::knowledge_search`

Test Plan:

LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/ --safety-shield meta-llama/Llama-Guard-3-8B
  • Loading branch information
ehhuang committed Feb 25, 2025
1 parent 71eb6e1 commit 65c8e2d
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 52 deletions.
4 changes: 2 additions & 2 deletions docs/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@
}
],
"source": [
"model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
"model_id = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
"\n",
"model_id\n"
]
Expand Down Expand Up @@ -1688,7 +1688,7 @@
" enable_session_persistence=False,\n",
" toolgroups = [\n",
" {\n",
" \"name\": \"builtin::rag\",\n",
" \"name\": \"builtin::rag::knowledge_search\",\n",
" \"args\" : {\n",
" \"vector_db_ids\": [vector_db_id],\n",
" }\n",
Expand Down
5 changes: 4 additions & 1 deletion docs/source/building_applications/agent_execution_loop.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ agent_config = AgentConfig(
instructions="You are a helpful assistant",
# Enable both RAG and tool usage
toolgroups=[
{"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}},
{
"name": "builtin::rag::knowledge_search",
"args": {"vector_db_ids": ["my_docs"]}
},
"builtin::code_interpreter",
],
# Configure safety
Expand Down
2 changes: 1 addition & 1 deletion docs/source/building_applications/rag.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ agent_config = AgentConfig(
enable_session_persistence=False,
toolgroups=[
{
"name": "builtin::rag",
"name": "builtin::rag::knowledge_search",
"args": {
"vector_db_ids": [vector_db_id],
},
Expand Down
2 changes: 1 addition & 1 deletion docs/source/getting_started/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ agent_config = AgentConfig(
# Define tools available to the agent
toolgroups=[
{
"name": "builtin::rag",
"name": "builtin::rag::knowledge_search",
"args": {
"vector_db_ids": [vector_db_id],
},
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/distribution/ui/page/playground/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def rag_chat_page():
},
toolgroups=[
dict(
name="builtin::rag",
name="builtin::rag::knowledge_search",
args={
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
},
Expand Down
102 changes: 63 additions & 39 deletions llama_stack/providers/inline/agents/meta_reference/agent_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def make_random_string(length: int = 8):
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
RAG_TOOL_GROUP = "rag"


class ChatAgent(ShieldRunnerMixin):
Expand Down Expand Up @@ -497,19 +497,14 @@ async def _run(
# TODO: simplify all of this code, it can be simpler
toolgroup_args = {}
toolgroups = set()
for toolgroup in self.agent_config.toolgroups:
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args
tool_type, tool_group, tool_name = self._parse_toolgroup_name(toolgroup.name)
tool_group_name = "::".join([tool_type, tool_group])
toolgroups.add(tool_group_name)
toolgroup_args[tool_group_name] = toolgroup.args
else:
toolgroups.add(toolgroup)
if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroups.add(toolgroup.name)
toolgroup_args[toolgroup.name] = toolgroup.args
else:
toolgroups.add(toolgroup)

tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
Expand Down Expand Up @@ -542,7 +537,7 @@ async def _run(
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=[tool for tool in tool_defs.values()],
tools=tool_defs,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
response_format=self.agent_config.response_format,
stream=True,
Expand Down Expand Up @@ -764,7 +759,7 @@ async def _run(

async def _get_tool_defs(
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
# Determine which tools to include
agent_config_toolgroups = set(
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
Expand All @@ -779,13 +774,13 @@ async def _get_tool_defs(
}
)

tool_def_map = {}
tool_name_to_def = {}
tool_to_group = {}

for tool_def in self.agent_config.client_tools:
if tool_def_map.get(tool_def.name, None):
if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
tool_def_map[tool_def.name] = ToolDefinition(
tool_name_to_def[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
Expand All @@ -802,41 +797,70 @@ async def _get_tool_defs(
for toolgroup_name in agent_config_toolgroups:
if toolgroup_name not in toolgroups_for_turn_set:
continue
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)

tool_type, tool_group, tool_name = self._parse_toolgroup_name(toolgroup_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id="::".join([tool_type, tool_group]))
if tool_name != "" and not any(tool.identifier == tool_name for tool in tools.data):
raise ValueError(
f"Tool {tool_name} not found in toolgroup {'::'.join([tool_type, tool_group])}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
)

for tool_def in tools.data:
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
if tool_type == "builtin" and tool_group != RAG_TOOL_GROUP:
tool_name = tool_def.identifier
built_in_type = BuiltinTool.brave_search
if tool_name == "web_search":
built_in_type = BuiltinTool.brave_search
else:
built_in_type = BuiltinTool(tool_name)

if tool_def_map.get(built_in_type, None):
if tool_name_to_def.get(built_in_type, None):
raise ValueError(f"Tool {built_in_type} already exists")

tool_def_map[built_in_type] = ToolDefinition(tool_name=built_in_type)
tool_name_to_def[built_in_type] = ToolDefinition(tool_name=built_in_type)
tool_to_group[built_in_type] = tool_def.toolgroup_id
continue

if tool_def_map.get(tool_def.identifier, None):
if tool_name_to_def.get(tool_def.identifier, None):
raise ValueError(f"Tool {tool_def.identifier} already exists")
tool_def_map[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id

return tool_def_map, tool_to_group
if tool_name in ("", tool_def.identifier):
tool_name_to_def[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id

return list(tool_name_to_def.values()), tool_to_group

def _parse_toolgroup_name(self, toolgroup_name: str) -> tuple[str, str, Optional[str]]:
"""Parse a toolgroup name into its components.
Args:
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag::knowledge_search")
Returns:
A tuple of (tool_type, tool_group, tool_name)
"""
tool_type, tool_group, tool_name = "", "", ""
split_names = toolgroup_name.split("::")
if len(split_names) == 3:
# e.g. "builtin::rag::knowledge_search"
tool_type, tool_group, tool_name = split_names
elif len(split_names) == 2:
# e.g. "builtin::rag"
tool_type, tool_group = split_names
else:
tool_group = split_names[0]
return tool_type, tool_group, tool_name

async def handle_documents(
self,
Expand All @@ -845,8 +869,8 @@ async def handle_documents(
input_messages: List[Message],
tool_defs: Dict[str, ToolDefinition],
) -> None:
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
Expand Down
11 changes: 4 additions & 7 deletions tests/client-sdk/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,8 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
assert "get_boiling_point" in logs_str


def test_rag_agent(llama_stack_client, agent_config):
@pytest.mark.parametrize("rag_tool_name", ["builtin::rag::knowledge_search", "builtin::rag"])
def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
documents = [
Document(
Expand Down Expand Up @@ -469,7 +470,7 @@ def test_rag_agent(llama_stack_client, agent_config):
**agent_config,
"toolgroups": [
dict(
name="builtin::rag",
name=rag_tool_name,
args={
"vector_db_ids": [vector_db_id],
},
Expand All @@ -483,10 +484,6 @@ def test_rag_agent(llama_stack_client, agent_config):
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
"grouped",
),
(
"What `tune` command to use for getting access to Llama3-8B-Instruct ?",
"download",
),
]
for prompt, expected_kw in user_prompts:
response = rag_agent.create_turn(
Expand Down Expand Up @@ -541,7 +538,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
**agent_config,
"toolgroups": [
dict(
name="builtin::rag",
name="builtin::rag::knowledge_search",
args={"vector_db_ids": [vector_db_id]},
),
"builtin::code_interpreter",
Expand Down

0 comments on commit 65c8e2d

Please sign in to comment.