diff --git a/coagent/cli/main.py b/coagent/cli/main.py index af59b10..b518aa5 100644 --- a/coagent/cli/main.py +++ b/coagent/cli/main.py @@ -7,7 +7,6 @@ from coagent.core import Address, RawMessage, set_stderr_logger from coagent.core.exceptions import BaseError -from coagent.core.messages import Cancel from coagent.runtimes import NATSRuntime, HTTPRuntime @@ -70,7 +69,7 @@ async def run( async for chunk in runtime.channel.publish_multi(addr, msg): print_msg(chunk, oneline, filter) except asyncio.CancelledError: - await runtime.channel.publish(addr, Cancel().encode(), probe=False) + await runtime.channel.cancel(addr) except BaseError as exc: print(f"Error: {exc}") diff --git a/coagent/core/runtime.py b/coagent/core/runtime.py index 5d6c27e..c5ddd59 100644 --- a/coagent/core/runtime.py +++ b/coagent/core/runtime.py @@ -5,7 +5,7 @@ from .discovery import Discovery from .exceptions import BaseError -from .messages import StopIteration, Error +from .messages import Cancel, Error, StopIteration from .factory import Factory from .types import ( AgentSpec, @@ -101,6 +101,12 @@ async def publish_multi( finally: await sub.unsubscribe() + async def cancel(self, addr: Address) -> None: + """Cancel the agent with the given address.""" + + # A shortcut for sending a Cancel message to the agent. + await self.publish(addr, Cancel().encode(), probe=False) + class QueueSubscriptionIterator: """A Queue-based async iterator that receives messages from a subscription and yields them.""" diff --git a/coagent/core/types.py b/coagent/core/types.py index 8e6a2e1..97f46bf 100644 --- a/coagent/core/types.py +++ b/coagent/core/types.py @@ -268,6 +268,11 @@ async def subscribe( async def new_reply_topic(self) -> str: pass + @abc.abstractmethod + async def cancel(self, addr: Address) -> None: + """Cancel the agent with the given address.""" + pass + @dataclasses.dataclass class AgentSpec: diff --git a/coagent/cos/runtime.py b/coagent/cos/runtime.py index c79f812..5a21e50 100644 --- a/coagent/cos/runtime.py +++ b/coagent/cos/runtime.py @@ -14,7 +14,6 @@ RawMessage, logger, ) -from coagent.core.messages import Cancel from coagent.core.exceptions import BaseError from coagent.core.types import Runtime from coagent.core.util import clear_queue @@ -149,7 +148,7 @@ async def publish(self, request: Request): # Disconnected from the client. # Cancel the ongoing operation. - await self._runtime.channel.publish(addr, Cancel().encode()) + await self._runtime.channel.cancel(addr) if resp is None: return Response(status_code=204) @@ -178,7 +177,7 @@ async def event_stream() -> AsyncIterator[str]: # Disconnected from the client. # Cancel the ongoing operation. - await self._runtime.channel.publish(addr, Cancel().encode()) + await self._runtime.channel.cancel(addr) return EventSourceResponse(event_stream()) diff --git a/tests/core/test_agent.py b/tests/core/test_agent.py index 0e30cbd..1c091f3 100644 --- a/tests/core/test_agent.py +++ b/tests/core/test_agent.py @@ -6,7 +6,7 @@ from coagent.core.types import Address, Agent, Channel, RawMessage from coagent.core.agent import BaseAgent, Context, handler from coagent.core.exceptions import BaseError -from coagent.core.messages import Cancel, Message +from coagent.core.messages import Message class Query(Message): @@ -89,7 +89,7 @@ async def test_cancel(self, local_channel, run_agent_in_task, yield_control): async def cancel(): await asyncio.sleep(0.01) - await local_channel.publish(addr, Cancel().encode()) + await local_channel.cancel(addr) _ = asyncio.create_task(cancel()) await yield_control() @@ -130,7 +130,7 @@ async def test_cancel(self, local_channel, run_agent_in_task, yield_control): async def cancel(): await asyncio.sleep(0.01) - await local_channel.publish(addr, Cancel().encode()) + await local_channel.cancel(addr) _ = asyncio.create_task(cancel()) await yield_control()