Skip to content

Commit

Permalink
Add results collection script (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
zakharova-anastasiia authored Dec 22, 2024
1 parent 5880b84 commit f8ca1a7
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 6 deletions.
98 changes: 92 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

102 changes: 102 additions & 0 deletions protollm/ensembles_ma/collect_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from enum import Enum
import uuid
import requests
from websockets import ConnectionClosed
import websockets.sync.client as wsclient
import json
import click
import pandas as pd


class AnswerType(str, Enum):
RETRIEVAL = 'retrieval'
ANSWER = 'answer'
ERROR = 'error'

def parse_ws_response(response):
response_body = json.loads(response)
match (name := response_body.get('name')):
case AnswerType.RETRIEVAL | AnswerType.ANSWER:
return name, response_body['result']
case AnswerType.ERROR:
raise Exception(response_body.get('detail'))
case _:
return None, None


@click.command()
@click.option("--basepath", type=str, default="0.0.0.0:8080")
@click.option("--output", type=str, default="agents_responses.csv")
def main(basepath, output):
run_uid = str(uuid.uuid4())
questions = [
"Какой объем финансирования программы политики защиты окружающей среды",
"Какая разница в объеме финансирования программ защиты окружающей среды и образования?",
"Кто ответственный исполнитель программы политики защиты окружающей среды?",
"Какие целевые показатели госпрограмм по образованию и защите окружающей среды?",
"Сколько подпрограмм в госполитике по защите окружающей среды?",
"Какие приоритеты политики в плане обращений с твердыми коммунальными отходами?",
"какой объем финансирования программы образования в 2017 году?",
]

response = requests.get(f"http://{basepath}/", params={"agent_type": "streaming"})
assert response.status_code == 200, "Failed to get agents"
response = response.json()
assert len(response) > 0, "No agents found"
agents_ids = {agent['agent_id']: agent['name'] for agent in response if "rag" in agent['name']}

agents_responses = list()
for question in questions:
question_columns = dict(question=question)
for agent_id, agent_name in agents_ids.items():
click.echo(f"Collecting response from {agent_name}, {question=}")
with wsclient.connect(f"ws://{basepath}/agent") as ws:
request_payload = {
"dialogue_id": run_uid,
"agent_id": agent_id,
"chat_history":[],
"query": question,
"run_params": {}
}
ws.send(json.dumps(request_payload))
try:
while True:
response = ws.recv()
response_type, response_data = parse_ws_response(response)
if response_type == AnswerType.RETRIEVAL:
question_columns[f'docs_{agent_name}'] = response_data
elif response_type == AnswerType.ANSWER:
question_columns[f'answer_{agent_name}'] = response_data
except ConnectionClosed:
pass
click.echo("Finished collecting RAG agents responses")

for endpoint in ('router', 'ensemble'):
click.echo(f"Collecting response from {endpoint}, {question=}")
with wsclient.connect(f"ws://{basepath}/{endpoint}") as ws:
request_payload = {
"dialogue_id": run_uid,
"chat_history":[],
"query": question,
}
ws.send(json.dumps(request_payload))
try:
while True:
response = ws.recv()
response_type, response_data = parse_ws_response(response)
if response_type == AnswerType.RETRIEVAL:
question_columns[f'docs_{endpoint}'] = response_data
elif response_type == AnswerType.ANSWER:
question_columns[f'answer_{endpoint}'] = response_data
except ConnectionClosed:
pass
click.echo("Finished collecting router and ensemble responses")
click.echo(f"Collected question_columns: {question_columns}")
agents_responses.append(question_columns)
click.echo("Finished collecting agents responses")

df = pd.DataFrame().from_records(agents_responses)
df.to_csv(output, index=False)

if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ tornado = "^6.4.1"
[tool.poetry.group.dev.dependencies]
pytest = "^8.2.2"
pytest-asyncio = "^0.24.0"
pandas = "^2.2.3"

[build-system]
requires = ["poetry-core"]
Expand Down

0 comments on commit f8ca1a7

Please sign in to comment.