Skip to content

Commit

Permalink
Add Channel.cancel()
Browse files Browse the repository at this point in the history
  • Loading branch information
RussellLuo committed Jan 14, 2025
1 parent f8b12b1 commit 1af0804
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 9 deletions.
3 changes: 1 addition & 2 deletions coagent/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from coagent.core import Address, RawMessage, set_stderr_logger
from coagent.core.exceptions import BaseError
from coagent.core.messages import Cancel
from coagent.runtimes import NATSRuntime, HTTPRuntime


Expand Down Expand Up @@ -70,7 +69,7 @@ async def run(
async for chunk in runtime.channel.publish_multi(addr, msg):
print_msg(chunk, oneline, filter)
except asyncio.CancelledError:
await runtime.channel.publish(addr, Cancel().encode(), probe=False)
await runtime.channel.cancel(addr)
except BaseError as exc:
print(f"Error: {exc}")

Expand Down
8 changes: 7 additions & 1 deletion coagent/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .discovery import Discovery
from .exceptions import BaseError
from .messages import StopIteration, Error
from .messages import Cancel, Error, StopIteration
from .factory import Factory
from .types import (
AgentSpec,
Expand Down Expand Up @@ -101,6 +101,12 @@ async def publish_multi(
finally:
await sub.unsubscribe()

async def cancel(self, addr: Address) -> None:
"""Cancel the agent with the given address."""

# A shortcut for sending a Cancel message to the agent.
await self.publish(addr, Cancel().encode(), probe=False)


class QueueSubscriptionIterator:
"""A Queue-based async iterator that receives messages from a subscription and yields them."""
Expand Down
5 changes: 5 additions & 0 deletions coagent/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ async def subscribe(
async def new_reply_topic(self) -> str:
pass

@abc.abstractmethod
async def cancel(self, addr: Address) -> None:
"""Cancel the agent with the given address."""
pass


@dataclasses.dataclass
class AgentSpec:
Expand Down
5 changes: 2 additions & 3 deletions coagent/cos/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
RawMessage,
logger,
)
from coagent.core.messages import Cancel
from coagent.core.exceptions import BaseError
from coagent.core.types import Runtime
from coagent.core.util import clear_queue
Expand Down Expand Up @@ -149,7 +148,7 @@ async def publish(self, request: Request):
# Disconnected from the client.

# Cancel the ongoing operation.
await self._runtime.channel.publish(addr, Cancel().encode())
await self._runtime.channel.cancel(addr)

if resp is None:
return Response(status_code=204)
Expand Down Expand Up @@ -178,7 +177,7 @@ async def event_stream() -> AsyncIterator[str]:
# Disconnected from the client.

# Cancel the ongoing operation.
await self._runtime.channel.publish(addr, Cancel().encode())
await self._runtime.channel.cancel(addr)

return EventSourceResponse(event_stream())

Expand Down
6 changes: 3 additions & 3 deletions tests/core/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from coagent.core.types import Address, Agent, Channel, RawMessage
from coagent.core.agent import BaseAgent, Context, handler
from coagent.core.exceptions import BaseError
from coagent.core.messages import Cancel, Message
from coagent.core.messages import Message


class Query(Message):
Expand Down Expand Up @@ -89,7 +89,7 @@ async def test_cancel(self, local_channel, run_agent_in_task, yield_control):

async def cancel():
await asyncio.sleep(0.01)
await local_channel.publish(addr, Cancel().encode())
await local_channel.cancel(addr)

_ = asyncio.create_task(cancel())
await yield_control()
Expand Down Expand Up @@ -130,7 +130,7 @@ async def test_cancel(self, local_channel, run_agent_in_task, yield_control):

async def cancel():
await asyncio.sleep(0.01)
await local_channel.publish(addr, Cancel().encode())
await local_channel.cancel(addr)

_ = asyncio.create_task(cancel())
await yield_control()
Expand Down

0 comments on commit 1af0804

Please sign in to comment.