Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add csv& json #30

Merged
merged 9 commits into from
Apr 6, 2024
4 changes: 4 additions & 0 deletions docs/all_the_issues.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

Large batch size may cause some episodes to be skipped. This is due to the fact that the server may not be able to handle the load. Try reducing the batch size. But you can also use the script in `examples/fix_missing_episodes.py` to fix the missing episodes.

## How to serialize the data saved in the database?

Check out `Episodes_to_CSV/JSON` in the `notebooks/redis_stats.ipynb` notebook.

## Where I can find the data?

For the full data:
Expand Down
75 changes: 73 additions & 2 deletions notebooks/redis_stats.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,90 @@
"import sys\n",
"import os\n",
"import json\n",
"from typing import get_args\n",
"from tqdm.notebook import tqdm\n",
"import rich\n",
"import logging\n",
"from pydantic import ValidationError\n",
"from collections import defaultdict, Counter\n",
"from sotopia.database.persistent_profile import AgentProfile, EnvironmentProfile, RelationshipProfile\n",
"from sotopia.database.logs import EpisodeLog\n",
"from sotopia.database import AgentProfile, EnvironmentProfile, RelationshipProfile, EpisodeLog, episodes_to_csv, episodes_to_json \n",
"from sotopia.database.env_agent_combo_storage import EnvAgentComboStorage\n",
"from collections import Counter \n",
"from redis_om import Migrator\n",
"from rich.console import Console\n",
"from rich.terminal_theme import MONOKAI "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Episodes to CSV/JSON"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"LLM_Name = Literal[\n",
" \"togethercomputer/llama-2-7b-chat\",\n",
" \"togethercomputer/llama-2-70b-chat\",\n",
" \"togethercomputer/mpt-30b-chat\",\n",
" \"gpt-3.5-turbo\",\n",
" \"text-davinci-003\",\n",
" \"gpt-4\",\n",
" \"gpt-4-turbo\",\n",
" \"human\",\n",
" \"redis\",\n",
"]\n",
"def _is_valid_episode_log_pk(pk: str) -> bool:\n",
" try:\n",
" episode = EpisodeLog.get(pk=pk)\n",
" except ValidationError:\n",
" return False\n",
" try:\n",
" tag = episode.tag\n",
" model_1, model_2, version = tag.split(\"_\", maxsplit=2)\n",
" if (\n",
" model_1 in get_args(LLM_Name)\n",
" and model_2 in get_args(LLM_Name)\n",
" and version == \"v0.0.1_clean\"\n",
" ):\n",
" return True\n",
" else:\n",
" return False\n",
" except (ValueError, AttributeError):\n",
" # ValueError: tag has less than 3 parts\n",
" # AttributeError: tag is None\n",
" return False\n",
"\n",
"\n",
"episodes: list[EpisodeLog] = [\n",
" EpisodeLog.get(pk=pk)\n",
" for pk in filter(_is_valid_episode_log_pk, EpisodeLog.all_pks())\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"episodes_to_csv(episodes, \"../data/sotopia_episodes_v1.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"episodes_to_json(episodes, \"../data/sotopia_episodes_v1.json\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
8 changes: 8 additions & 0 deletions sotopia/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
RelationshipProfile,
RelationshipType,
)
from .serialization import (
episodes_to_csv,
episodes_to_json,
get_rewards_from_episode,
)
from .session_transaction import MessageTransaction, SessionTransaction
from .waiting_room import MatchingInWaitingRoom

Expand All @@ -23,4 +28,7 @@
"SessionTransaction",
"MessageTransaction",
"MatchingInWaitingRoom",
"episodes_to_csv",
"episodes_to_json",
"get_rewards_from_episodes",
]
200 changes: 200 additions & 0 deletions sotopia/database/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import json

import pandas as pd
from pydantic import BaseModel, Field

from .logs import EpisodeLog
from .persistent_profile import AgentProfile, EnvironmentProfile


class TwoAgentEpisodeWithScenarioBackgroundGoals(BaseModel):
episode_id: str = Field(required=True)
scenario: str = Field(required=True)
codename: str = Field(required=True)
agents_background: dict[str, str] = Field(required=True)
social_goals: dict[str, str] = Field(required=True)
social_interactions: str = Field(required=True)
reasoning: str = Field(required=False)
rewards: list[dict[str, float]] = Field(required=False)


def _map_gender_to_adj(gender: str) -> str:
gender_to_adj = {
"Man": "male",
"Woman": "female",
"Nonbinary": "nonbinary",
}
if gender:
return gender_to_adj[gender]
else:
return ""


def get_rewards_from_episode(episode: EpisodeLog) -> list[dict[str, float]]:
assert (
len(episode.rewards) == 2
and (not isinstance(episode.rewards[0], float))
and (not isinstance(episode.rewards[1], float))
)
return [episode.rewards[0][1], episode.rewards[1][1]]


def get_scenario_from_episode(episode: EpisodeLog) -> str:
"""Get the scenario from the episode.

Args:
episode (EpisodeLog): The episode.

Returns:
str: The scenario.
"""
return EnvironmentProfile.get(pk=episode.environment).scenario


def get_codename_from_episode(episode: EpisodeLog) -> str:
"""Get the codename from the episode.

Args:
episode (EpisodeLog): The episode.

Returns:
str: The codename.
"""
return EnvironmentProfile.get(pk=episode.environment).codename


def get_agents_background_from_episode(episode: EpisodeLog) -> dict[str, str]:
"""Get the agents' background from the episode.

Args:
episode (EpisodeLog): The episode.

Returns:
list[str]: The agents' background.
"""
agents = [AgentProfile.get(pk=agent) for agent in episode.agents]

return {
f"{profile.first_name} {profile.last_name}": f"{profile.first_name} {profile.last_name} is a {profile.age}-year-old {_map_gender_to_adj(profile.gender)} {profile.occupation.lower()}. {profile.gender_pronoun} pronouns. {profile.public_info} Personality and values description: {profile.personality_and_values} {profile.first_name}'s secrets: {profile.secret}"
for profile in agents
}


def get_agent_name_to_social_goal_from_episode(
episode: EpisodeLog,
) -> dict[str, str]:
agents = [AgentProfile.get(agent) for agent in episode.agents]
agent_names = [
agent.first_name + " " + agent.last_name for agent in agents
]
environment = EnvironmentProfile.get(episode.environment)
agent_goals = {
agent_names[0]: environment.agent_goals[0],
agent_names[1]: environment.agent_goals[1],
}
return agent_goals


def get_social_interactions_from_episode(
episode: EpisodeLog,
) -> str:
assert isinstance(episode.tag, str)
list_of_social_interactions = episode.render_for_humans()[1]
if len(list_of_social_interactions) < 3:
return ""
if "script" in episode.tag.split("_"):
overall_social_interaction = list_of_social_interactions[1:-3]
else:
overall_social_interaction = list_of_social_interactions[0:-3]
# only get the sentence after "Conversation Starts:\n\n"
starter_msg_list = overall_social_interaction[0].split(
"Conversation Starts:\n\n"
)
if len(starter_msg_list) < 3:
overall_social_interaction = list_of_social_interactions[1:-3]
# raise ValueError("The starter message is not in the expected format")
else:
overall_social_interaction[0] = starter_msg_list[-1]
return "\n\n".join(overall_social_interaction)


def episodes_to_csv(
episodes: list[EpisodeLog], csv_file_path: str = "episodes.csv"
) -> None:
"""Save episodes to a csv file.

Args:
episodes (list[EpisodeLog]): List of episodes.
filepath (str, optional): The file path. Defaults to "episodes.csv".
"""
data = {
"episode_id": [episode.pk for episode in episodes],
"scenario": [
get_scenario_from_episode(episode) for episode in episodes
],
"codename": [
get_codename_from_episode(episode) for episode in episodes
],
"agents_background": [
get_agents_background_from_episode(episode) for episode in episodes
],
"social_goals": [
get_agent_name_to_social_goal_from_episode(episode)
for episode in episodes
],
"social_interactions": [
get_social_interactions_from_episode(episode)
for episode in episodes
],
}
df = pd.DataFrame(data)
df.to_csv(csv_file_path, index=False)


def episodes_to_json(
episodes: list[EpisodeLog], jsonl_file_path: str = "episodes.jsonl"
) -> None:
"""Save episodes to a json file.

Args:
episodes (list[EpisodeLog]): List of episodes.
filepath (str, optional): The file path. Defaults to "episodes.json".
"""
with open(jsonl_file_path, "w") as f:
for episode in episodes:
data = TwoAgentEpisodeWithScenarioBackgroundGoals(
episode_id=episode.pk,
scenario=get_scenario_from_episode(episode),
codename=get_codename_from_episode(episode),
agents_background=get_agents_background_from_episode(episode),
social_goals=get_agent_name_to_social_goal_from_episode(
episode
),
social_interactions=get_social_interactions_from_episode(
episode
),
reasoning=episode.reasoning,
rewards=get_rewards_from_episode(episode),
)
json.dump(dict(data), f)
f.write("\n")


def jsonl_to_episodes(
jsonl_file_path: str,
) -> list[TwoAgentEpisodeWithScenarioBackgroundGoals]:
"""Load episodes from a jsonl file.

Args:
jsonl_file_path (str): The file path.

Returns:
list[TwoAgentEpisodeWithScenarioBackgroundGoals]: List of episodes.
"""
episodes = []
with open(jsonl_file_path, "r") as f:
for line in f:
data = json.loads(line)
episode = TwoAgentEpisodeWithScenarioBackgroundGoals(**data)
episodes.append(episode)
return episodes
53 changes: 29 additions & 24 deletions tests/database/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,30 +54,8 @@ def _test_create_episode_log_setup_and_tear_down() -> Generator[
EpisodeLog.delete("tmppk_episode_log")


def test_get_agent_by_name(
_test_create_episode_log_setup_and_tear_down: Any,
) -> None:
agent_profile = AgentProfile.find(AgentProfile.first_name == "John").all()
assert agent_profile[0].pk == "tmppk_agent1"


def test_create_episode_log(
_test_create_episode_log_setup_and_tear_down: Any,
) -> None:
try:
_ = EpisodeLog(
environment="",
agents=["", ""],
messages=[],
rewards=[[0, 0, 0]],
reasoning=[""],
rewards_prompt="",
)
assert False
except Exception as e:
assert isinstance(e, ValidationError)

episode_log = EpisodeLog(
def create_dummy_episode_log() -> EpisodeLog:
episode = EpisodeLog(
environment="env",
agents=["tmppk_agent1", "tmppk_agent2"],
messages=[
Expand Down Expand Up @@ -126,6 +104,33 @@ def test_create_episode_log(
pk="tmppk_episode_log",
rewards_prompt="",
)
return episode


def test_get_agent_by_name(
_test_create_episode_log_setup_and_tear_down: Any,
) -> None:
agent_profile = AgentProfile.find(AgentProfile.first_name == "John").all()
assert agent_profile[0].pk == "tmppk_agent1"


def test_create_episode_log(
_test_create_episode_log_setup_and_tear_down: Any,
) -> None:
try:
_ = EpisodeLog(
environment="",
agents=["", ""],
messages=[],
rewards=[[0, 0, 0]],
reasoning=[""],
rewards_prompt="",
)
assert False
except Exception as e:
assert isinstance(e, ValidationError)

episode_log = create_dummy_episode_log()
episode_log.save()
assert episode_log.pk == "tmppk_episode_log"
retrieved_episode_log: EpisodeLog = EpisodeLog.get(episode_log.pk)
Expand Down
Loading