diff --git a/examples/python-demo/main-stream.py b/examples/python-demo/main-stream.py new file mode 100644 index 00000000..83b14dd0 --- /dev/null +++ b/examples/python-demo/main-stream.py @@ -0,0 +1,167 @@ + +import uuid +import asyncio +from typing import Optional, List, Dict, Any +import json +import sys + +from tools import weather_tool + +from multi_agent_orchestrator.orchestrator import MultiAgentOrchestrator, OrchestratorConfig +from multi_agent_orchestrator.agents import (BedrockLLMAgent, + BedrockLLMAgentOptions, + AgentResponse, + AgentStreamResponse, + AgentCallbacks) +from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole +from multi_agent_orchestrator.utils import AgentTools + +class LLMAgentCallbacks(AgentCallbacks): + def on_llm_new_token(self, token: str) -> None: + print(token, end='', flush=True) + + +async def handle_request(_orchestrator: MultiAgentOrchestrator, _user_input:str, _user_id:str, _session_id:str): + stream_response = True + response:AgentResponse = await _orchestrator.route_request(_user_input, _user_id, _session_id, {}, stream_response) + + # Print metadata + print("\nMetadata:") + print(f"Selected Agent: {response.metadata.agent_name}") + if stream_response and response.streaming: + async for chunk in response.output: + if isinstance(chunk, AgentStreamResponse): + if response.streaming: + print(chunk.text, end='', flush=True) + else: + if isinstance(response.output, ConversationMessage): + print(response.output.content[0]['text']) + elif isinstance(response.output, str): + print(response.output) + else: + print(response.output) + +def custom_input_payload_encoder(input_text: str, + chat_history: List[Any], + user_id: str, + session_id: str, + additional_params: Optional[Dict[str, str]] = None) -> str: + return json.dumps({ + 'hello':'world' + }) + +def custom_output_payload_decoder(response: Dict[str, Any]) -> Any: + decoded_response = json.loads( + json.loads( + response['Payload'].read().decode('utf-8') + )['body'])['response'] + return ConversationMessage( + role=ParticipantRole.ASSISTANT.value, + content=[{'text': decoded_response}] + ) + +if __name__ == "__main__": + + # Initialize the orchestrator with some options + orchestrator = MultiAgentOrchestrator(options=OrchestratorConfig( + LOG_AGENT_CHAT=True, + LOG_CLASSIFIER_CHAT=True, + LOG_CLASSIFIER_RAW_OUTPUT=True, + LOG_CLASSIFIER_OUTPUT=True, + LOG_EXECUTION_TIMES=True, + MAX_RETRIES=3, + USE_DEFAULT_AGENT_IF_NONE_IDENTIFIED=True, + MAX_MESSAGE_PAIRS_PER_AGENT=10, + )) + + # Add some agents + tech_agent = BedrockLLMAgent(BedrockLLMAgentOptions( + name="Tech Agent", + streaming=True, + description="Specializes in technology areas including software development, hardware, AI, \ + cybersecurity, blockchain, cloud computing, emerging tech innovations, and pricing/costs \ + related to technology products and services.", + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + # callbacks=LLMAgentCallbacks() + )) + orchestrator.add_agent(tech_agent) + + # Add some agents + tech_agent = BedrockLLMAgent(BedrockLLMAgentOptions( + name="Health Agent", + streaming=False, + description="Specializes in health and well being.", + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + )) + orchestrator.add_agent(tech_agent) + + # Add a Anthropic weather agent with a tool in anthropic's tool format + # weather_agent = AnthropicAgent(AnthropicAgentOptions( + # api_key='api-key', + # name="Weather Agent", + # streaming=False, + # description="Specialized agent for giving weather condition from a city.", + # tool_config={ + # 'tool': [tool.to_claude_format() for tool in weather_tool.weather_tools.tools], + # 'toolMaxRecursions': 5, + # 'useToolHandler': weather_tool.anthropic_weather_tool_handler + # }, + # callbacks=LLMAgentCallbacks() + # )) + + # Add an Anthropic weather agent with Tools class + # weather_agent = AnthropicAgent(AnthropicAgentOptions( + # api_key='api-key', + # name="Weather Agent", + # streaming=True, + # description="Specialized agent for giving weather condition from a city.", + # tool_config={ + # 'tool': weather_tool.weather_tools, + # 'toolMaxRecursions': 5, + # }, + # callbacks=LLMAgentCallbacks() + # )) + + # Add a Bedrock weather agent with Tools class + # weather_agent = BedrockLLMAgent(BedrockLLMAgentOptions( + # name="Weather Agent", + # streaming=False, + # description="Specialized agent for giving weather condition from a city.", + # tool_config={ + # 'tool': weather_tool.weather_tools, + # 'toolMaxRecursions': 5, + # }, + # callbacks=LLMAgentCallbacks(), + # )) + + # Add a Bedrock weather agent with custom handler and bedrock's tool format + weather_agent = BedrockLLMAgent(BedrockLLMAgentOptions( + name="Weather Agent", + streaming=False, + description="Specialized agent for giving weather condition from a city.", + tool_config={ + 'tool': [tool.to_bedrock_format() for tool in weather_tool.weather_tools.tools], + 'toolMaxRecursions': 5, + 'useToolHandler': weather_tool.bedrock_weather_tool_handler + } + )) + + + weather_agent.set_system_prompt(weather_tool.weather_tool_prompt) + orchestrator.add_agent(weather_agent) + + USER_ID = "user123" + SESSION_ID = str(uuid.uuid4()) + + print("Welcome to the interactive Multi-Agent system. Type 'quit' to exit.") + + while True: + # Get user input + user_input = input("\nYou: ").strip() + + if user_input.lower() == 'quit': + print("Exiting the program. Goodbye!") + sys.exit() + + # Run the async function + asyncio.run(handle_request(orchestrator, user_input, USER_ID, SESSION_ID)) diff --git a/python/src/multi_agent_orchestrator/agents/__init__.py b/python/src/multi_agent_orchestrator/agents/__init__.py index a96a151e..12c3ea4a 100644 --- a/python/src/multi_agent_orchestrator/agents/__init__.py +++ b/python/src/multi_agent_orchestrator/agents/__init__.py @@ -1,7 +1,7 @@ """ Code for Agents. """ -from .agent import Agent, AgentOptions, AgentCallbacks, AgentProcessingResult, AgentResponse +from .agent import Agent, AgentOptions, AgentCallbacks, AgentProcessingResult, AgentResponse, AgentStreamResponse try: @@ -32,16 +32,16 @@ from .supervisor_agent import SupervisorAgent, SupervisorAgentOptions - __all__ = [ 'Agent', 'AgentOptions', 'AgentCallbacks', 'AgentProcessingResult', 'AgentResponse', + 'AgentStreamResponse', 'SupervisorAgent', 'SupervisorAgentOptions' - ] +] if _AWS_AVAILABLE : diff --git a/python/src/multi_agent_orchestrator/agents/agent.py b/python/src/multi_agent_orchestrator/agents/agent.py index ec6badd6..09363400 100644 --- a/python/src/multi_agent_orchestrator/agents/agent.py +++ b/python/src/multi_agent_orchestrator/agents/agent.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union, AsyncIterable, Optional, Any +from typing import Union, AsyncIterable, Optional, Any from abc import ABC, abstractmethod from dataclasses import dataclass, field from multi_agent_orchestrator.types import ConversationMessage @@ -12,13 +12,19 @@ class AgentProcessingResult: agent_name: str user_id: str session_id: str - additional_params: Dict[str, Any] = field(default_factory=dict) + additional_params: dict[str, Any] = field(default_factory=dict) + + +class AgentStreamResponse: + def __init__(self, text: str = '', final_message: ConversationMessage = None): + self.text = text + self.final_message = final_message @dataclass class AgentResponse: metadata: AgentProcessingResult - output: Union[Any, str] + output: Union[Any, str, AgentStreamResponse] streaming: bool @@ -72,9 +78,10 @@ async def process_request( input_text: str, user_id: str, session_id: str, - chat_history: List[ConversationMessage], - additional_params: Optional[Dict[str, str]] = None, + chat_history: list[ConversationMessage], + additional_params: Optional[dict[str, str]] = None, ) -> Union[ConversationMessage, AsyncIterable[Any]]: + pass def log_debug(self, class_name, message, data=None): diff --git a/python/src/multi_agent_orchestrator/agents/amazon_bedrock_agent.py b/python/src/multi_agent_orchestrator/agents/amazon_bedrock_agent.py index 7bcea532..7652e281 100644 --- a/python/src/multi_agent_orchestrator/agents/amazon_bedrock_agent.py +++ b/python/src/multi_agent_orchestrator/agents/amazon_bedrock_agent.py @@ -6,12 +6,12 @@ AWS Bedrock's agent runtime capabilities. """ -from typing import Dict, List, Optional, Any +from typing import Any, Optional from dataclasses import dataclass import os import boto3 from botocore.exceptions import BotoCoreError, ClientError -from multi_agent_orchestrator.agents import Agent, AgentOptions +from multi_agent_orchestrator.agents import Agent, AgentOptions, AgentStreamResponse from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole from multi_agent_orchestrator.utils import Logger @@ -31,9 +31,9 @@ class AmazonBedrockAgentOptions(AgentOptions): region: Optional[str] = None agent_id: str = None agent_alias_id: str = None - client: Optional[Any] = None - streaming: Optional[bool] = False - enableTrace: Optional[bool] = False + client: Any | None = None + streaming: bool | None = False + enableTrace: bool | None = False class AmazonBedrockAgent(Agent): @@ -88,8 +88,8 @@ async def process_request( input_text: str, user_id: str, session_id: str, - chat_history: List[ConversationMessage], - additional_params: Optional[Dict[str, str]] = None + chat_history: list[ConversationMessage], + additional_params: dict[str, str] | None = None ) -> ConversationMessage: """ Process a user request through the Bedrock agent runtime. @@ -129,32 +129,39 @@ async def process_request( streamingConfigurations=streamingConfigurations if self.streaming else {} ) - # Process response, handling both streaming and non-streaming modes completion = "" - for event in response['completion']: - if 'chunk' in event: - # Process streaming chunk - chunk = event['chunk'] - decoded_response = chunk['bytes'].decode('utf-8') - - # Trigger callback for each token (useful for real-time updates) - self.callbacks.on_llm_new_token(decoded_response) - completion += decoded_response - - elif 'trace' in event: - # Log trace events if tracing is enabled - Logger.info(f"Received event: {event}") if self.enableTrace else None - - else: - # Ignore unrecognized event types - pass - - # Construct and return the conversation message - return ConversationMessage( - role=ParticipantRole.ASSISTANT.value, - content=[{"text": completion}] - ) + if self.streaming: + async def generate_chunks(): + nonlocal completion + for event in response['completion']: + if 'chunk' in event: + chunk = event['chunk'] + decoded_response = chunk['bytes'].decode('utf-8') + self.callbacks.on_llm_new_token(decoded_response) + completion += decoded_response + yield AgentStreamResponse(text=decoded_response) + elif 'trace' in event and self.enableTrace: + Logger.info(f"Received event: {event}") + yield AgentStreamResponse( + final_message=ConversationMessage( + role=ParticipantRole.ASSISTANT.value, + content=[{'text':completion}])) + return generate_chunks() + else: + for event in response['completion']: + if 'chunk' in event: + chunk = event['chunk'] + decoded_response = chunk['bytes'].decode('utf-8') + self.callbacks.on_llm_new_token(decoded_response) + completion += decoded_response + elif 'trace' in event and self.enableTrace: + Logger.info(f"Received event: {event}") + + return ConversationMessage( + role=ParticipantRole.ASSISTANT.value, + content=[{"text": completion}] + ) except (BotoCoreError, ClientError) as error: # Comprehensive error logging and propagation Logger.error(f"Error processing request: {str(error)}") diff --git a/python/src/multi_agent_orchestrator/agents/anthropic_agent.py b/python/src/multi_agent_orchestrator/agents/anthropic_agent.py index 2dc509ef..f489863d 100644 --- a/python/src/multi_agent_orchestrator/agents/anthropic_agent.py +++ b/python/src/multi_agent_orchestrator/agents/anthropic_agent.py @@ -1,14 +1,15 @@ import json -from typing import List, Dict, Any, AsyncIterable, Optional, Union +from typing import AsyncIterable, Optional, Any, AsyncGenerator +from typing import Any, AsyncIterable, Optional from dataclasses import dataclass, field import re from anthropic import AsyncAnthropic, Anthropic -from multi_agent_orchestrator.agents import Agent, AgentOptions +from multi_agent_orchestrator.agents import Agent, AgentOptions, AgentStreamResponse from multi_agent_orchestrator.types import (ConversationMessage, ParticipantRole, TemplateVariables, AgentProviderType) -from multi_agent_orchestrator.utils import Logger, AgentTools +from multi_agent_orchestrator.utils import Logger, AgentTools, AgentTool from multi_agent_orchestrator.retrievers import Retriever @dataclass @@ -17,10 +18,10 @@ class AnthropicAgentOptions(AgentOptions): client: Optional[Any] = None model_id: str = "claude-3-5-sonnet-20240620" streaming: Optional[bool] = False - inference_config: Optional[Dict[str, Any]] = None + inference_config: Optional[dict[str, Any]] = None retriever: Optional[Retriever] = None - tool_config: Optional[Union[dict[str, Any], AgentTools]] = None - custom_system_prompt: Optional[Dict[str, Any]] = None + tool_config: Optional[dict[str, Any] | AgentTools] = None + custom_system_prompt: Optional[dict[str, Any]] = None @@ -100,27 +101,51 @@ def __init__(self, options: AnthropicAgentOptions): def is_streaming_enabled(self) -> bool: return self.streaming is True - async def process_request( + async def _prepare_system_prompt(self, input_text: str) -> str: + """Prepare the system prompt with optional retrieval context.""" + + self.update_system_prompt() + system_prompt = self.system_prompt + + if self.retriever: + response = await self.retriever.retrieve_and_combine_results(input_text) + system_prompt += f"\nHere is the context to use to answer the user's question:\n{response}" + + return system_prompt + + def _prepare_conversation( self, input_text: str, - user_id: str, - session_id: str, - chat_history: List[ConversationMessage], - additional_params: Optional[Dict[str, str]] = None - ) -> Union[ConversationMessage, AsyncIterable[Any]]: + chat_history: list[ConversationMessage] + ) -> list[Any]: + """Prepare the conversation history with the new user message.""" messages = [{"role": "user" if msg.role == ParticipantRole.USER.value else "assistant", "content": msg.content[0]['text'] if msg.content else ''} for msg in chat_history] messages.append({"role": "user", "content": input_text}) - self.update_system_prompt() - system_prompt = self.system_prompt + return messages - if self.retriever: - response = await self.retriever.retrieve_and_combine_results(input_text) - context_prompt = f"\nHere is the context to use to answer the user's question:\n{response}" - system_prompt += context_prompt + def _prepare_tool_config(self) -> dict: + """Prepare tool configuration based on the tool type.""" + + if isinstance(self.tool_config["tool"], AgentTools): + return self.tool_config["tool"].to_claude_format() + + if isinstance(self.tool_config["tool"], list): + return [ + tool.to_claude_format() if isinstance(tool, AgentTool) else tool + for tool in self.tool_config['tool'] + ] + raise RuntimeError("Invalid tool config") + + def _build_input( + self, + messages: list[Any], + system_prompt: str + ) -> dict: + """Build the conversation command with all necessary configurations.""" input = { "model": self.model_id, "max_tokens": self.inference_config.get('maxTokens'), @@ -131,63 +156,114 @@ async def process_request( "stop_sequences": self.inference_config.get('stopSequences'), } - try: - if self.tool_config: - tools = self.tool_config["tool"] if not isinstance(self.tool_config["tool"], AgentTools) else self.tool_config["tool"].to_claude_format() - - input['tools'] = tools - final_message = '' - tool_use = True - recursions = self.tool_config.get('toolMaxRecursions', self.default_max_recursions) - - while tool_use and recursions > 0: - if self.streaming: - response = await self.handle_streaming_response(input) - else: - response = await self.handle_single_response(input) - - tool_use_blocks = [content for content in response.content if content.type == 'tool_use'] - - if tool_use_blocks: - input['messages'].append({"role": "assistant", "content": response.content}) - if not self.tool_config: - raise ValueError("No tools available for tool use") - - if self.tool_config.get('useToolHandler'): - # user is handling the tool blocks itself - tool_response = await self.tool_config['useToolHandler'](response, input['messages']) - else: - tools:AgentTools = self.tool_config["tool"] - # no handler has been provided, we can use the default implementation - tool_response = await tools.tool_handler(AgentProviderType.ANTHROPIC.value, response, messages) - input['messages'].append(tool_response) - tool_use = True - else: - text_content = next((content for content in response.content if content.type == 'text'), None) - final_message = text_content.text if text_content else '' - - if response.stop_reason == 'end_turn': - tool_use = False - - recursions -= 1 - - return ConversationMessage(role=ParticipantRole.ASSISTANT.value, content=[{'text': final_message}]) - else: - if self.streaming: - response = await self.handle_streaming_response(input) + if self.tool_config: + input["tools"] = self._prepare_tool_config() + + return input + + def _get_max_recursions(self) -> int: + """Get the maximum number of recursions based on tool configuration.""" + if not self.tool_config: + return 1 + return self.tool_config.get('toolMaxRecursions', self.default_max_recursions) + + async def _handle_streaming( + self, + input: dict, + messages: list[Any], + max_recursions: int + ) -> AsyncIterable[Any]: + """Handle streaming response processing with tool recursion.""" + continue_with_tools = True + final_response = None + + async def stream_generator(): + nonlocal continue_with_tools, final_response, max_recursions + + while continue_with_tools and max_recursions > 0: + response = self.handle_streaming_response(input) + + async for chunk in response: + if chunk.final_message: + final_response = chunk.final_message + yield chunk + + if any('toolUse' in content for content in final_response.content): + tool_response = await self._process_tool_block(final_response, messages) + input['messages'].append(tool_response) else: - response = await self.handle_single_response(input) + continue_with_tools = False - return ConversationMessage( - role=ParticipantRole.ASSISTANT.value, - content=[{'text':response.content[0].text}] - ) + max_recursions -= 1 - except Exception as error: - Logger.error(f"Error processing request: {error}") - raise error + return stream_generator() + + async def _process_with_strategy( + self, + streaming: bool, + input: dict, + messages: list[Any] + ) -> ConversationMessage | AsyncIterable[Any]: + """Process the request using the specified strategy.""" + + max_recursions = self._get_max_recursions() + + if streaming: + return await self._handle_streaming(input, messages, max_recursions) + return await self._handle_single_response_loop(input, messages, max_recursions) + + async def _process_tool_block(self, llm_response: Any, conversation: list[Any]) -> (Any): + if 'useToolHandler' in self.tool_config: + # tool process logic is handled elsewhere + tool_response = await self.tool_config['useToolHandler'](llm_response, conversation) + else: + # tool process logic is handled in AgentTools class + if isinstance(self.tool_config['tool'], AgentTools): + tool_response = await self.tool_config['tool'].tool_handler(AgentProviderType.ANTHROPIC.value, llm_response, conversation) + else: + raise ValueError("You must use class when not providing a custom tool handler") + return tool_response + + async def _handle_single_response_loop( + self, + input: Any, + messages: list[Any], + max_recursions: int + ) -> ConversationMessage: + """Handle single response processing with tool recursion.""" + + continue_with_tools = True + llm_response = None + + while continue_with_tools and max_recursions > 0: + llm_response = await self.handle_single_response(input) + if any('tool_use' in content.type for content in llm_response.content): + input['messages'].append({"role": "assistant", "content": llm_response.content}) + tool_response = await self._process_tool_block(llm_response, messages) + input['messages'].append(tool_response) + else: + continue_with_tools = False + + max_recursions -= 1 + + return ConversationMessage(role=ParticipantRole.ASSISTANT.value, content=[{"text": llm_response.content[0].text}]) + + async def process_request( + self, + input_text: str, + user_id: str, + session_id: str, + chat_history: list[ConversationMessage], + additional_params: Optional[dict[str, str]] = None + ) -> ConversationMessage | AsyncIterable[Any]: + + messages = self._prepare_conversation(input_text, chat_history) + system_prompt = await self._prepare_system_prompt(input_text) + input = self._build_input(messages, system_prompt) + + return await self._process_with_strategy(self.streaming, input, messages) - async def handle_single_response(self, input_data: Dict) -> Any: + async def handle_single_response(self, input_data: dict) -> Any: try: response = self.client.messages.create(**input_data) return response @@ -195,7 +271,7 @@ async def handle_single_response(self, input_data: Dict) -> Any: Logger.error(f"Error invoking Anthropic: {error}") raise error - async def handle_streaming_response(self, input) -> Any: + async def handle_streaming_response(self, input) -> AsyncGenerator[AgentStreamResponse, None]: message = {} content = [] accumulated = {} @@ -206,6 +282,7 @@ async def handle_streaming_response(self, input) -> Any: async for event in stream: if event.type == "text": self.callbacks.on_llm_new_token(event.text) + yield AgentStreamResponse(text=event.text) elif event.type == "input_json": message['input'] = json.loads(event.partial_json) elif event.type == "content_block_stop": @@ -216,7 +293,9 @@ async def handle_streaming_response(self, input) -> Any: # the context manager, as long as the entire stream was consumed # inside of the context manager accumulated = await stream.get_final_message() - return accumulated + yield AgentStreamResponse( + final_message=ConversationMessage(role=ParticipantRole.ASSISTANT.value, + content=[{"text": accumulated.content[0].text}])) except Exception as error: Logger.error(f"Error getting stream from Anthropic model: {str(error)}") diff --git a/python/src/multi_agent_orchestrator/agents/bedrock_llm_agent.py b/python/src/multi_agent_orchestrator/agents/bedrock_llm_agent.py index 724501d4..06c55b04 100644 --- a/python/src/multi_agent_orchestrator/agents/bedrock_llm_agent.py +++ b/python/src/multi_agent_orchestrator/agents/bedrock_llm_agent.py @@ -1,16 +1,16 @@ -from typing import Any, AsyncIterable, Optional, Union +from typing import Any, Optional, AsyncGenerator, AsyncIterable from dataclasses import dataclass import re import json import os import boto3 -from multi_agent_orchestrator.agents import Agent, AgentOptions +from multi_agent_orchestrator.agents import Agent, AgentOptions, AgentStreamResponse from multi_agent_orchestrator.types import (ConversationMessage, ParticipantRole, BEDROCK_MODEL_ID_CLAUDE_3_HAIKU, TemplateVariables, AgentProviderType) -from multi_agent_orchestrator.utils import conversation_to_dict, Logger, AgentTools +from multi_agent_orchestrator.utils import conversation_to_dict, Logger, AgentTools, AgentTool from multi_agent_orchestrator.retrievers import Retriever @@ -22,7 +22,7 @@ class BedrockLLMAgentOptions(AgentOptions): inference_config: Optional[dict[str, Any]] = None guardrail_config: Optional[dict[str, str]] = None retriever: Optional[Retriever] = None - tool_config: Optional[Union[dict[str, Any], AgentTools]] = None + tool_config: dict[str, Any] | AgentTools | None = None custom_system_prompt: Optional[dict[str, Any]] = None client: Optional[Any] = None @@ -95,32 +95,39 @@ def __init__(self, options: BedrockLLMAgentOptions): def is_streaming_enabled(self) -> bool: return self.streaming is True - async def process_request( + async def _prepare_system_prompt(self, input_text: str) -> str: + """Prepare the system prompt with optional retrieval context.""" + + self.update_system_prompt() + system_prompt = self.system_prompt + + if self.retriever: + response = await self.retriever.retrieve_and_combine_results(input_text) + system_prompt += f"\nHere is the context to use to answer the user's question:\n{response}" + + return system_prompt + + def _prepare_conversation( self, input_text: str, - user_id: str, - session_id: str, - chat_history: list[ConversationMessage], - additional_params: Optional[dict[str, str]] = None - ) -> Union[ConversationMessage, AsyncIterable[Any]]: + chat_history: list[ConversationMessage] + ) -> list[ConversationMessage]: + """Prepare the conversation history with the new user message.""" user_message = ConversationMessage( role=ParticipantRole.USER.value, content=[{'text': input_text}] ) + return [*chat_history, user_message] - conversation = [*chat_history, user_message] - - self.update_system_prompt() + def _build_conversation_command( + self, + conversation: list[ConversationMessage], + system_prompt: str + ) -> dict: + """Build the conversation command with all necessary configurations.""" - system_prompt = self.system_prompt - - if self.retriever: - response = await self.retriever.retrieve_and_combine_results(input_text) - context_prompt = "\nHere is the context to use to answer the user's question:\n" + response - system_prompt += context_prompt - - converse_cmd = { + command = { 'modelId': self.model_id, 'messages': conversation_to_dict(conversation), 'system': [{'text': system_prompt}], @@ -133,54 +140,146 @@ async def process_request( } if self.guardrail_config: - converse_cmd["guardrailConfig"] = self.guardrail_config + command["guardrailConfig"] = self.guardrail_config if self.tool_config: - converse_cmd["toolConfig"] = { - 'tools': self.tool_config["tool"] if not isinstance(self.tool_config["tool"], AgentTools) else self.tool_config["tool"].to_bedrock_format() + command["toolConfig"] = self._prepare_tool_config() + + return command + + def _prepare_tool_config(self) -> dict: + """Prepare tool configuration based on the tool type.""" + + if isinstance(self.tool_config["tool"], AgentTools): + return {'tools': self.tool_config["tool"].to_bedrock_format()} + + if isinstance(self.tool_config["tool"], list): + return { + 'tools': [ + tool.to_bedrock_format() if isinstance(tool, AgentTool) else tool + for tool in self.tool_config['tool'] + ] } - if self.tool_config: - continue_with_tools = True - final_message: ConversationMessage = {'role': ParticipantRole.USER.value, 'content': []} - max_recursions = self.tool_config.get('toolMaxRecursions', self.default_max_recursions) + raise RuntimeError("Invalid tool config") + + def _get_max_recursions(self) -> int: + """Get the maximum number of recursions based on tool configuration.""" + if not self.tool_config: + return 1 + return self.tool_config.get('toolMaxRecursions', self.default_max_recursions) + + async def _handle_single_response_loop( + self, + command: dict, + conversation: list[ConversationMessage], + max_recursions: int + ) -> ConversationMessage: + """Handle single response processing with tool recursion.""" + + continue_with_tools = True + llm_response = None + + while continue_with_tools and max_recursions > 0: + llm_response = await self.handle_single_response(command) + conversation.append(llm_response) + + if any('toolUse' in content for content in llm_response.content): + tool_response = await self._process_tool_block(llm_response, conversation) + conversation.append(tool_response) + command['messages'] = conversation_to_dict(conversation) + else: + continue_with_tools = False + + max_recursions -= 1 + + return llm_response + + async def _handle_streaming( + self, + command: dict, + conversation: list[ConversationMessage], + max_recursions: int + ) -> AsyncIterable[Any]: + """Handle streaming response processing with tool recursion.""" + continue_with_tools = True + final_response = None + + async def stream_generator(): + nonlocal continue_with_tools, final_response, max_recursions while continue_with_tools and max_recursions > 0: - if self.streaming: - bedrock_response = await self.handle_streaming_response(converse_cmd) - else: - bedrock_response = await self.handle_single_response(converse_cmd) + response = self.handle_streaming_response(command) - conversation.append(bedrock_response) + async for chunk in response: + if chunk.final_message: + final_response = chunk.final_message + yield chunk + + conversation.append(final_response) + + if any('toolUse' in content for content in final_response.content): + tool_response = await self._process_tool_block(final_response, conversation) - if any('toolUse' in content for content in bedrock_response.content): - if 'useToolHandler' in self.tool_config: - # user is handling the tool blocks itself - tool_response = await self.tool_config['useToolHandler'](bedrock_response, conversation) - else: - tools:AgentTools = self.tool_config["tool"] - # no handler has been provided, we can use the default implementation - tool_response = await tools.tool_handler(AgentProviderType.BEDROCK.value, bedrock_response, conversation) conversation.append(tool_response) + command['messages'] = conversation_to_dict(conversation) else: continue_with_tools = False - final_message = bedrock_response max_recursions -= 1 - converse_cmd['messages'] = conversation_to_dict(conversation) - return final_message + return stream_generator() + + async def _process_with_strategy( + self, + streaming: bool, + command: dict, + conversation: list[ConversationMessage] + ) -> ConversationMessage | AsyncIterable[Any]: + """Process the request using the specified strategy.""" + + max_recursions = self._get_max_recursions() + + if streaming: + return await self._handle_streaming(command, conversation, max_recursions) + return await self._handle_single_response_loop(command, conversation, max_recursions) + + async def process_request( + self, + input_text: str, + user_id: str, + session_id: str, + chat_history: list[ConversationMessage], + additional_params: Optional[dict[str, str]] = None + ) -> ConversationMessage | AsyncIterable[Any]: + """ + Process a conversation request either in streaming or single response mode. + """ + conversation = self._prepare_conversation(input_text, chat_history) + system_prompt = await self._prepare_system_prompt(input_text) - if self.streaming: - return await self.handle_streaming_response(converse_cmd) + command = self._build_conversation_command(conversation, system_prompt) - return await self.handle_single_response(converse_cmd) + return await self._process_with_strategy(self.streaming, command, conversation) + + async def _process_tool_block(self, llm_response: ConversationMessage, conversation: list[ConversationMessage]) -> (ConversationMessage): + if 'useToolHandler' in self.tool_config: + # tool process logic is handled elsewhere + tool_response = await self.tool_config['useToolHandler'](llm_response, conversation) + else: + # tool process logic is handled in AgentTools class + if isinstance(self.tool_config['tool'], AgentTools): + tool_response = await self.tool_config['tool'].tool_handler(AgentProviderType.BEDROCK.value, llm_response, conversation) + else: + raise ValueError("You must use AgentTools class when not providing a custom tool handler") + return tool_response async def handle_single_response(self, converse_input: dict[str, Any]) -> ConversationMessage: try: response = self.client.converse(**converse_input) if 'output' not in response: raise ValueError("No output received from Bedrock model") + return ConversationMessage( role=response['output']['message']['role'], content=response['output']['message']['content'] @@ -189,7 +288,20 @@ async def handle_single_response(self, converse_input: dict[str, Any]) -> Conver Logger.error(f"Error invoking Bedrock model:{str(error)}") raise error - async def handle_streaming_response(self, converse_input: dict[str, Any]) -> ConversationMessage: + async def handle_streaming_response( + self, + converse_input: dict[str, Any] + ) -> AsyncGenerator[AgentStreamResponse, None]: + """ + Handle streaming response from Bedrock model. + Yields StreamChunk objects containing either text chunks or the final message. + + Args: + converse_input: Input for the conversation + + Yields: + StreamChunk: Contains either a text chunk or the final complete message + """ try: response = self.client.converse_stream(**converse_input) @@ -199,7 +311,6 @@ async def handle_streaming_response(self, converse_input: dict[str, Any]) -> Con text = '' tool_use = {} - #stream the response into a message. for chunk in response['stream']: if 'messageStart' in chunk: message['role'] = chunk['messageStart']['role'] @@ -216,6 +327,8 @@ async def handle_streaming_response(self, converse_input: dict[str, Any]) -> Con elif 'text' in delta: text += delta['text'] self.callbacks.on_llm_new_token(delta['text']) + # yield the text chunk + yield AgentStreamResponse(text=delta['text']) elif 'contentBlockStop' in chunk: if 'input' in tool_use: tool_use['input'] = json.loads(tool_use['input']) @@ -224,10 +337,13 @@ async def handle_streaming_response(self, converse_input: dict[str, Any]) -> Con else: content.append({'text': text}) text = '' - return ConversationMessage( + + final_message = ConversationMessage( role=ParticipantRole.ASSISTANT.value, content=message['content'] ) + # yield the final message + yield AgentStreamResponse(final_message=final_message) except Exception as error: Logger.error(f"Error getting stream from Bedrock model: {str(error)}") diff --git a/python/src/multi_agent_orchestrator/agents/chain_agent.py b/python/src/multi_agent_orchestrator/agents/chain_agent.py index a008aa3e..8139ad7b 100644 --- a/python/src/multi_agent_orchestrator/agents/chain_agent.py +++ b/python/src/multi_agent_orchestrator/agents/chain_agent.py @@ -1,10 +1,10 @@ -from typing import List, Dict, Union, AsyncIterable, Optional, Any +from typing import Union, AsyncIterable, Optional, Any from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole from multi_agent_orchestrator.utils.logger import Logger from .agent import Agent, AgentOptions class ChainAgentOptions(AgentOptions): - def __init__(self, agents: List[Agent], default_output: Optional[str] = None, **kwargs): + def __init__(self, agents: list[Agent], default_output: Optional[str] = None, **kwargs): super().__init__(**kwargs) self.agents = agents self.default_output = default_output @@ -22,8 +22,8 @@ async def process_request( input_text: str, user_id: str, session_id: str, - chat_history: List[ConversationMessage], - additional_params: Optional[Dict[str, str]] = None + chat_history: list[ConversationMessage], + additional_params: Optional[dict[str, str]] = None ) -> Union[ConversationMessage, AsyncIterable[Any]]: current_input = input_text final_response: Union[ConversationMessage, AsyncIterable[Any]] diff --git a/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py b/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py index 4d7f07f0..d21d4a6f 100644 --- a/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py +++ b/python/src/multi_agent_orchestrator/agents/comprehend_filter_agent.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Union, Optional, Callable, Any +from typing import Optional, Callable, Any from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole from multi_agent_orchestrator.utils.logger import Logger from .agent import Agent, AgentOptions @@ -34,7 +34,7 @@ def __init__(self, options: ComprehendFilterAgentOptions): config = Config(region_name=options.region) if options.region else None self.comprehend_client = boto3.client('comprehend', config=config) - self.custom_checks: List[CheckFunction] = [] + self.custom_checks: list[CheckFunction] = [] self.enable_sentiment_check = options.enable_sentiment_check self.enable_pii_check = options.enable_pii_check @@ -52,10 +52,10 @@ async def process_request(self, input_text: str, user_id: str, session_id: str, - chat_history: List[ConversationMessage], - additional_params: Optional[Dict[str, str]] = None) -> Optional[ConversationMessage]: + chat_history: list[ConversationMessage], + additional_params: Optional[dict[str, str]] = None) -> Optional[ConversationMessage]: try: - issues: List[str] = [] + issues: list[str] = [] # Run all checks sentiment_result = self.detect_sentiment(input_text) if self.enable_sentiment_check else None @@ -101,43 +101,43 @@ async def process_request(self, def add_custom_check(self, check: CheckFunction): self.custom_checks.append(check) - def check_sentiment(self, result: Dict[str, Any]) -> Optional[str]: + def check_sentiment(self, result: dict[str, Any]) -> Optional[str]: if result['Sentiment'] == 'NEGATIVE' and result['SentimentScore']['Negative'] > self.sentiment_threshold: return f"Negative sentiment detected ({result['SentimentScore']['Negative']:.2f})" return None - def check_pii(self, result: Dict[str, Any]) -> Optional[str]: + def check_pii(self, result: dict[str, Any]) -> Optional[str]: if not self.allow_pii and result.get('Entities'): return f"PII detected: {', '.join(e['Type'] for e in result['Entities'])}" return None - def check_toxicity(self, result: Dict[str, Any]) -> Optional[str]: + def check_toxicity(self, result: dict[str, Any]) -> Optional[str]: toxic_labels = self.get_toxic_labels(result) if toxic_labels: return f"Toxic content detected: {', '.join(toxic_labels)}" return None - def detect_sentiment(self, text: str) -> Dict[str, Any]: + def detect_sentiment(self, text: str) -> dict[str, Any]: return self.comprehend_client.detect_sentiment( Text=text, LanguageCode=self.language_code ) - def detect_pii_entities(self, text: str) -> Dict[str, Any]: + def detect_pii_entities(self, text: str) -> dict[str, Any]: return self.comprehend_client.detect_pii_entities( Text=text, LanguageCode=self.language_code ) - def detect_toxic_content(self, text: str) -> Dict[str, Any]: + def detect_toxic_content(self, text: str) -> dict[str, Any]: return self.comprehend_client.detect_toxic_content( TextSegments=[{"Text": text}], LanguageCode=self.language_code ) - def get_toxic_labels(self, toxicity_result: Dict[str, Any]) -> List[str]: + def get_toxic_labels(self, toxicity_result: dict[str, Any]) -> list[str]: toxic_labels = [] - for result in toxicity_result.get('ResultList', []): + for result in toxicity_result.get('Resultlist', []): for label in result.get('Labels', []): if label['Score'] > self.toxicity_threshold: toxic_labels.append(label['Name']) diff --git a/python/src/multi_agent_orchestrator/agents/openai_agent.py b/python/src/multi_agent_orchestrator/agents/openai_agent.py index 8b205724..6a3c54a3 100644 --- a/python/src/multi_agent_orchestrator/agents/openai_agent.py +++ b/python/src/multi_agent_orchestrator/agents/openai_agent.py @@ -1,7 +1,11 @@ -from typing import Dict, List, Union, AsyncIterable, Optional, Any +from typing import AsyncIterable, Optional, Any, AsyncGenerator from dataclasses import dataclass from openai import OpenAI -from multi_agent_orchestrator.agents import Agent, AgentOptions +from multi_agent_orchestrator.agents import ( + Agent, + AgentOptions, + AgentStreamResponse +) from multi_agent_orchestrator.types import ( ConversationMessage, ParticipantRole, @@ -18,8 +22,8 @@ class OpenAIAgentOptions(AgentOptions): api_key: str = None model: Optional[str] = None streaming: Optional[bool] = None - inference_config: Optional[Dict[str, Any]] = None - custom_system_prompt: Optional[Dict[str, Any]] = None + inference_config: Optional[dict[str, Any]] = None + custom_system_prompt: Optional[dict[str, Any]] = None retriever: Optional[Retriever] = None client: Optional[Any] = None @@ -30,13 +34,13 @@ def __init__(self, options: OpenAIAgentOptions): super().__init__(options) if not options.api_key: raise ValueError("OpenAI API key is required") - + if options.client: self.client = options.client else: self.client = OpenAI(api_key=options.api_key) - + self.model = options.model or OPENAI_MODEL_ID_GPT_O_MINI self.streaming = options.streaming or False self.retriever: Optional[Retriever] = options.retriever @@ -83,7 +87,7 @@ def __init__(self, options: OpenAIAgentOptions): options.custom_system_prompt.get('template'), options.custom_system_prompt.get('variables') ) - + def is_streaming_enabled(self) -> bool: @@ -94,9 +98,9 @@ async def process_request( input_text: str, user_id: str, session_id: str, - chat_history: List[ConversationMessage], - additional_params: Optional[Dict[str, str]] = None - ) -> Union[ConversationMessage, AsyncIterable[Any]]: + chat_history: list[ConversationMessage], + additional_params: Optional[dict[str, str]] = None + ) -> ConversationMessage | AsyncIterable[Any]: try: self.update_system_prompt() @@ -129,7 +133,7 @@ async def process_request( "stream": self.streaming } if self.streaming: - return await self.handle_streaming_response(request_options) + return self.handle_streaming_response(request_options) else: return await self.handle_single_response(request_options) @@ -137,7 +141,7 @@ async def process_request( Logger.error(f"Error in OpenAI API call: {str(error)}") raise error - async def handle_single_response(self, request_options: Dict[str, Any]) -> ConversationMessage: + async def handle_single_response(self, request_options: dict[str, Any]) -> ConversationMessage: try: request_options['stream'] = False chat_completion = self.client.chat.completions.create(**request_options) @@ -159,30 +163,29 @@ async def handle_single_response(self, request_options: Dict[str, Any]) -> Conve Logger.error(f'Error in OpenAI API call: {str(error)}') raise error - async def handle_streaming_response(self, request_options: Dict[str, Any]) -> ConversationMessage: + async def handle_streaming_response(self, request_options: dict[str, Any]) -> AsyncGenerator[AgentStreamResponse, None]: try: stream = self.client.chat.completions.create(**request_options) accumulated_message = [] - + for chunk in stream: if chunk.choices[0].delta.content: chunk_content = chunk.choices[0].delta.content accumulated_message.append(chunk_content) - if self.callbacks: - self.callbacks.on_llm_new_token(chunk_content) - #yield chunk_content + self.callbacks.on_llm_new_token(chunk_content) + yield AgentStreamResponse(text=chunk_content) # Store the complete message in the instance for later access if needed - return ConversationMessage( + yield AgentStreamResponse(final_message=ConversationMessage( role=ParticipantRole.ASSISTANT.value, content=[{"text": ''.join(accumulated_message)}] - ) + )) except Exception as error: Logger.error(f"Error getting stream from OpenAI model: {str(error)}") raise error - def set_system_prompt(self, + def set_system_prompt(self, template: Optional[str] = None, variables: Optional[TemplateVariables] = None) -> None: if template: diff --git a/python/src/multi_agent_orchestrator/orchestrator.py b/python/src/multi_agent_orchestrator/orchestrator.py index 49e577c4..0aec2494 100644 --- a/python/src/multi_agent_orchestrator/orchestrator.py +++ b/python/src/multi_agent_orchestrator/orchestrator.py @@ -1,12 +1,16 @@ -from typing import Dict, Any, AsyncIterable, Optional, Union +from typing import Any, AsyncIterable from dataclasses import dataclass, fields, asdict, replace import time from multi_agent_orchestrator.utils.logger import Logger -from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole, OrchestratorConfig +from multi_agent_orchestrator.types import (ConversationMessage, + ParticipantRole, + OrchestratorConfig, + TimestampedMessage) from multi_agent_orchestrator.classifiers import Classifier,ClassifierResult from multi_agent_orchestrator.agents import (Agent, - AgentResponse, - AgentProcessingResult) + AgentStreamResponse, + AgentResponse, + AgentProcessingResult) from multi_agent_orchestrator.storage import ChatStorage from multi_agent_orchestrator.storage import InMemoryChatStorage try: @@ -18,11 +22,11 @@ @dataclass class MultiAgentOrchestrator: def __init__(self, - options: Optional[OrchestratorConfig] = None, - storage: Optional[ChatStorage] = None, - classifier: Optional[Classifier] = None, - logger: Optional[Logger] = None, - default_agent: Optional[Agent] = None): + options: OrchestratorConfig | None = None, + storage: ChatStorage | None = None, + classifier: Classifier | None = None, + logger: Logger | None = None, + default_agent: Agent | None = None): DEFAULT_CONFIG=OrchestratorConfig() @@ -42,7 +46,7 @@ def __init__(self, self.logger = Logger(self.config, logger) - self.agents: Dict[str, Agent] = {} + self.agents: dict[str, Agent] = {} self.storage = storage or InMemoryChatStorage() if classifier: @@ -51,8 +55,8 @@ def __init__(self, self.classifier = BedrockClassifier(options=BedrockClassifierOptions()) else: raise ValueError("No classifier provided and BedrockClassifier is not available. Please provide a classifier.") - - self.execution_times: Dict[str, float] = {} + + self.execution_times: dict[str, float] = {} self.default_agent: Agent = default_agent @@ -70,17 +74,16 @@ def set_default_agent(self, agent: Agent): def set_classifier(self, intent_classifier: Classifier): self.classifier = intent_classifier + self.classifier.set_agents(self.agents) - def get_all_agents(self) -> Dict[str, Dict[str, str]]: + def get_all_agents(self) -> dict[str, dict[str, str]]: return {key: { "name": agent.name, "description": agent.description } for key, agent in self.agents.items()} - async def dispatch_to_agent(self, - params: Dict[str, Any]) -> Union[ - ConversationMessage, AsyncIterable[Any] - ]: + async def dispatch_to_agent(self, params: dict[str, Any] + ) -> ConversationMessage | AsyncIterable[Any]: user_input = params['user_input'] user_id = params['user_id'] session_id = params['session_id'] @@ -132,13 +135,15 @@ async def classify_request(self, except Exception as error: self.logger.error(f"Error during intent classification: {str(error)}") raise error - + async def agent_process_request(self, user_input: str, user_id: str, session_id: str, classifier_result: ClassifierResult, - additional_params: Dict[str, str] = {}) -> AgentResponse: + additional_params: dict[str, str] = {}, + stream_response: bool | None = False # wether to stream back the response from the agent + ) -> AgentResponse: """Process agent response and handle chat storage.""" try: agent_response = await self.dispatch_to_agent({ @@ -165,33 +170,79 @@ async def agent_process_request(self, classifier_result.selected_agent ) - if isinstance(agent_response, ConversationMessage): - await self.save_message(agent_response, - user_id, - session_id, - classifier_result.selected_agent) + final_response = None + if classifier_result.selected_agent.is_streaming_enabled(): + if stream_response: + if isinstance(agent_response, AsyncIterable): + # Create an async generator function to handle the streaming + async def process_stream(): + full_message = None + async for chunk in agent_response: + if isinstance(chunk, AgentStreamResponse): + if chunk.final_message: + full_message = chunk.final_message + yield chunk + else: + Logger.error("Invalid response type from agent. Expected AgentStreamResponse") + pass + + if full_message: + await self.save_message(full_message, + user_id, + session_id, + classifier_result.selected_agent) + + + final_response = process_stream() + else: + async def process_stream() -> ConversationMessage: + full_message = None + async for chunk in agent_response: + if isinstance(chunk, AgentStreamResponse): + if chunk.final_message: + full_message = chunk.final_message + else: + Logger.error("Invalid response type from agent. Expected AgentStreamResponse") + pass + + if full_message: + await self.save_message(full_message, + user_id, + session_id, + classifier_result.selected_agent) + return full_message + final_response = await process_stream() + + + else: # Non-streaming response + final_response = agent_response + await self.save_message(final_response, + user_id, + session_id, + classifier_result.selected_agent) return AgentResponse( metadata=metadata, - output=agent_response, + output=final_response, streaming=classifier_result.selected_agent.is_streaming_enabled() ) except Exception as error: self.logger.error(f"Error during agent processing: {str(error)}") raise error - + async def route_request(self, user_input: str, user_id: str, - session_id: str, - additional_params: Dict[str, str] = {}) -> AgentResponse: + session_id: str, + additional_params: dict[str, str] = {}, + stream_response: bool | None = False) -> AgentResponse: """Route user request to appropriate agent.""" self.execution_times.clear() try: classifier_result = await self.classify_request(user_input, user_id, session_id) - + if not classifier_result.selected_agent: return AgentResponse( metadata=self.create_metadata(classifier_result, user_input, user_id, session_id, additional_params), @@ -203,11 +254,12 @@ async def route_request(self, ) return await self.agent_process_request( - user_input, + user_input, user_id, - session_id, + session_id, classifier_result, - additional_params + additional_params, + stream_response ) except Exception as error: @@ -252,11 +304,11 @@ async def measure_execution_time(self, timer_name: str, fn): raise error def create_metadata(self, - intent_classifier_result: Optional[ClassifierResult], + intent_classifier_result: ClassifierResult | None, user_input: str, user_id: str, session_id: str, - additional_params: Dict[str, str]) -> AgentProcessingResult: + additional_params: dict[str, str]) -> AgentProcessingResult: base_metadata = AgentProcessingResult( user_input=user_input, agent_id="no_agent_selected", @@ -287,3 +339,15 @@ async def save_message(self, agent.id, message, self.config.MAX_MESSAGE_PAIRS_PER_AGENT) + async def save_messages(self, + messages: list[ConversationMessage] | list[TimestampedMessage], + user_id: str, session_id: str, + agent: Agent): + if agent and agent.save_chat: + for message in messages: + # TODO: change this to self.storage.save_chat_messages() when SupervisorAgent is merged + await self.storage.save_chat_message(user_id, + session_id, + agent.id, + message, + self.config.MAX_MESSAGE_PAIRS_PER_AGENT) diff --git a/python/src/multi_agent_orchestrator/storage/dynamodb_chat_storage.py b/python/src/multi_agent_orchestrator/storage/dynamodb_chat_storage.py index 52ee3d5b..7d34a28b 100644 --- a/python/src/multi_agent_orchestrator/storage/dynamodb_chat_storage.py +++ b/python/src/multi_agent_orchestrator/storage/dynamodb_chat_storage.py @@ -100,7 +100,7 @@ async def save_chat_messages(self, max_history_size ) - item: dict[str, Union[str, list[TimestampedMessage], int]] = { + item: dict[str, str | list[TimestampedMessage] | int] = { 'PK': user_id, 'SK': key, 'conversation': conversation_to_dict(trimmed_conversation), diff --git a/python/src/multi_agent_orchestrator/storage/sql_chat_storage.py b/python/src/multi_agent_orchestrator/storage/sql_chat_storage.py index 109426cb..1e4e1853 100644 --- a/python/src/multi_agent_orchestrator/storage/sql_chat_storage.py +++ b/python/src/multi_agent_orchestrator/storage/sql_chat_storage.py @@ -1,4 +1,3 @@ -from typing import List, Dict, Optional, Union import time import json from libsql_client import Client, create_client @@ -8,14 +7,14 @@ class SqlChatStorage(ChatStorage): """SQL-based chat storage implementation supporting both local SQLite and remote Turso databases.""" - + def __init__( self, url: str, - auth_token: Optional[str] = None + auth_token: str | None = None ): """Initialize SQL storage. - + Args: url: Database URL (e.g., 'file:local.db' or 'libsql://your-db-url.com') auth_token: Authentication token for remote databases (optional) @@ -46,7 +45,7 @@ def _initialize_database(self) -> None: # Create index for faster queries self.client.execute(""" - CREATE INDEX IF NOT EXISTS idx_conversations_lookup + CREATE INDEX IF NOT EXISTS idx_conversations_lookup ON conversations(user_id, session_id, agent_id) """) except Exception as error: @@ -59,8 +58,8 @@ async def save_chat_message( session_id: str, agent_id: str, new_message: ConversationMessage, - max_history_size: Optional[int] = None - ) -> List[ConversationMessage]: + max_history_size: int | None = None + ) -> list[ConversationMessage]: """Save a new chat message.""" try: # Fetch existing conversation @@ -76,7 +75,7 @@ async def save_chat_message( FROM conversations WHERE user_id = ? AND session_id = ? AND agent_id = ? """, [user_id, session_id, agent_id]) - + next_index = result.rows[0]['next_index'] timestamp = int(time.time() * 1000) content = json.dumps(new_message.content) @@ -125,8 +124,8 @@ async def fetch_chat( user_id: str, session_id: str, agent_id: str, - max_history_size: Optional[int] = None - ) -> List[ConversationMessage]: + max_history_size: int | None = None + ) -> list[ConversationMessage]: """Fetch chat messages.""" try: query = """ @@ -161,7 +160,7 @@ async def fetch_all_chats( self, user_id: str, session_id: str - ) -> List[ConversationMessage]: + ) -> list[ConversationMessage]: """Fetch all chat messages for a user and session.""" try: result = self.client.execute(""" @@ -188,9 +187,9 @@ async def fetch_all_chats( def _format_content( self, role: str, - content: Union[List, str], + content: list | str, agent_id: str - ) -> List[Dict[str, str]]: + ) -> list[dict[str, str]]: """Format message content with agent ID for assistant messages.""" if role == ParticipantRole.ASSISTANT.value: text = content[0]['text'] if isinstance(content, list) else content diff --git a/python/src/multi_agent_orchestrator/types/types.py b/python/src/multi_agent_orchestrator/types/types.py index d7ca9fda..2b4bfc69 100644 --- a/python/src/multi_agent_orchestrator/types/types.py +++ b/python/src/multi_agent_orchestrator/types/types.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Dict, Union, TypedDict, Optional, Any +from typing import TypedDict, Optional, Any from dataclasses import dataclass import time @@ -31,7 +31,7 @@ class RequestMetadata(TypedDict): agent_name: str user_id: str session_id: str - additional_params :Optional[Dict[str, str]] + additional_params :Optional[dict[str, str]] error_type: Optional[str] @@ -42,21 +42,21 @@ class ParticipantRole(Enum): class ConversationMessage: role: ParticipantRole - content: List[Any] + content: list[Any] - def __init__(self, role: ParticipantRole, content: Optional[List[Any]] = None): + def __init__(self, role: ParticipantRole, content: Optional[list[Any]] = None): self.role = role self.content = content class TimestampedMessage(ConversationMessage): def __init__(self, role: ParticipantRole, - content: Optional[List[Any]] = None, - timestamp: Optional[int] = None): + content: Optional[list[Any]] = None, + timestamp: int = 0): super().__init__(role, content) # Call the parent constructor self.timestamp = timestamp or int(time.time() * 1000) # Initialize the timestamp attribute (in ms) -TemplateVariables = Dict[str, Union[str, List[str]]] +TemplateVariables = dict[str, str | list[str]] @dataclass class OrchestratorConfig: diff --git a/python/src/multi_agent_orchestrator/utils/helpers.py b/python/src/multi_agent_orchestrator/utils/helpers.py index 80a31c5b..178af5cd 100644 --- a/python/src/multi_agent_orchestrator/utils/helpers.py +++ b/python/src/multi_agent_orchestrator/utils/helpers.py @@ -1,7 +1,7 @@ """ Helpers method """ -from typing import Any, List, Dict, Union +from typing import Any from multi_agent_orchestrator.types import ConversationMessage, TimestampedMessage def is_tool_input(input_obj: Any) -> bool: @@ -13,18 +13,17 @@ def is_tool_input(input_obj: Any) -> bool: ) def conversation_to_dict( - conversation: Union[ - ConversationMessage, - TimestampedMessage, - List[Union[ConversationMessage, TimestampedMessage]] - ] -) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + conversation: + ConversationMessage | + TimestampedMessage | + list[ConversationMessage | TimestampedMessage] +) -> dict[str, Any] | list[dict[str, Any]]: """Convert conversation to dictionary format.""" if isinstance(conversation, list): return [message_to_dict(msg) for msg in conversation] return message_to_dict(conversation) -def message_to_dict(message: Union[ConversationMessage, TimestampedMessage]) -> Dict[str, Any]: +def message_to_dict(message: ConversationMessage | TimestampedMessage) -> dict[str, Any]: """Convert a single message to dictionary format.""" result = { "role": message.role.value if hasattr(message.role, 'value') else str(message.role), diff --git a/python/src/multi_agent_orchestrator/utils/tool.py b/python/src/multi_agent_orchestrator/utils/tool.py index b617ef83..4ec75b52 100644 --- a/python/src/multi_agent_orchestrator/utils/tool.py +++ b/python/src/multi_agent_orchestrator/utils/tool.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Callable, get_type_hints, Union +from typing import Any, Optional, Callable, get_type_hints import inspect from functools import wraps import re @@ -230,7 +230,7 @@ async def tool_handler(self, provider_type, response: Any, _conversation: list[d 'content': tool_results } - def _get_tool_use_block(self, provider_type:AgentProviderType, block: dict) -> Union[dict, None]: + def _get_tool_use_block(self, provider_type:AgentProviderType, block: dict) -> dict | None: """Extract tool use block based on platform format.""" if provider_type == AgentProviderType.BEDROCK.value and "toolUse" in block: return block["toolUse"] diff --git a/python/src/tests/agents/test_agent.py b/python/src/tests/agents/test_agent.py index ef66cf4f..f050300b 100644 --- a/python/src/tests/agents/test_agent.py +++ b/python/src/tests/agents/test_agent.py @@ -111,3 +111,7 @@ async def test_process_request(self, mock_agent): assert isinstance(result, ConversationMessage) assert result.role == "assistant" assert result.content == "Mock response" + + + def test_streaming(self, mock_agent): + assert mock_agent.is_streaming_enabled() is False diff --git a/python/src/tests/agents/test_amazon_bedrock_agent.py b/python/src/tests/agents/test_amazon_bedrock_agent.py index e51c074d..43bec690 100644 --- a/python/src/tests/agents/test_amazon_bedrock_agent.py +++ b/python/src/tests/agents/test_amazon_bedrock_agent.py @@ -120,3 +120,32 @@ async def test_process_request_with_additional_params(bedrock_agent): assert isinstance(result, ConversationMessage) assert result.role == ParticipantRole.ASSISTANT.value assert result.content == [{"text": "Response with additional params"}] + + +def test_streaming(mock_boto3_client): + options = AmazonBedrockAgentOptions( + name="TestAgent", + description="A test agent", + streaming=True + ) + + agent = AmazonBedrockAgent(options) + assert(agent.is_streaming_enabled() == True) + + options = AmazonBedrockAgentOptions( + name="TestAgent", + description="A test agent", + streaming=False + ) + + agent = AmazonBedrockAgent(options) + assert(agent.is_streaming_enabled() == False) + + options = AmazonBedrockAgentOptions( + name="TestAgent", + description="A test agent", + ) + + agent = AmazonBedrockAgent(options) + assert(agent.is_streaming_enabled() == False) + diff --git a/python/src/tests/agents/test_anthropic_agent.py b/python/src/tests/agents/test_anthropic_agent.py index 219b4d25..1fbfbabb 100644 --- a/python/src/tests/agents/test_anthropic_agent.py +++ b/python/src/tests/agents/test_anthropic_agent.py @@ -3,7 +3,7 @@ from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole from multi_agent_orchestrator.agents import AnthropicAgent, AnthropicAgentOptions from multi_agent_orchestrator.utils import Logger - +from anthropic import Anthropic, AsyncAnthropic logger = Logger() @pytest.fixture @@ -23,6 +23,92 @@ def test_no_api_key_init(mock_anthropic): except Exception as e: assert(str(e) == "Anthropic API key or Anthropic client is required") +def test_client(mock_anthropic): + try: + options = AnthropicAgentOptions( + name="TestAgent", + description="A test agent", + client=Anthropic(), + streaming=True + + ) + + _anthropic_llm_agent = AnthropicAgent(options) + assert(_anthropic_llm_agent.api_key is not None) + except Exception as e: + assert(str(e) == "If streaming is enabled, the provided client must be an AsyncAnthropic client") + + try: + options = AnthropicAgentOptions( + name="TestAgent", + description="A test agent", + client=AsyncAnthropic(), + streaming=False + + ) + + _anthropic_llm_agent = AnthropicAgent(options) + assert(_anthropic_llm_agent.api_key is not None) + except Exception as e: + assert(str(e) == "If streaming is disabled, the provided client must be an Anthropic client") + + options = AnthropicAgentOptions( + name="TestAgent", + description="A test agent", + client=AsyncAnthropic(), + streaming=True + + ) + _anthropic_llm_agent = AnthropicAgent(options) + + +def test_inference_config(mock_anthropic): + + options = AnthropicAgentOptions( + name="TestAgent", + description="A test agent", + client=Anthropic(), + streaming=False, + inference_config={ + 'temperature': 0.5, + 'topP': 0.5, + 'topK': 0.5, + 'maxTokens': 1000, + } + ) + + _anthropic_llm_agent = AnthropicAgent(options) + assert _anthropic_llm_agent.inference_config == { + 'temperature': 0.5, + 'topP': 0.5, + 'topK': 0.5, + 'maxTokens': 1000, + 'stopSequences': [] + } + + options = AnthropicAgentOptions( + name="TestAgent", + description="A test agent", + client=Anthropic(), + streaming=False, + inference_config={ + 'temperature': 0.5, + 'topK': 0.5, + 'maxTokens': 1000, + } + ) + + _anthropic_llm_agent = AnthropicAgent(options) + assert _anthropic_llm_agent.inference_config == { + 'temperature': 0.5, + 'topP': 0.9, + 'topK': 0.5, + 'maxTokens': 1000, + 'stopSequences': [] + } + + + def test_custom_system_prompt_with_variable(mock_anthropic): options = AnthropicAgentOptions( api_key='test-api-key', @@ -74,10 +160,10 @@ async def test_process_request_single_response(): # Verify the mock was called mock_instance.messages.create.assert_called_once_with( - model="claude-3-sonnet-20240229", + model='claude-3-sonnet-20240229', max_tokens=1000, - messages=[{"role": "user", "content": "Test prompt"}], - system=anthropic_llm_agent.system_prompt, + messages=[{'role': 'user', 'content': 'Test prompt'}], + system="You are a TestAgent.\n A test agent\n Provide helpful and accurate information based on your expertise.\n You will engage in an open-ended conversation,\n providing helpful and accurate information based on your expertise.\n The conversation will proceed as follows:\n - The human may ask an initial question or provide a prompt on any topic.\n - You will provide a relevant and informative response.\n - The human may then follow up with additional questions or prompts related to your previous\n response, allowing for a multi-turn dialogue on that topic.\n - Or, the human may switch to a completely new and unrelated topic at any point.\n - You will seamlessly shift your focus to the new topic, providing thoughtful and\n coherent responses based on your broad knowledge base.\n Throughout the conversation, you should aim to:\n - Understand the context and intent behind each new question or prompt.\n - Provide substantive and well-reasoned responses that directly address the query.\n - Draw insights and connections from your extensive knowledge when appropriate.\n - Ask for clarification if any part of the question or prompt is ambiguous.\n - Maintain a consistent, respectful, and engaging tone tailored\n to the human's communication style.\n - Seamlessly transition between topics as the human introduces new subjects.", temperature=0.1, top_p=0.9, stop_sequences=[] @@ -85,3 +171,46 @@ async def test_process_request_single_response(): assert isinstance(response, ConversationMessage) assert response.content[0].get('text') == "Test response" assert response.role == ParticipantRole.ASSISTANT.value + + +def test_streaming(mock_anthropic): + options = AnthropicAgentOptions( + api_key='test-api-key', + name="TestAgent", + description="A test agent", + custom_system_prompt={ + 'template': """This is my new prompt with this {{variable}}""", + 'variables': {'variable': 'value'} + }, + streaming=True + ) + + _anthropic_llm_agent = AnthropicAgent(options) + assert(_anthropic_llm_agent.is_streaming_enabled() == True) + + options = AnthropicAgentOptions( + api_key='test-api-key', + name="TestAgent", + description="A test agent", + custom_system_prompt={ + 'template': """This is my new prompt with this {{variable}}""", + 'variables': {'variable': 'value'} + }, + streaming=False + ) + + _anthropic_llm_agent = AnthropicAgent(options) + assert(_anthropic_llm_agent.is_streaming_enabled() == False) + + options = AnthropicAgentOptions( + api_key='test-api-key', + name="TestAgent", + description="A test agent", + custom_system_prompt={ + 'template': """This is my new prompt with this {{variable}}""", + 'variables': {'variable': 'value'} + } + ) + + _anthropic_llm_agent = AnthropicAgent(options) + assert(_anthropic_llm_agent.is_streaming_enabled() == False) diff --git a/python/src/tests/agents/test_bedrock_llm_agent.py b/python/src/tests/agents/test_bedrock_llm_agent.py index c92ea9e2..5ad2301e 100644 --- a/python/src/tests/agents/test_bedrock_llm_agent.py +++ b/python/src/tests/agents/test_bedrock_llm_agent.py @@ -1,8 +1,12 @@ import pytest -from unittest.mock import AsyncMock, patch +from unittest.mock import Mock, AsyncMock, patch +from typing import AsyncIterable from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole -from multi_agent_orchestrator.agents import BedrockLLMAgent, BedrockLLMAgentOptions -from multi_agent_orchestrator.utils import Logger +from multi_agent_orchestrator.agents import ( + BedrockLLMAgent, + BedrockLLMAgentOptions, + AgentStreamResponse) +from multi_agent_orchestrator.utils import Logger, AgentTools, AgentTool logger = Logger() @@ -121,14 +125,23 @@ async def test_process_request_streaming(bedrock_llm_agent, mock_boto3_client): result = await bedrock_llm_agent.process_request(input_text, user_id, session_id, chat_history) - assert isinstance(result, ConversationMessage) - assert result.role == ParticipantRole.ASSISTANT.value - assert result.content[0]['text'] == 'This is a test response' + assert isinstance(result, AsyncIterable) + + async for chunk in result: + assert isinstance(chunk, AgentStreamResponse) + if chunk.final_message: + assert chunk.final_message.role == ParticipantRole.ASSISTANT.value + assert chunk.final_message.content[0]['text'] == 'This is a test response' + @pytest.mark.asyncio async def test_process_request_with_tool_use(bedrock_llm_agent, mock_boto3_client): + async def _handler(message, conversation): + return ConversationMessage(role=ParticipantRole.ASSISTANT, content=[{'text': 'Tool response'}]) bedrock_llm_agent.tool_config = { - "tool": {"name": "test_tool"}, + "tool": [ + AgentTool(name='test_tool', func=_handler, description='This is a test handler') + ], "toolMaxRecursions": 2, "useToolHandler": AsyncMock() } @@ -175,4 +188,42 @@ def test_set_system_prompt(bedrock_llm_agent): assert bedrock_llm_agent.custom_variables == variables assert bedrock_llm_agent.system_prompt == "You are a test agent. Your task is to run tests." -# Add more tests as needed for other methods and edge cases \ No newline at end of file +def test_streaming(mock_boto3_client): + options = BedrockLLMAgentOptions( + name="TestAgent", + description="A test agent", + custom_system_prompt={ + 'template': """This is my new prompt with this {{variable}}""", + 'variables': {'variable': 'value'} + }, + streaming=True + ) + + agent = BedrockLLMAgent(options) + assert(agent.is_streaming_enabled() == True) + + options = BedrockLLMAgentOptions( + name="TestAgent", + description="A test agent", + custom_system_prompt={ + 'template': """This is my new prompt with this {{variable}}""", + 'variables': {'variable': 'value'} + }, + streaming=False + ) + + agent = BedrockLLMAgent(options) + assert(agent.is_streaming_enabled() == False) + + options = BedrockLLMAgentOptions( + name="TestAgent", + description="A test agent", + custom_system_prompt={ + 'template': """This is my new prompt with this {{variable}}""", + 'variables': {'variable': 'value'} + } + ) + + agent = BedrockLLMAgent(options) + assert(agent.is_streaming_enabled() == False) + diff --git a/python/src/tests/agents/test_openai_agent.py b/python/src/tests/agents/test_openai_agent.py index 077a05ee..af36aa84 100644 --- a/python/src/tests/agents/test_openai_agent.py +++ b/python/src/tests/agents/test_openai_agent.py @@ -1,7 +1,8 @@ import pytest from unittest.mock import Mock, AsyncMock, patch +from typing import AsyncIterable from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole -from multi_agent_orchestrator.agents import OpenAIAgent, OpenAIAgentOptions +from multi_agent_orchestrator.agents import OpenAIAgent, OpenAIAgentOptions, AgentStreamResponse @pytest.fixture def mock_openai_client(): @@ -89,21 +90,25 @@ def __init__(self, content): ] mock_openai_client.chat.completions.create.return_value = mock_stream - result = await openai_agent.process_request( + result:AgentStreamResponse = await openai_agent.process_request( "Test question", "test_user", "test_session", [] ) - # chunks = [] - # async for chunk in result: - # chunks.append(chunk) - # assert chunks == ["This ", "is ", "a ", "test response"] - assert isinstance(result, ConversationMessage) - assert result.role == ParticipantRole.ASSISTANT.value - assert result.content[0]['text'] == 'This is a test response' + assert isinstance(result, AsyncIterable) + chunks = [] + async for chunk in result: + assert isinstance(chunk, AgentStreamResponse) + if chunk.text: + chunks.append(chunk.text) + elif chunk.final_message: + assert chunk.final_message.role == ParticipantRole.ASSISTANT.value + assert chunk.final_message.content[0]['text'] == 'This is a test response' + assert chunks == ["This ", "is ", "a ", "test response"] + @pytest.mark.asyncio @@ -164,4 +169,5 @@ async def test_handle_single_response_no_choices(openai_agent, mock_openai_clien def test_is_streaming_enabled(openai_agent): assert not openai_agent.is_streaming_enabled() openai_agent.streaming = True - assert openai_agent.is_streaming_enabled() \ No newline at end of file + assert openai_agent.is_streaming_enabled() + diff --git a/python/src/tests/test_orchestrator.py b/python/src/tests/test_orchestrator.py new file mode 100644 index 00000000..27bc5ca9 --- /dev/null +++ b/python/src/tests/test_orchestrator.py @@ -0,0 +1,350 @@ +import pytest +from unittest.mock import Mock, AsyncMock, patch +from typing import AsyncIterable +import pytest_asyncio +from dataclasses import dataclass + +from multi_agent_orchestrator.types import ( + ConversationMessage, + ParticipantRole, + OrchestratorConfig, + TimestampedMessage +) +from multi_agent_orchestrator.classifiers import Classifier, ClassifierResult +from multi_agent_orchestrator.agents import ( + Agent, + AgentStreamResponse, + AgentResponse, + AgentProcessingResult +) +from multi_agent_orchestrator.storage import ChatStorage, InMemoryChatStorage +from multi_agent_orchestrator.utils.logger import Logger +from multi_agent_orchestrator.orchestrator import MultiAgentOrchestrator + +@pytest.fixture +def mock_boto3_client(): + with patch('boto3.client') as mock_client: + yield mock_client + +# Fixtures +@pytest.fixture +def mock_logger(): + return Mock(spec=Logger) + +@pytest.fixture +def mock_storage(): + storage = AsyncMock(spec=ChatStorage) + storage.fetch_chat = AsyncMock(return_value=[]) + storage.fetch_all_chats = AsyncMock(return_value=[]) + storage.save_chat_message = AsyncMock() + return storage + +@pytest.fixture +def mock_classifier(): + classifier = AsyncMock(spec=Classifier) + classifier.set_agents = Mock() + return classifier + +@pytest.fixture +def mock_agent(): + agent = AsyncMock(spec=Agent) + agent.id = "test_agent" + agent.name = "Test Agent" + agent.description = "Test Agent Description" + agent.save_chat = True + agent.is_streaming_enabled = Mock(return_value=False) + return agent + +@pytest.fixture +def mock_streaming_agent(): + agent = AsyncMock(spec=Agent) + agent.id = "streaming_agent" + agent.name = "Streaming Agent" + agent.description = "Streaming Agent Description" + agent.save_chat = True + agent.is_streaming_enabled = Mock(return_value=True) + return agent + +@pytest.fixture +def orchestrator(mock_storage, mock_classifier, mock_logger, mock_agent, mock_boto3_client): + return MultiAgentOrchestrator( + storage=mock_storage, + classifier=mock_classifier, + logger=mock_logger, + default_agent=mock_agent + ) + +def test_init_with_dict_options(mock_boto3_client): + options = {"MAX_MESSAGE_PAIRS_PER_AGENT": 10} + orchestrator = MultiAgentOrchestrator( + options=options, + classifier=Mock(spec=Classifier) + ) + assert orchestrator.config.MAX_MESSAGE_PAIRS_PER_AGENT == 10 + +def test_init_with_invalid_options(mock_boto3_client): + with pytest.raises(ValueError): + MultiAgentOrchestrator(options="invalid") + +# Test agent management +def test_add_agent(orchestrator, mock_agent): + orchestrator.add_agent(mock_agent) + assert orchestrator.agents[mock_agent.id] == mock_agent + orchestrator.classifier.set_agents.assert_called_once_with(orchestrator.agents) + +def test_add_duplicate_agent(orchestrator, mock_agent): + orchestrator.add_agent(mock_agent) + with pytest.raises(ValueError): + orchestrator.add_agent(mock_agent) + +def test_get_all_agents(orchestrator, mock_agent): + orchestrator.add_agent(mock_agent) + agents = orchestrator.get_all_agents() + assert agents[mock_agent.id]["name"] == mock_agent.name + assert agents[mock_agent.id]["description"] == mock_agent.description + +# Test default agent management +def test_get_default_agent(orchestrator, mock_agent): + assert orchestrator.get_default_agent() == mock_agent + +def test_set_default_agent(orchestrator, mock_agent): + new_agent = AsyncMock(spec=Agent) + orchestrator.set_default_agent(new_agent) + assert orchestrator.get_default_agent() == new_agent + +# Test classifier management +def test_set_classifier(orchestrator): + new_classifier = Mock(spec=Classifier) + orchestrator.set_classifier(new_classifier) + assert orchestrator.classifier == new_classifier + new_classifier.set_agents.assert_called_once_with(orchestrator.agents) + +# Test request classification +@pytest.mark.asyncio +async def test_classify_request_success(orchestrator, mock_agent): + expected_result = ClassifierResult(selected_agent=mock_agent, confidence=0.9) + orchestrator.classifier.classify.return_value = expected_result + + result = await orchestrator.classify_request("test input", "user1", "session1") + assert result == expected_result + +@pytest.mark.asyncio +async def test_classify_request_no_agent_with_default(orchestrator): + orchestrator.classifier.classify.return_value = ClassifierResult(selected_agent=None, confidence=0) + orchestrator.config.USE_DEFAULT_AGENT_IF_NONE_IDENTIFIED = True + + result = await orchestrator.classify_request("test input", "user1", "session1") + assert result.selected_agent == orchestrator.default_agent + +@pytest.mark.asyncio +async def test_classify_request_error(orchestrator): + orchestrator.classifier.classify.side_effect = Exception("Classification error") + + with pytest.raises(Exception): + await orchestrator.classify_request("test input", "user1", "session1") + +# Test dispatch to agent +@pytest.mark.asyncio +async def test_dispatch_to_agent_success(orchestrator, mock_agent): + classifier_result = ClassifierResult(selected_agent=mock_agent, confidence=0.9) + expected_response = ConversationMessage( + role=ParticipantRole.ASSISTANT.value, + content=[{"text": "Test response"}] + ) + mock_agent.process_request.return_value = expected_response + + response = await orchestrator.dispatch_to_agent({ + "user_input": "test", + "user_id": "user1", + "session_id": "session1", + "classifier_result": classifier_result, + "additional_params": {} + }) + + assert response == expected_response + +@pytest.mark.asyncio +async def test_dispatch_to_agent_no_agent(orchestrator): + classifier_result = ClassifierResult(selected_agent=None, confidence=0) + + response = await orchestrator.dispatch_to_agent({ + "user_input": "test", + "user_id": "user1", + "session_id": "session1", + "classifier_result": classifier_result, + "additional_params": {} + }) + + assert isinstance(response, str) + assert "more information" in response + +# Test streaming functionality +@pytest.mark.asyncio +async def test_agent_process_request_streaming(orchestrator, mock_streaming_agent): + classifier_result = ClassifierResult(selected_agent=mock_streaming_agent, confidence=0.9) + + async def mock_stream(): + yield AgentStreamResponse( + chunk="Test chunk", + final_message=ConversationMessage( + role=ParticipantRole.ASSISTANT.value, + content=[{"text": "Final message"}] + ) + ) + + mock_streaming_agent.process_request.return_value = mock_stream() + + response = await orchestrator.agent_process_request( + "test input", + "user1", + "session1", + classifier_result, + stream_response=True + ) + + assert response.streaming == True + assert isinstance(response.output, AsyncIterable) + +# Test route request +@pytest.mark.asyncio +async def test_route_request_success(orchestrator, mock_agent): + classifier_result = ClassifierResult(selected_agent=mock_agent, confidence=0.9) + orchestrator.classifier.classify.return_value = classifier_result + + expected_response = ConversationMessage( + role=ParticipantRole.ASSISTANT.value, + content=[{"text": "Test response"}] + ) + mock_agent.process_request.return_value = expected_response + + response = await orchestrator.route_request( + "test input", + "user1", + "session1" + ) + + assert response.output == expected_response + assert response.metadata.agent_id == mock_agent.id + +@pytest.mark.asyncio +async def test_route_request_error(orchestrator): + orchestrator.classifier.classify.side_effect = Exception("Test error") + + response = await orchestrator.route_request( + "test input", + "user1", + "session1" + ) + + assert isinstance(response.output, str) + assert "Test error" in response.output + +# Test chat storage +@pytest.mark.asyncio +async def test_save_message(orchestrator, mock_agent): + message = ConversationMessage( + role=ParticipantRole.ASSISTANT.value, + content=[{"text": "Test message"}] + ) + + await orchestrator.save_message( + message, + "user1", + "session1", + mock_agent + ) + + orchestrator.storage.save_chat_message.assert_called_once_with( + "user1", + "session1", + mock_agent.id, + message, + orchestrator.config.MAX_MESSAGE_PAIRS_PER_AGENT + ) + +@pytest.mark.asyncio +async def test_save_messages(orchestrator, mock_agent): + messages = [ + ConversationMessage( + role=ParticipantRole.ASSISTANT.value, + content=[{"text": "Message 1"}] + ), + ConversationMessage( + role=ParticipantRole.USER.value, + content=[{"text": "Message 2"}] + ) + ] + + await orchestrator.save_messages( + messages, + "user1", + "session1", + mock_agent + ) + + assert orchestrator.storage.save_chat_message.call_count == 2 + +# Test execution time measurement +@pytest.mark.asyncio +async def test_measure_execution_time(orchestrator): + async def test_fn(): + return "test result" + + orchestrator.config.LOG_EXECUTION_TIMES = True + result = await orchestrator.measure_execution_time("test_timer", test_fn) + + assert result == "test result" + assert "test_timer" in orchestrator.execution_times + assert isinstance(orchestrator.execution_times["test_timer"], float) + +@pytest.mark.asyncio +async def test_measure_execution_time_error(orchestrator): + async def test_fn(): + raise Exception("Test error") + + orchestrator.config.LOG_EXECUTION_TIMES = True + + with pytest.raises(Exception): + await orchestrator.measure_execution_time("test_timer", test_fn) + + assert "test_timer" in orchestrator.execution_times + assert isinstance(orchestrator.execution_times["test_timer"], float) + +# Test metadata creation +def test_create_metadata(orchestrator, mock_agent): + classifier_result = ClassifierResult(selected_agent=mock_agent, confidence=0.9) + + metadata = orchestrator.create_metadata( + classifier_result, + "test input", + "user1", + "session1", + {"param1": "value1"} + ) + + assert metadata.user_input == "test input" + assert metadata.agent_id == mock_agent.id + assert metadata.agent_name == mock_agent.name + assert metadata.user_id == "user1" + assert metadata.session_id == "session1" + assert metadata.additional_params == {"param1": "value1"} + +def test_create_metadata_no_agent(orchestrator): + metadata = orchestrator.create_metadata( + None, + "test input", + "user1", + "session1", + {} + ) + + assert metadata.agent_id == "no_agent_selected" + assert metadata.agent_name == "No Agent" + assert "error_type" in metadata.additional_params + assert metadata.additional_params["error_type"] == "classification_failed" + +# Test fallback functionality +def test_get_fallback_result(orchestrator, mock_agent): + result = orchestrator.get_fallback_result() + assert result.selected_agent == mock_agent + assert result.confidence == 0 \ No newline at end of file diff --git a/python/src/tests/utils/test_tool.py b/python/src/tests/utils/test_tool.py index a95f682d..22c402d3 100644 --- a/python/src/tests/utils/test_tool.py +++ b/python/src/tests/utils/test_tool.py @@ -469,4 +469,69 @@ def test_tool_with_properties(): } } } - } \ No newline at end of file + } + +def test_tool_not_found(): + try: + tools = AgentTools([AgentTool( + name="weather", + func=fetch_weather_data + )]) + tools._process_tool("test", {'test':'value'}) + except Exception as e: + assert str(e) == f"Tool weather not found" + + +def test_get_tool_use_block(): + tools = AgentTools([AgentTool( + name="weather", + func=fetch_weather_data + )]) + response = tools._get_tool_use_block("test", {'test':'value'}) + assert response == None + + +def test_no_func(): + try: + tools = AgentTools([AgentTool( + name="weather", + )]) + except Exception as e: + assert str(e) == "Function must be provided" + +@pytest.mark.asyncio +async def test_no_tool_block(): + try: + tools = AgentTools([AgentTool( + name="weather", + func=fetch_weather_data + )]) + message = ConversationMessage(role=ParticipantRole.ASSISTANT.value, content=None) + response = await tools.tool_handler(AgentProviderType.BEDROCK.value, message, []) + except Exception as e: + assert str(e) == "No content blocks in response" + +@pytest.mark.asyncio +async def test_no_tool_use_block(): + tools = AgentTools([AgentTool( + name="weather", + func=fetch_weather_data + )]) + message = ConversationMessage(role=ParticipantRole.ASSISTANT.value, content=[{'text'}]) + response = await tools.tool_handler(AgentProviderType.BEDROCK.value, message, []) + assert isinstance(response, ConversationMessage) + assert response.role == ParticipantRole.USER.value + assert response.content == [] + + +def test_self_param(): + def _handler(self, tool_input): + return tool_input + tools = AgentTools([AgentTool( + name="test", + func=_handler + )]) + + + +