From 32fcc31eb8b22273d2f0dc51559f31d6a8b1143a Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Thu, 20 Feb 2025 14:32:20 -0800 Subject: [PATCH] feat: tool outputs metadata Summary: Allows tools to output metadata. This is useful for evaluating tool outputs, e.g. RAG tool will output document IDs, which can be used to score recall. Will need to make a similar change on the client side to support ClientTool outputting metadata. Test Plan: LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --- docs/_static/llama-stack-spec.html | 78 +++++++++++++++++++ docs/_static/llama-stack-spec.yaml | 32 ++++++++ llama_stack/apis/inference/inference.py | 1 + llama_stack/apis/tools/rag_tool.py | 1 + llama_stack/apis/tools/tools.py | 1 + .../agents/meta_reference/agent_instance.py | 38 ++++----- .../inline/tool_runtime/rag/memory.py | 7 +- tests/client-sdk/agents/test_agents.py | 11 ++- 8 files changed, 141 insertions(+), 28 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 40c1676855..a9023e3f69 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4454,6 +4454,31 @@ }, "content": { "$ref": "#/components/schemas/InterleavedContent" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, @@ -6625,6 +6650,31 @@ }, "error_code": { "type": "integer" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, @@ -7474,9 +7524,37 @@ "properties": { "content": { "$ref": "#/components/schemas/InterleavedContent" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, + "required": [ + "metadata" + ], "title": "RAGQueryResult" }, "QueryChunksRequest": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index c5043665bb..9159993acd 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2896,6 +2896,16 @@ components: - type: string content: $ref: '#/components/schemas/InterleavedContent' + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object additionalProperties: false required: - call_id @@ -4289,6 +4299,16 @@ components: type: string error_code: type: integer + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object additionalProperties: false required: - content @@ -4862,7 +4882,19 @@ components: properties: content: $ref: '#/components/schemas/InterleavedContent' + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object additionalProperties: false + required: + - metadata title: RAGQueryResult QueryChunksRequest: type: object diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index a3fb694776..b17a5abeba 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -165,6 +165,7 @@ class ToolResponse(BaseModel): call_id: str tool_name: Union[BuiltinTool, str] content: InterleavedContent + metadata: Optional[Dict[str, Any]] = None @field_validator("tool_name", mode="before") @classmethod diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index cff8eeefe9..2b9ef10d8b 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -26,6 +26,7 @@ class RAGDocument(BaseModel): @json_schema_type class RAGQueryResult(BaseModel): content: Optional[InterleavedContent] = None + metadata: Dict[str, Any] = Field(default_factory=dict) @json_schema_type diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index b83be127f2..a4d84edbe6 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -72,6 +72,7 @@ class ToolInvocationResult(BaseModel): content: InterleavedContent error_message: Optional[str] = None error_code: Optional[int] = None + metadata: Optional[Dict[str, Any]] = None class ToolStore(Protocol): diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 1c21df57f1..a51608ca57 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -60,7 +60,7 @@ UserMessage, ) from llama_stack.apis.safety import Safety -from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime +from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition from llama_stack.providers.utils.kvstore import KVStore @@ -456,6 +456,7 @@ async def _run( call_id="", tool_name=MEMORY_QUERY_TOOL, content=retrieved_context or [], + metadata=result.metadata, ) ], ), @@ -650,13 +651,21 @@ async def _run( }, ) as span: tool_execution_start_time = datetime.now() - result_messages = await execute_tool_call_maybe( + tool_call = message.tool_calls[0] + tool_result = await execute_tool_call_maybe( self.tool_runtime_api, session_id, - [message], + tool_call, toolgroup_args, tool_to_group, ) + result_messages = [ + ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=tool_result.content, + ) + ] assert len(result_messages) == 1, "Currently not supporting multiple messages" result_message = result_messages[0] span.set_attribute("output", result_message.model_dump_json()) @@ -675,6 +684,7 @@ async def _run( call_id=result_message.call_id, tool_name=result_message.tool_name, content=result_message.content, + metadata=tool_result.metadata, ) ], started_at=tool_execution_start_time, @@ -913,19 +923,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa async def execute_tool_call_maybe( tool_runtime_api: ToolRuntime, session_id: str, - messages: List[CompletionMessage], + tool_call: ToolCall, toolgroup_args: Dict[str, Dict[str, Any]], tool_to_group: Dict[str, str], -) -> List[ToolResponseMessage]: - # While Tools.run interface takes a list of messages, - # All tools currently only run on a single message - # When this changes, we can drop this assert - # Whether to call tools on each message and aggregate - # or aggregate and call tool once, reamins to be seen. - assert len(messages) == 1, "Expected single message" - message = messages[0] - - tool_call = message.tool_calls[0] +) -> ToolInvocationResult: name = tool_call.tool_name group_name = tool_to_group.get(name, None) if group_name is None: @@ -946,14 +947,7 @@ async def execute_tool_call_maybe( **tool_call_args, ), ) - - return [ - ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=result.content, - ) - ] + return result def _interpret_content_as_attachment( diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index a6cd579238..306bd78a60 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -119,10 +119,10 @@ async def query( # sort by score chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) - + chunks = chunks[: query_config.max_chunks] tokens = 0 picked = [] - for c in chunks[: query_config.max_chunks]: + for c in chunks: metadata = c.metadata tokens += metadata["token_count"] if tokens > query_config.max_tokens_in_context: @@ -146,6 +146,9 @@ async def query( text="\n=== END-RETRIEVED-CONTEXT ===\n", ), ], + metadata={ + "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]], + }, ) async def list_runtime_tools( diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index e5380d357a..813a0e6079 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -453,6 +453,7 @@ def test_rag_agent(llama_stack_client, agent_config): vector_db_id=vector_db_id, embedding_model="all-MiniLM-L6-v2", embedding_dimension=384, + provider_id="faiss", ) llama_stack_client.tool_runtime.rag_tool.insert( documents=documents, @@ -488,11 +489,13 @@ def test_rag_agent(llama_stack_client, agent_config): response = rag_agent.create_turn( messages=[{"role": "user", "content": prompt}], session_id=session_id, + stream=False, ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] - logs_str = "".join(logs) - assert "Tool:query_from_memory" in logs_str - assert expected_kw in logs_str.lower() + # rag is called + assert response.steps[0].tool_calls[0].tool_name == "query_from_memory" + # document ids are present in metadata + assert "num-0" in response.steps[0].tool_responses[0].metadata["document_ids"] + assert expected_kw in response.output_message.content def test_rag_and_code_agent(llama_stack_client, agent_config):