Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add csv& json #30

Merged
merged 9 commits into from
Apr 6, 2024
74 changes: 72 additions & 2 deletions notebooks/redis_stats.ipynb
ProKil marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,89 @@
"import sys\n",
"import os\n",
"import json\n",
"from typing import get_args\n",
"from sotopia.generation_utils.generate import LLM_Name\n",
"from tqdm.notebook import tqdm\n",
"import rich\n",
"import logging\n",
"from pydantic import ValidationError\n",
"from collections import defaultdict, Counter\n",
"from sotopia.database.persistent_profile import AgentProfile, EnvironmentProfile, RelationshipProfile\n",
"from sotopia.database.logs import EpisodeLog\n",
"from sotopia.database import AgentProfile, EnvironmentProfile, RelationshipProfile, EpisodeLog, episodes_to_csv, episodes_to_json \n",
"from sotopia.database.env_agent_combo_storage import EnvAgentComboStorage\n",
"from collections import Counter \n",
"from redis_om import Migrator\n",
"from rich.console import Console\n",
"from rich.terminal_theme import MONOKAI "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Episodes to CSV"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def _is_valid_episode_log_pk(pk: str) -> bool:\n",
" try:\n",
" episode = EpisodeLog.get(pk=pk)\n",
" except ValidationError:\n",
" return False\n",
" try:\n",
" tag = episode.tag\n",
" model_1, model_2, version = tag.split(\"_\", maxsplit=2)\n",
" if (\n",
" model_1 in get_args(LLM_Name)\n",
" and model_2 in get_args(LLM_Name)\n",
" and version == \"v0.0.1_clean\"\n",
" ):\n",
" return True\n",
" else:\n",
" return False\n",
" except (ValueError, AttributeError):\n",
" # ValueError: tag has less than 3 parts\n",
" # AttributeError: tag is None\n",
" return False\n",
"\n",
"\n",
"episodes: list[EpisodeLog] = [\n",
" EpisodeLog.get(pk=pk)\n",
" for pk in filter(_is_valid_episode_log_pk, EpisodeLog.all_pks())\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"len(episodes)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"episodes_to_csv(episodes, \"../data/sotopia_episodes_v1.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"episodes_to_json(episodes, \"../data/sotopia_episodes_v1.json\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
8 changes: 8 additions & 0 deletions sotopia/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
RelationshipType,
)
from .session_transaction import MessageTransaction, SessionTransaction
from .utils import (
episodes_to_csv,
episodes_to_json,
get_rewards_from_episode,
)
from .waiting_room import MatchingInWaitingRoom

__all__ = [
Expand All @@ -23,4 +28,7 @@
"SessionTransaction",
"MessageTransaction",
"MatchingInWaitingRoom",
"episodes_to_csv",
"episodes_to_json",
"get_rewards_from_episodes",
]
181 changes: 181 additions & 0 deletions sotopia/database/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import json

import pandas as pd

from .logs import EpisodeLog
from .persistent_profile import AgentProfile, EnvironmentProfile


def _map_gender_to_adj(gender: str) -> str:
gender_to_adj = {
"Man": "male",
"Woman": "female",
"Nonbinary": "nonbinary",
}
if gender:
return gender_to_adj[gender]
else:
return ""


def get_rewards_from_episode(episode: EpisodeLog) -> list[dict[str, float]]:
assert (
len(episode.rewards) == 2
and (not isinstance(episode.rewards[0], float))
and (not isinstance(episode.rewards[1], float))
)
return [episode.rewards[0][1], episode.rewards[1][1]]


def get_scenario_from_episode(episode: EpisodeLog) -> str:
"""Get the scenario from the episode.

Args:
episode (EpisodeLog): The episode.

Returns:
str: The scenario.
"""
return EnvironmentProfile.get(pk=episode.environment).scenario


def get_codename_from_episode(episode: EpisodeLog) -> str:
"""Get the codename from the episode.

Args:
episode (EpisodeLog): The episode.

Returns:
str: The codename.
"""
return EnvironmentProfile.get(pk=episode.environment).codename


