From 3a8ad9cf2e092baecead4706123f02afda23280c Mon Sep 17 00:00:00 2001 From: ThaiND Date: Sun, 10 Nov 2024 12:53:56 +0700 Subject: [PATCH 1/4] Add AgentNameTermination to terminate conversation --- .../src/autogen_ext/task/__init__.py | 3 ++ .../task/_agent_name_termination.py | 38 +++++++++++++++++++ .../tests/task/test_termination_condition.py | 29 ++++++++++++++ 3 files changed, 70 insertions(+) create mode 100644 python/packages/autogen-ext/src/autogen_ext/task/__init__.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/task/_agent_name_termination.py create mode 100644 python/packages/autogen-ext/tests/task/test_termination_condition.py diff --git a/python/packages/autogen-ext/src/autogen_ext/task/__init__.py b/python/packages/autogen-ext/src/autogen_ext/task/__init__.py new file mode 100644 index 00000000000..25c0b587ddc --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/task/__init__.py @@ -0,0 +1,3 @@ +from ._agent_name_termination import AgentNameTermination + +__all__ = ["AgentNameTermination"] diff --git a/python/packages/autogen-ext/src/autogen_ext/task/_agent_name_termination.py b/python/packages/autogen-ext/src/autogen_ext/task/_agent_name_termination.py new file mode 100644 index 00000000000..ac4d30f0c4b --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/task/_agent_name_termination.py @@ -0,0 +1,38 @@ +from typing import Sequence + +from autogen_agentchat.base import TerminationCondition, TerminatedException +from autogen_agentchat.messages import StopMessage, AgentMessage, ChatMessage + + +class AgentNameTermination(TerminationCondition): + """Terminate the conversation after a specific agent responds. + + Args: + agent_name (str): The name of the agent whose response will trigger the termination. + + Raises: + TerminatedException: If the termination condition has already been reached. + """ + + def __init__(self, agent_name: str) -> None: + self._agent_name = agent_name + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None: + if self._terminated: + raise TerminatedException("Termination condition has already been reached") + if not messages: + return None + last_message = messages[-1] + if last_message.source == self._agent_name: + if isinstance(last_message, ChatMessage): + self._terminated = True + return StopMessage(content=f"Agent '{self._agent_name}' answered", source="AgentNameTermination") + return None + + async def reset(self) -> None: + self._terminated = False diff --git a/python/packages/autogen-ext/tests/task/test_termination_condition.py b/python/packages/autogen-ext/tests/task/test_termination_condition.py new file mode 100644 index 00000000000..a8f63530a20 --- /dev/null +++ b/python/packages/autogen-ext/tests/task/test_termination_condition.py @@ -0,0 +1,29 @@ +import pytest + +from autogen_agentchat.base import TerminatedException +from autogen_agentchat.messages import TextMessage, StopMessage +from autogen_ext.task import AgentNameTermination + + +@pytest.mark.asyncio +async def test_agent_name_termination() -> None: + termination = AgentNameTermination(agent_name="Assistant") + assert await termination([]) is None + + continue_messages = [ + TextMessage(content="Hello", source="Assistant"), + TextMessage(content="Hello", source="user") + ] + assert await termination(continue_messages) is None + + terminate_messages = [ + TextMessage(content="Hello", source="user"), + TextMessage(content="Hello", source="Assistant") + ] + result = await termination(terminate_messages) + assert isinstance(result, StopMessage) + assert termination.terminated + with pytest.raises(TerminatedException): + await termination([]) + await termination.reset() + assert not termination.terminated From d09163b30f69228b915c811d1bf8a00060983591 Mon Sep 17 00:00:00 2001 From: thainduy Date: Sun, 10 Nov 2024 19:15:34 +0700 Subject: [PATCH 2/4] Terminate by list of agent name --- .../src/autogen_ext/task/_agent_name_termination.py | 12 ++++++------ .../tests/task/test_termination_condition.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/task/_agent_name_termination.py b/python/packages/autogen-ext/src/autogen_ext/task/_agent_name_termination.py index ac4d30f0c4b..e4f1ae46059 100644 --- a/python/packages/autogen-ext/src/autogen_ext/task/_agent_name_termination.py +++ b/python/packages/autogen-ext/src/autogen_ext/task/_agent_name_termination.py @@ -1,4 +1,4 @@ -from typing import Sequence +from typing import Sequence, List from autogen_agentchat.base import TerminationCondition, TerminatedException from autogen_agentchat.messages import StopMessage, AgentMessage, ChatMessage @@ -8,14 +8,14 @@ class AgentNameTermination(TerminationCondition): """Terminate the conversation after a specific agent responds. Args: - agent_name (str): The name of the agent whose response will trigger the termination. + agents (List[str]): List of agent names to terminate the conversation. Raises: TerminatedException: If the termination condition has already been reached. """ - def __init__(self, agent_name: str) -> None: - self._agent_name = agent_name + def __init__(self, agents: List[str]) -> None: + self._agents = agents self._terminated = False @property @@ -28,10 +28,10 @@ async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None if not messages: return None last_message = messages[-1] - if last_message.source == self._agent_name: + if last_message.source in self._agents: if isinstance(last_message, ChatMessage): self._terminated = True - return StopMessage(content=f"Agent '{self._agent_name}' answered", source="AgentNameTermination") + return StopMessage(content=f"Agent '{last_message.source}' answered", source="AgentNameTermination") return None async def reset(self) -> None: diff --git a/python/packages/autogen-ext/tests/task/test_termination_condition.py b/python/packages/autogen-ext/tests/task/test_termination_condition.py index a8f63530a20..5aabaeaf83f 100644 --- a/python/packages/autogen-ext/tests/task/test_termination_condition.py +++ b/python/packages/autogen-ext/tests/task/test_termination_condition.py @@ -7,7 +7,7 @@ @pytest.mark.asyncio async def test_agent_name_termination() -> None: - termination = AgentNameTermination(agent_name="Assistant") + termination = AgentNameTermination(agents=["Assistant"]) assert await termination([]) is None continue_messages = [ From f3470bcec3511f5def02625302e82bb04698cb4b Mon Sep 17 00:00:00 2001 From: thainduy Date: Wed, 13 Nov 2024 07:57:38 +0700 Subject: [PATCH 3/4] - Remove unnecessary message type check - Rename class to `SourceMatchTermination` --- .../autogen-ext/src/autogen_ext/task/__init__.py | 4 ++-- ..._name_termination.py => _source_match_termination.py} | 9 ++++----- .../autogen-ext/tests/task/test_termination_condition.py | 4 ++-- 3 files changed, 8 insertions(+), 9 deletions(-) rename python/packages/autogen-ext/src/autogen_ext/task/{_agent_name_termination.py => _source_match_termination.py} (79%) diff --git a/python/packages/autogen-ext/src/autogen_ext/task/__init__.py b/python/packages/autogen-ext/src/autogen_ext/task/__init__.py index 25c0b587ddc..d9152d17e8c 100644 --- a/python/packages/autogen-ext/src/autogen_ext/task/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/task/__init__.py @@ -1,3 +1,3 @@ -from ._agent_name_termination import AgentNameTermination +from ._source_match_termination import SourceMatchTermination -__all__ = ["AgentNameTermination"] +__all__ = ["SourceMatchTermination"] diff --git a/python/packages/autogen-ext/src/autogen_ext/task/_agent_name_termination.py b/python/packages/autogen-ext/src/autogen_ext/task/_source_match_termination.py similarity index 79% rename from python/packages/autogen-ext/src/autogen_ext/task/_agent_name_termination.py rename to python/packages/autogen-ext/src/autogen_ext/task/_source_match_termination.py index e4f1ae46059..ea915e0f7e1 100644 --- a/python/packages/autogen-ext/src/autogen_ext/task/_agent_name_termination.py +++ b/python/packages/autogen-ext/src/autogen_ext/task/_source_match_termination.py @@ -1,10 +1,10 @@ from typing import Sequence, List from autogen_agentchat.base import TerminationCondition, TerminatedException -from autogen_agentchat.messages import StopMessage, AgentMessage, ChatMessage +from autogen_agentchat.messages import StopMessage, AgentMessage -class AgentNameTermination(TerminationCondition): +class SourceMatchTermination(TerminationCondition): """Terminate the conversation after a specific agent responds. Args: @@ -29,9 +29,8 @@ async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None return None last_message = messages[-1] if last_message.source in self._agents: - if isinstance(last_message, ChatMessage): - self._terminated = True - return StopMessage(content=f"Agent '{last_message.source}' answered", source="AgentNameTermination") + self._terminated = True + return StopMessage(content=f"Agent '{last_message.source}' answered", source="SourceMatchTermination") return None async def reset(self) -> None: diff --git a/python/packages/autogen-ext/tests/task/test_termination_condition.py b/python/packages/autogen-ext/tests/task/test_termination_condition.py index 5aabaeaf83f..604aae2ec4a 100644 --- a/python/packages/autogen-ext/tests/task/test_termination_condition.py +++ b/python/packages/autogen-ext/tests/task/test_termination_condition.py @@ -2,12 +2,12 @@ from autogen_agentchat.base import TerminatedException from autogen_agentchat.messages import TextMessage, StopMessage -from autogen_ext.task import AgentNameTermination +from autogen_ext.task import SourceMatchTermination @pytest.mark.asyncio async def test_agent_name_termination() -> None: - termination = AgentNameTermination(agents=["Assistant"]) + termination = SourceMatchTermination(agents=["Assistant"]) assert await termination([]) is None continue_messages = [ From 7fe5daf15e9a762afe765cb19a29c181b3b9d0ef Mon Sep 17 00:00:00 2001 From: thainduy Date: Wed, 13 Nov 2024 09:27:40 +0700 Subject: [PATCH 4/4] Rename `agents` to `sources`, avoid ambiguous naming --- .../autogen_ext/task/_source_match_termination.py | 12 ++++++------ .../tests/task/test_termination_condition.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/task/_source_match_termination.py b/python/packages/autogen-ext/src/autogen_ext/task/_source_match_termination.py index ea915e0f7e1..a0905405f1f 100644 --- a/python/packages/autogen-ext/src/autogen_ext/task/_source_match_termination.py +++ b/python/packages/autogen-ext/src/autogen_ext/task/_source_match_termination.py @@ -5,17 +5,17 @@ class SourceMatchTermination(TerminationCondition): - """Terminate the conversation after a specific agent responds. + """Terminate the conversation after a specific source responds. Args: - agents (List[str]): List of agent names to terminate the conversation. + sources (List[str]): List of source names to terminate the conversation. Raises: TerminatedException: If the termination condition has already been reached. """ - def __init__(self, agents: List[str]) -> None: - self._agents = agents + def __init__(self, sources: List[str]) -> None: + self._sources = sources self._terminated = False @property @@ -28,9 +28,9 @@ async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None if not messages: return None last_message = messages[-1] - if last_message.source in self._agents: + if last_message.source in self._sources: self._terminated = True - return StopMessage(content=f"Agent '{last_message.source}' answered", source="SourceMatchTermination") + return StopMessage(content=f"'{last_message.source}' answered", source="SourceMatchTermination") return None async def reset(self) -> None: diff --git a/python/packages/autogen-ext/tests/task/test_termination_condition.py b/python/packages/autogen-ext/tests/task/test_termination_condition.py index 604aae2ec4a..d92102abf88 100644 --- a/python/packages/autogen-ext/tests/task/test_termination_condition.py +++ b/python/packages/autogen-ext/tests/task/test_termination_condition.py @@ -7,7 +7,7 @@ @pytest.mark.asyncio async def test_agent_name_termination() -> None: - termination = SourceMatchTermination(agents=["Assistant"]) + termination = SourceMatchTermination(sources=["Assistant"]) assert await termination([]) is None continue_messages = [