Skip to content

Commit

Permalink
to fully stop
Browse files Browse the repository at this point in the history
  • Loading branch information
XuhuiZhou committed Dec 14, 2024
1 parent 7ec769b commit c220664
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 13 deletions.
2 changes: 2 additions & 0 deletions examples/experimental/nodes/initial_message_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
input_tick_channel: str,
output_channels: list[str],
env_scenario: str,
node_name: str,
redis_url: str = "redis://localhost:6379/0",
):
super().__init__(
Expand All @@ -26,6 +27,7 @@ def __init__(
(output_channel, Text) for output_channel in output_channels
],
redis_url=redis_url,
node_name=node_name,
)
self.env_scenario = env_scenario
self.output_channels = output_channels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
output_channel: str,
query_interval: int,
agent_name: str,
node_name: str,
goal: str,
model_name: str,
redis_url: str,
Expand All @@ -42,6 +43,7 @@ def __init__(
[(input_channel, Observation) for input_channel in input_channels],
[(output_channel, AgentAction)],
redis_url,
node_name,
)
self.output_channel = output_channel
self.query_interval = query_interval
Expand Down
2 changes: 1 addition & 1 deletion examples/experimental/sotopia_original_replica/origin.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ input_channels = ["Jane:moderator", "Jack:moderator"]
agent_backgrounds = {"Jane" = "", "Jack" = ""}
agent_mapping = {"moderator:Jane" = "Jane", "moderator:Jack" = "Jack"}
scenario = "Two friends are sitting in a cafe and catching up with each other's lives."
max_turns = 20
max_turns = 2
push_to_db = false

[[nodes]]
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ plugins = [
module = "transformers.*"
ignore_missing_imports = true

[tool.uv.sources]
aact = { git = "https://github.com/ProKil/aact" , branch = "feature/node-manager" }

[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = "test_*.py"
Expand Down
2 changes: 2 additions & 0 deletions sotopia/experimental/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ def __init__(
input_channel_types: list[tuple[str, type[T_agent_observation]]],
output_channel_types: list[tuple[str, type[T_agent_action]]],
redis_url: str = "redis://localhost:6379/0",
node_name: str = "base_agent",
):
super().__init__(
input_channel_types=input_channel_types,
output_channel_types=output_channel_types,
redis_url=redis_url,
node_name=node_name,
)

self.observation_queue: asyncio.Queue[T_agent_observation] = asyncio.Queue()
Expand Down
29 changes: 23 additions & 6 deletions sotopia/experimental/agents/moderator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
output_channels: list[str],
scenario: str,
agent_mapping: dict[str, str],
node_name: str,
agent_backgrounds: dict[str, str],
redis_url: str = "redis://localhost:6379/0",
action_order: Literal["simultaneous", "round-robin", "random"] = "round-robin",
Expand All @@ -55,6 +56,7 @@ def __init__(
(output_channel, Observation) for output_channel in output_channels
],
redis_url=redis_url,
node_name=node_name,
)
self.observation_queue: asyncio.Queue[AgentAction] = asyncio.Queue()
self.task_scheduler: asyncio.Task[None] | None = None
Expand Down Expand Up @@ -97,13 +99,13 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None
self.task_scheduler.cancel()
return await super().__aexit__(exc_type, exc_value, traceback)

async def send(self, action: Observations) -> None:
async def send(self, observations: Observations) -> None:
for output_channel, output_channel_type in self.output_channel_types.items():
if output_channel in action.observations_map:
if output_channel in observations.observations_map:
await self.r.publish(
output_channel,
Message[output_channel_type]( # type:ignore[valid-type]
data=action.observations_map[output_channel]
data=observations.observations_map[output_channel]
).model_dump_json(),
)

Expand Down Expand Up @@ -172,6 +174,19 @@ async def booting(self) -> None:
)
self.current_agent_index += 1

async def wrap_up_and_stop(self) -> None:
if self.push_to_db:
await self.save()
await asyncio.sleep(0.5)
print("stopping all agents")
for output_channel, output_channel_type in self.output_channel_types.items():
await self.r.publish(
output_channel,
Message[output_channel_type]( # type:ignore[valid-type]
data=f"shutdown:{self.node_name}"
).model_dump_json(),
)

async def save(self) -> EpisodeLog:
"""
save the EpisodeLog to redis, without evaluating
Expand All @@ -196,6 +211,11 @@ async def save(self) -> EpisodeLog:
return epilog

async def aact(self, agent_action: AgentAction) -> Observations | None:
if agent_action.action_type == "leave":
self.agents_awake[agent_action.agent_name] = False
if True not in self.agents_awake.values():
await self.wrap_up_and_stop()
return None
if agent_action.action_type == "none":
return None

Expand All @@ -221,9 +241,6 @@ async def aact(self, agent_action: AgentAction) -> Observations | None:
if self.turn_number < self.max_turns:
self.turn_number += 1
else:
if self.push_to_db:
await self.save()
self.shutdown_event.set()
return Observations(
observations_map={
output_channel: Observation(
Expand Down
9 changes: 3 additions & 6 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit c220664

Please sign in to comment.