Skip to content

Commit

Permalink
Feature/select agent env (#45)
Browse files Browse the repository at this point in the history
* changed ui to include scenario and agent info

* ui layout correct; need to fix logics

* half-way through; need to fix record reading and agent pair filtering logic

* fixed deletion of app.py

* debugging gradio change

* before debug

* finished UI features

* added 5 times retry

* finished merging
  • Loading branch information
Jasonqi146 authored Apr 18, 2024
1 parent c423c55 commit c3a4051
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 31 deletions.
75 changes: 48 additions & 27 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
os.environ["OPENAI_API_KEY"] = f.read().strip()

DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
DEFAULT_MODEL_SELECTION = "cmu-lti/sotopia-pi-mistral-7b-BC_SR" # "mistralai/Mistral-7B-Instruct-v0.1"
DEFAULT_MODEL_SELECTION = "gpt-3.5-turbo" # "mistralai/Mistral-7B-Instruct-v0.1"
TEMPERATURE = 0.7
TOP_P = 1
MAX_TOKENS = 1024
Expand Down Expand Up @@ -100,6 +100,7 @@ def create_bot_agent_dropdown(environment_id, user_agent_id):
environment, user_agent = environment_dict[environment_id], agent_dict[user_agent_id]

bot_agent_list = []
# import pdb; pdb.set_trace()
for neighbor_id in relationship_dict[environment.relationship][user_agent.agent_id]:
bot_agent_list.append((agent_dict[neighbor_id].name, neighbor_id))

Expand All @@ -109,46 +110,62 @@ def create_environment_info(environment_dropdown):
_, environment_dict, _, _ = get_sotopia_profiles()
environment = environment_dict[environment_dropdown]
text = environment.scenario
return gr.Textbox(label="Scenario Information", lines=4, value=text)
return gr.Textbox(label="Scenario", lines=1, value=text)

def create_user_info(environment_dropdown, user_agent_dropdown):
_, environment_dict, agent_dict, _ = get_sotopia_profiles()
environment, user_agent = environment_dict[environment_dropdown], agent_dict[user_agent_dropdown]
text = f"{user_agent.background} {user_agent.personality} \n {environment.agent_goals[0]}"
def create_user_info(user_agent_dropdown):
_, _, agent_dict, _ = get_sotopia_profiles()
user_agent = agent_dict[user_agent_dropdown]
text = f"{user_agent.background} {user_agent.personality}"
return gr.Textbox(label="User Agent Profile", lines=4, value=text)

def create_bot_info(environment_dropdown, bot_agent_dropdown):
_, environment_dict, agent_dict, _ = get_sotopia_profiles()
environment, bot_agent = environment_dict[environment_dropdown], agent_dict[bot_agent_dropdown]
text = f"{bot_agent.background} {bot_agent.personality} \n {environment.agent_goals[1]}"
def create_bot_info(bot_agent_dropdown):
_, _, agent_dict, _ = get_sotopia_profiles()
# import pdb; pdb.set_trace()
bot_agent = agent_dict[bot_agent_dropdown]
text = f"{bot_agent.background} {bot_agent.personality}"
return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)

def create_user_goal(environment_dropdown):
_, environment_dict, _, _ = get_sotopia_profiles()
text = environment_dict[environment_dropdown].agent_goals[0]
return gr.Textbox(label="User Agent Goal", lines=4, value=text)

def create_bot_goal(environment_dropdown):
_, environment_dict, _, _ = get_sotopia_profiles()
text = environment_dict[environment_dropdown].agent_goals[1]
return gr.Textbox(label="Bot Agent Goal", lines=4, value=text)

def sotopia_info_accordion(accordion_visible=True):
environments, _, _, _ = get_sotopia_profiles()

with gr.Accordion("Sotopia Information", open=accordion_visible):
with gr.Column():
model_name_dropdown = gr.Dropdown(
choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo"],
value="cmu-lti/sotopia-pi-mistral-7b-BC_SR",
interactive=True,
label="Model Selection"
)
with gr.Accordion("Environment Configuration", open=accordion_visible):
with gr.Row():
environments, _, _, _ = get_sotopia_profiles()
environment_dropdown = gr.Dropdown(
choices=environments,
label="Scenario Selection",
value=environments[0][1] if environments else None,
interactive=True,
)
print(environment_dropdown.value)
model_name_dropdown = gr.Dropdown(
choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo", "gpt-4-turbo"],
value=DEFAULT_MODEL_SELECTION,
interactive=True,
label="Model Selection"
)

