Skip to content

Commit

Permalink
Fix agent deletion
Browse files Browse the repository at this point in the history
  • Loading branch information
RussellLuo committed Jan 12, 2025
1 parent b35ff6b commit ddee4b7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 8 deletions.
4 changes: 1 addition & 3 deletions coagent/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def delete(self) -> None:

if self.factory_address:
msg = DeleteAgent(session_id=self.address.id).encode()
await self.channel.publish(self.factory_address, msg)
await self.channel.publish(self.factory_address, msg, probe=False)

async def started(self) -> None:
"""This handler is called after the agent is started."""
Expand Down Expand Up @@ -243,8 +243,6 @@ async def _handle_control(self, msg: ControlMessage) -> None:
"""Handle CONTROL messages."""
match msg:
case Cancel():
if self._handle_data_task:
self._handle_data_task.cancel()
# Delete the agent when cancelled.
await self.delete()

Expand Down
5 changes: 3 additions & 2 deletions coagent/core/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ async def start(self) -> None:
await super().start()

# Generate a unique address and create an instance subscription.
self._instance_address = Address(name=self.address.name, id=uuid.uuid4().hex)
self._instance_sub = self.channel.subscribe(
unique_id = uuid.uuid4().hex
self._instance_address = Address(name=f"{self.address.name}_{unique_id}")
self._instance_sub = await self.channel.subscribe(
self._instance_address, handler=self.receive
)

Expand Down
30 changes: 27 additions & 3 deletions tests/core/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from coagent.core.types import Address
from coagent.core.types import Address, Agent, Channel, RawMessage
from coagent.core.agent import BaseAgent, Context, handler
from coagent.core.exceptions import BaseError
from coagent.core.messages import Cancel, Message
Expand Down Expand Up @@ -43,6 +43,22 @@ async def handle(self, msg: Query, ctx: Context) -> AsyncIterator[Reply]:
yield Reply()


class _TestFactory:
def __init__(self, channel: Channel, address: Address):
self.channel = channel
self.address = address

self.agent = None
self.sub = None

async def receive(self, msg: RawMessage) -> None:
await self.agent.stop()

async def start(self, agent: Agent) -> None:
self.agent = agent
self.sub = await self.channel.subscribe(self.address, self.receive)


class TestTrivialAgent:
@pytest.mark.asyncio
async def test_normal(self, local_channel, run_agent_in_task, yield_control):
Expand All @@ -60,9 +76,13 @@ async def test_normal(self, local_channel, run_agent_in_task, yield_control):

@pytest.mark.asyncio
async def test_cancel(self, local_channel, run_agent_in_task, yield_control):
test_factory = _TestFactory(local_channel, Address(name="test_1"))

agent = TrivialAgent(wait_s=10)
addr = Address(name="test", id="1")
agent.init(local_channel, addr)
agent.init(local_channel, addr, test_factory.address)

await test_factory.start(agent)

_task = run_agent_in_task(agent)
await yield_control()
Expand Down Expand Up @@ -97,9 +117,13 @@ async def test_normal(self, local_channel, run_agent_in_task, yield_control):

@pytest.mark.asyncio
async def test_cancel(self, local_channel, run_agent_in_task, yield_control):
test_factory = _TestFactory(local_channel, Address(name="test_3"))

agent = StreamAgent(wait_s=10)
addr = Address(name="test", id="3")
agent.init(local_channel, addr)
agent.init(local_channel, addr, test_factory.address)

await test_factory.start(agent)

_task = run_agent_in_task(agent)
await yield_control()
Expand Down

0 comments on commit ddee4b7

Please sign in to comment.