def get_agents_background_from_episode(episode: EpisodeLog) -> dict[str, str]:
"""Get the agents' background from the episode.

Args:
episode (EpisodeLog): The episode.

Returns:
list[str]: The agents' background.
"""
agents = [AgentProfile.get(pk=agent) for agent in episode.agents]

return {
f"{profile.first_name} {profile.last_name}": f"{profile.first_name} {profile.last_name} is a {profile.age}-year-old {_map_gender_to_adj(profile.gender)} {profile.occupation.lower()}. {profile.gender_pronoun} pronouns. {profile.public_info} Personality and values description: {profile.personality_and_values} {profile.first_name}'s secrets: {profile.secret}"
for profile in agents
}


def get_social_goals_from_episode(
XuhuiZhou marked this conversation as resolved.
Show resolved Hide resolved
epsidoes: list[EpisodeLog],
) -> list[dict[str, str]]:
"""Obtain social goals from episodes.

Args:
epsidoes (list[EpisodeLog]): List of episodes.

Returns:
list[dict[str, str]]: List of social goals with agent names as the index.
"""
social_goals = []
for episode in epsidoes:
agents = [AgentProfile.get(agent) for agent in episode.agents]
agent_names = [
agent.first_name + " " + agent.last_name for agent in agents
]
environment = EnvironmentProfile.get(episode.environment)
agent_goals = {
agent_names[0]: environment.agent_goals[0],
agent_names[1]: environment.agent_goals[1],
}
social_goals.append(agent_goals)
return social_goals


def get_social_interactions_from_episode(
epsidoes: list[EpisodeLog],
) -> list[str]:
"""Obtain pure social interactions from episodes.
Args:
epsidoes (list[EpisodeLog]): List of episodes.
Returns:
list[str]: List of social interactions.
"""
social_interactions = []
for episode in epsidoes:
assert isinstance(episode.tag, str)
list_of_social_interactions = episode.render_for_humans()[1]
if len(list_of_social_interactions) < 3:
continue
if "script" in episode.tag.split("_"):
overall_social_interaction = list_of_social_interactions[1:-3]
else:
overall_social_interaction = list_of_social_interactions[0:-3]
# only get the sentence after "Conversation Starts:\n\n"
starter_msg_list = overall_social_interaction[0].split(
"Conversation Starts:\n\n"
)
if len(starter_msg_list) < 3:
overall_social_interaction = list_of_social_interactions[1:-3]
# raise ValueError("The starter message is not in the expected format")
else:
overall_social_interaction[0] = starter_msg_list[-1]

social_interactions.append("\n\n".join(overall_social_interaction))
return social_interactions


def episodes_to_csv(
episodes: list[EpisodeLog], filepath: str = "episodes.csv"
) -> None:
"""Save episodes to a csv file.

Args:
episodes (list[EpisodeLog]): List of episodes.
filepath (str, optional): The file path. Defaults to "episodes.csv".
"""
data = {
"episode_id": [episode.pk for episode in episodes],
"scenario": [
get_scenario_from_episode(episode) for episode in episodes
],
"codename": [
get_codename_from_episode(episode) for episode in episodes
],
"agents_background": [
get_agents_background_from_episode(episode) for episode in episodes
],
"social_goals": get_social_goals_from_episode(episodes),
XuhuiZhou marked this conversation as resolved.
Show resolved Hide resolved
"social_interactions": get_social_interactions_from_episode(episodes),
}
df = pd.DataFrame(data)
df.to_csv(filepath, index=False)


def episodes_to_json(
XuhuiZhou marked this conversation as resolved.
Show resolved Hide resolved
episodes: list[EpisodeLog], filepath: str = "episodes.json"
) -> None:
"""Save episodes to a json file.

Args:
episodes (list[EpisodeLog]): List of episodes.
filepath (str, optional): The file path. Defaults to "episodes.json".
"""
with open(filepath, "w") as f:
for episode in episodes:
data = {
"episode_id": episode.pk,
"scenario": get_scenario_from_episode(episode),
"codename": get_codename_from_episode(episode),
"agents_background": get_agents_background_from_episode(
episode
),
"social_goals": get_social_goals_from_episode([episode]),
"social_interactions": get_social_interactions_from_episode(
[episode]
),
}
json.dump(data, f)
f.write("\n")
Loading