Skip to content

Commit

Permalink
Merge pull request #18 from sotopia-lab/feature/add_finetuned_model
Browse files Browse the repository at this point in the history
Add new model version for gpt-3.5-turbo-finetuned
  • Loading branch information
XuhuiZhou authored Feb 4, 2024
2 parents eac12c8 + 4a35410 commit bfc2554
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 14 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,3 @@ deprecated/*

#backup
backup/*

scripts/*
90 changes: 83 additions & 7 deletions examples/generate_scenarios.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import ast
import asyncio
import json
import random
from typing import Any, cast

import pandas as pd
import typer
from experiment_eval import _sample_env_agent_combo_and_push_to_db
from redis_om import Migrator

from sotopia.database import EnvironmentProfile
from sotopia.database import EnvAgentComboStorage, EnvironmentProfile
from sotopia.database.persistent_profile import RelationshipType
from sotopia.generation_utils import (
LLM_Name,
Expand Down Expand Up @@ -37,11 +39,15 @@ def add_env_profiles(
def check_existing_envs(
env_profile: dict[str, Any], existing_envs: pd.DataFrame
) -> bool:
if (
env_profile["scenario"] in existing_envs["scenario"].to_list()
and str(env_profile["agent_goals"])
in existing_envs["agent_goals"].to_list()
):
try:
if (
env_profile["scenario"] in existing_envs["scenario"].to_list()
and str(env_profile["agent_goals"])
in existing_envs["agent_goals"].to_list()
):
return False
except KeyError:
print(env_profile)
return False
return True

Expand All @@ -50,7 +56,7 @@ def generate_newenv_profile(
num: int,
gen_model: LLM_Name = "gpt-4-turbo",
temperature: float = 0.5,
type: str = "mutual_friend",
type: str = "craigslist_bargains",
) -> pd.DataFrame:
env_profile_list = [] # type: ignore
existing_envs = pd.read_csv(
Expand All @@ -70,6 +76,22 @@ def generate_newenv_profile(
}
if check_existing_envs(env_profile, existing_envs):
env_profile_list.append(env_profile)
elif type == "craigslist_bargains":
while len(env_profile_list) < num:
scenario, social_goals = asyncio.run(
generate_craigslist_bargains_envs()
)
env_profile = {
"codename": f"craigslist_bargains_{len(env_profile_list)+10}",
"scenario": scenario,
"agent_goals": social_goals,
"relationship": RelationshipType.stranger,
"age_constraint": "[(18, 80), (18, 80)]",
"occupation_constraint": None,
"source": "craigslist_bargains",
}
if check_existing_envs(env_profile, existing_envs):
env_profile_list.append(env_profile)
else:
raise NotImplementedError("Only mutual_friend is supported for now")
return pd.DataFrame(env_profile_list)
Expand Down Expand Up @@ -116,5 +138,59 @@ def auto_generate_scenarios(
Migrator().run()


@app.command()
def clean_env_wo_combos() -> None:
"""
Function to clean up env-agent combos in the database
"""
env_agent_combos = list(EnvAgentComboStorage.all_pks())
envs_id_in_combos = set(
[
EnvAgentComboStorage.get(env_agent_combo).env_id
for env_agent_combo in env_agent_combos
]
)
envs = list(EnvironmentProfile.all_pks())
for env in envs:
if env not in envs_id_in_combos:
EnvironmentProfile.delete(env)


@app.command()
def upload_env_profiles(
filepath: str = "./data/all_environment_profile.json",
) -> None:
"""
Function to upload environment profiles from json file
The json file format is a direct dump from the database
"""
env_profile_list = []
existing_envs = pd.read_csv(
"./data/env_profiles_v1.csv"
) # TODO: find a better way to deal with this
current_envs = json.load(open(filepath, "r"))
for key in current_envs:
env_profile = current_envs[key]
if env_profile and check_existing_envs(env_profile, existing_envs):
del env_profile["pk"]
env_profile_list.append(env_profile)
env_profiles = add_env_profiles(env_profile_list)
print("New env profiles added to database:")
print(len(env_profiles))

count = 0
for env_profile in env_profiles:
assert env_profile.pk is not None
try:
_sample_env_agent_combo_and_push_to_db(env_profile.pk)
count += 1
except:
EnvironmentProfile.delete(env_profile.pk)
pass
print(f"New env-agent combo added to database: {count}")

Migrator().run()


if __name__ == "__main__":
app()
16 changes: 16 additions & 0 deletions scripts/evaluate_finetuned_MF.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
MODEL_NAME_1=gpt-3.5-turbo-ft-MF
MODEL_NAME_2=gpt-3.5-turbo

python examples/experiment_eval.py \
--gin_file sotopia_conf/generation_utils_conf/generate.gin \
--gin_file sotopia_conf/server_conf/server.gin \
--gin_file sotopia_conf/run_async_server_in_batch.gin \
"--gin.ENV_IDS=['01H7VFHPKA2GGPPNVJWV967HZC', '01H7VFHPHWA2CYG7BC82NS4XH1', '01H7VFHPH567HKQRE0C745KH9C', '01H7VFHPMS6AJY0PFGGCFFK5GX', '01H7VFHPJKR16MD1KC71V4ZRCF', '01H7VFHPQ1712DHGTMPQFTXH02', '01H7VFHPP9SPQ8W6583JFZ7HZC', '01H7VFHPM3NVVKSGCCB4S10465', '01H7VFHPGABSWQXTACCC8C3X2F', '01H7VFHPNHZ2YYRHP0GXARD550']" \
"--gin.AGENT1_MODEL=\"${MODEL_NAME_1}\"" \
"--gin.AGENT2_MODEL=\"${MODEL_NAME_2}\"" \
'--gin.BATCH_SIZE=1' \
'--gin.TAG="finetuned_gpt3.5_gpt3.5ft_MF"' \
'--gin.TAG_TO_CHECK_EXISTING_EPISODES="finetuned_gpt3.5_gpt3.5ft_MF"' \
'--gin.PUSH_TO_DB=True' \
'--gin.VERBOSE=False' \
'--gin.LITE=False' \
15 changes: 15 additions & 0 deletions scripts/evaluate_finetuned_full.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
MODEL_NAME=gpt-3.5-turbo-finetuned

python examples/experiment_eval.py \
--gin_file sotopia_conf/generation_utils_conf/generate.gin \
--gin_file sotopia_conf/server_conf/server.gin \
--gin_file sotopia_conf/run_async_server_in_batch.gin \
'--gin.ENV_IDS=[]' \
"--gin.AGENT1_MODEL=\"${MODEL_NAME}\"" \
"--gin.AGENT2_MODEL=\"${MODEL_NAME}\"" \
'--gin.BATCH_SIZE=5' \
'--gin.TAG="finetuned_gpt3.5"' \
'--gin.TAG_TO_CHECK_EXISTING_EPISODES="finetuned_gpt3.5"' \
'--gin.PUSH_TO_DB=True' \
'--gin.VERBOSE=False' \
'--gin.LITE=False' \
7 changes: 7 additions & 0 deletions exp_scripts/exp_instruction.md → scripts/exp_instruction.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Agent vs Storyteller Scripts

### Basic Scripts
Here are some of the script for running {gpt-3.5-turbo, mixtral-7b-moe} under {normal interaction, omniscient interaction, script generation} mode in {normal, lite} setting.
If you need to run all interaction mode, you can use `run_all.sh`, the usage is `Usage: ./run_all.sh <model_name> <tag_base> <lite>`. For example, `./run_all.sh gpt-3.5-turbo exp0128 True`. You may find model_name in `LLM_Name`, and currently we are using `mistralai/Mixtral-8x7B-Instruct-v0.1` and `gpt-3.5-turbo`.
If you want to run mode separately, you can use `run_interaction.sh` or `run_script_full.sh`.
After running the above script, you may specify tags and fix those error episodes using `./fix_missing_episodes_with_tag.sh`.
Current `fix_missing_episodes_with_tag.py` first detects erroneous episodes, delete them and regenerate them.

### Fine-tuning

* `evaluate_finetuned_full.sh`: evaluate the fine-tuned model (gpt-3.5 finetuned on the full dataset) on the sotopia lite setting.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
11 changes: 6 additions & 5 deletions sotopia/generation_utils/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
"togethercomputer/llama-2-70b-chat",
"togethercomputer/mpt-30b-chat",
"gpt-3.5-turbo",
"gpt-3.5-turbo-finetuned",
"gpt-3.5-turbo-ft-MF",
"text-davinci-003",
"gpt-4",
"gpt-4-turbo",
Expand Down Expand Up @@ -290,11 +292,11 @@ def _type(self) -> str:
return "str"


def _return_fixed_model_version(
model_name: Literal["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo"]
) -> str:
def _return_fixed_model_version(model_name: LLM_Name) -> str:
return {
"gpt-3.5-turbo": "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-finetuned": "ft:gpt-3.5-turbo-0613:academicscmu::8nY2zgdt",
"gpt-3.5-turbo-ft-MF": "ft:gpt-3.5-turbo-0613:academicscmu::8nuER4bO",
"gpt-4": "gpt-4-0613",
"gpt-4-turbo": "gpt-4-1106-preview",
}[model_name]
Expand All @@ -313,7 +315,7 @@ def obtain_chain(
Using langchain to sample profiles for participants
"""
match model_name:
case "gpt-3.5-turbo" | "gpt-4" | "gpt-4-turbo":
case "gpt-3.5-turbo" | "gpt-4" | "gpt-4-turbo" | "gpt-3.5-turbo-finetuned" | "gpt-3.5-turbo-ft-MF":
human_message_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(
template=template,
Expand Down Expand Up @@ -781,7 +783,6 @@ async def agenerate_action(
Your action should follow the given format:
{format_instructions}
"""

return await agenerate(
model_name=model_name,
template=template,
Expand Down

0 comments on commit bfc2554

Please sign in to comment.