scenario_info_display = create_environment_info(environment_dropdown.value)

with gr.Row():
bot_goal_display = create_bot_goal(environment_dropdown.value)
user_goal_display = create_user_goal(environment_dropdown.value)

with gr.Row():
user_agent_dropdown = create_user_agent_dropdown(environment_dropdown.value)
bot_agent_dropdown = create_bot_agent_dropdown(environment_dropdown.value, user_agent_dropdown.value)

with gr.Row():
scenario_info_display = create_environment_info(environment_dropdown.value)
user_agent_info_display = create_user_info(environment_dropdown.value, user_agent_dropdown.value)
bot_agent_info_display = create_bot_info(environment_dropdown.value, bot_agent_dropdown.value)
user_agent_info_display = create_user_info(user_agent_dropdown.value)
bot_agent_info_display = create_bot_info(bot_agent_dropdown.value)

# Update user dropdown when scenario changes
environment_dropdown.change(fn=create_user_agent_dropdown, inputs=[environment_dropdown], outputs=[user_agent_dropdown])
Expand All @@ -157,9 +174,13 @@ def sotopia_info_accordion(accordion_visible=True):
# Update scenario information when scenario changes
environment_dropdown.change(fn=create_environment_info, inputs=[environment_dropdown], outputs=[scenario_info_display])
# Update user agent profile when user changes
user_agent_dropdown.change(fn=create_user_info, inputs=[environment_dropdown, user_agent_dropdown], outputs=[user_agent_info_display])
user_agent_dropdown.change(fn=create_user_info, inputs=[user_agent_dropdown], outputs=[user_agent_info_display])
# Update bot agent profile when bot changes
bot_agent_dropdown.change(fn=create_bot_info, inputs=[environment_dropdown, bot_agent_dropdown], outputs=[bot_agent_info_display])
bot_agent_dropdown.change(fn=create_bot_info, inputs=[bot_agent_dropdown], outputs=[bot_agent_info_display])
# Update user goal when scenario changes
environment_dropdown.change(fn=create_user_goal, inputs=[environment_dropdown], outputs=[user_goal_display])
# Update bot goal when scenario changes
environment_dropdown.change(fn=create_bot_goal, inputs=[environment_dropdown], outputs=[bot_goal_display])

return model_name_dropdown, environment_dropdown, user_agent_dropdown, bot_agent_dropdown

Expand Down Expand Up @@ -192,12 +213,12 @@ def run_chat(
user_agent = agent_dict[user_agent_dropdown]
bot_agent = agent_dict[bot_agent_dropdown]

import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
context = get_context_prompt(bot_agent, user_agent, environment)
dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
prompt_history = f"{context}\n\n{dialogue_history}"
agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
return agent_action.to_natural_language()

with gr.Column():
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
sotopia
gradio
transformers
torch
Expand Down
6 changes: 3 additions & 3 deletions sotopia_pi_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def obtain_chain_hf(
model, tokenizer = prepare_model(model_name)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_tokens, temperature=temperature)
hf = HuggingFacePipeline(pipeline=pipe)
import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
chain = LLMChain(llm=hf, prompt=chat_prompt_template)
return chain

Expand All @@ -124,7 +124,7 @@ def generate(
output_parser: BaseOutputParser[OutputType],
temperature: float = 0.7,
) -> tuple[OutputType, str]:
import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
input_variables = re.findall(r"{(.*?)}", template)
assert (
set(input_variables) == set(list(input_values.keys()) + ["format_instructions"])
Expand All @@ -136,7 +136,7 @@ def generate(
if "format_instructions" not in input_values:
input_values["format_instructions"] = output_parser.get_format_instructions()
result = chain.predict([], **input_values)
import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
try:
parsed_result = output_parser.parse(result)
except KeyboardInterrupt:
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def truncate_dialogue_history_to_length(dia_his, surpass_num, tokenizer):


def format_bot_message(bot_message) -> str:
# import pdb; pdb.set_trace()
# # import pdb; pdb.set_trace()
start_idx, end_idx = bot_message.index("{"), bot_message.index("}")
if end_idx == -1:
bot_message += "'}"
Expand Down

0 comments on commit c3a4051

Please sign in to comment.