Skip to content

Commit

Permalink
memory
Browse files Browse the repository at this point in the history
Summary:

Test Plan:
  • Loading branch information
ehhuang committed Feb 8, 2025
1 parent 7766e68 commit 6be512a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -381,92 +381,6 @@ async def _run(
if documents:
await self.handle_documents(session_id, documents, input_messages, tool_defs)

if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)

args = toolgroup_args.get(RAG_TOOL_GROUP, {})
vector_db_ids = args.get("vector_db_ids", [])
query_config = args.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()

session_info = await self.storage.get_session_info(session_id)

# if the session has a memory bank id, let the memory tool use it
if session_info.vector_db_id:
vector_db_ids.append(session_info.vector_db_id)

yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.succeeded,
tool_call=ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
arguments={},
),
),
)
)
)
result = await self.tool_runtime_api.rag_tool.query(
content=concat_interleaved_content([msg.content for msg in input_messages]),
vector_db_ids=vector_db_ids,
query_config=query_config,
)
retrieved_context = result.content

yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[
ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
arguments={},
)
],
tool_responses=[
ToolResponse(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
content=retrieved_context or [],
)
],
),
)
)
)
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
span.set_attribute("output", retrieved_context)
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)

# append retrieved_context to the last user message
for message in input_messages[::-1]:
if isinstance(message, UserMessage):
message.context = retrieved_context
break

output_attachments = []

n_iter = 0
Expand All @@ -493,9 +407,7 @@ async def _run(
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=[
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
],
tools=[tool for tool in tool_defs.values()],
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
response_format=self.agent_config.response_format,
stream=True,
Expand Down Expand Up @@ -601,6 +513,7 @@ async def _run(
else:
log.info(f"{str(message)}")
tool_call = message.tool_calls[0]
# yield control back to the client to handle the client tool call
if tool_call.tool_name in client_tools:
yield message
return
Expand Down Expand Up @@ -672,8 +585,9 @@ async def _run(

# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially

if out_attachment := _interpret_content_as_attachment(result_message.content):
if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(result_message.content)
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
Expand Down
42 changes: 30 additions & 12 deletions llama_stack/providers/inline/tool_runtime/rag/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
RAGToolRuntime,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from pydantic import TypeAdapter
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
Expand Down Expand Up @@ -151,21 +153,37 @@ async def query(
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
# Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return [
ToolDef(
name="query_from_memory",
description="Retrieve context from memory",
),
ToolDef(
name="insert_into_memory",
description="Insert documents into memory",
),
name="search",
description="Search for information in a database",
parameters=[
ToolParameter(
name="query",
description="The query to search for. Can be a natural language sentence or keywords.",
parameter_type="string",
),
],
)
]

async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
raise RuntimeError(
"This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol"
vector_db_ids = kwargs.get("vector_db_ids", [])
query_config = kwargs.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()

query = kwargs["query"]
result = await self.query(
content=query,
vector_db_ids=vector_db_ids,
query_config=query_config,
)
retrieved_context = result.content

return ToolInvocationResult(
content=retrieved_context,
)

0 comments on commit 6be512a

Please sign in to comment.