Skip to content

Commit

Permalink
[Evals API] eval client demo scripts (#103)
Browse files Browse the repository at this point in the history
* scoring client

* test dataset resource

* eval client

* example eval client
  • Loading branch information
yanxi0830 authored Oct 31, 2024
1 parent 2c86055 commit 13f4e83
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 0 deletions.
88 changes: 88 additions & 0 deletions examples/eval/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import asyncio
import base64
import mimetypes
import os

import fire

from llama_stack_client import LlamaStackClient
from termcolor import cprint


def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")

with open(file_path, "rb") as file:
file_content = file.read()

base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)

data_url = f"data:{mime_type};base64,{base64_content}"

return data_url


async def run_main(host: str, port: int, file_path: str):
client = LlamaStackClient(
base_url=f"http://{host}:{port}",
)

providers = client.providers.list()
dataset_url = data_url_from_file(file_path)

client.datasets.register(
dataset_def={
"identifier": "eval-dataset",
"provider_id": providers["datasetio"][0].provider_id,
"url": {"uri": dataset_url},
"dataset_schema": {
"expected_answer": {"type": "string"},
"input_query": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
}
)

datasets_list_response = client.datasets.list()
cprint([x.identifier for x in datasets_list_response], "cyan")

# test eval with individual rows
rows_paginated = client.datasetio.get_rows_paginated(
dataset_id="eval-dataset",
rows_in_page=3,
page_token=None,
filter_condition=None,
)
print(rows_paginated)

eval_candidate = {
"type": "model",
"model": "Llama3.2-1B-Instruct",
"sampling_params": {
"strategy": "greedy",
"temperature": 0,
"top_p": 0.95,
"top_k": 0,
"max_tokens": 0,
"repetition_penalty": 1.0,
},
}
eval_rows = client.eval.evaluate(
input_rows=rows_paginated.rows,
candidate=eval_candidate,
scoring_functions=[
"meta-reference::subset_of",
"meta-reference::llm_as_judge_8b_correctness",
],
)
cprint(eval_rows, "green")


def main(host: str, port: int, file_path: str):
asyncio.run(run_main(host, port, file_path))


if __name__ == "__main__":
fire.Fire(main)
84 changes: 84 additions & 0 deletions examples/scoring/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import asyncio
import base64
import mimetypes
import os

import fire

from llama_stack_client import LlamaStackClient
from termcolor import cprint


def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")

with open(file_path, "rb") as file:
file_content = file.read()

base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)

data_url = f"data:{mime_type};base64,{base64_content}"

return data_url


async def run_main(host: str, port: int, file_path: str):
client = LlamaStackClient(
base_url=f"http://{host}:{port}",
)

providers = client.providers.list()
dataset_url = data_url_from_file(file_path)

client.datasets.register(
dataset_def={
"identifier": "test-dataset",
"provider_id": providers["datasetio"][0].provider_id,
"url": {"uri": dataset_url},
"dataset_schema": {
"generated_answer": {"type": "string"},
"expected_answer": {"type": "string"},
"input_query": {"type": "string"},
},
}
)

datasets_list_response = client.datasets.list()
cprint([x.identifier for x in datasets_list_response], "cyan")

# test scoring with individual rows
rows_paginated = client.datasetio.get_rows_paginated(
dataset_id="test-dataset",
rows_in_page=3,
page_token=None,
filter_condition=None,
)

# check scoring functions available
score_fn_list = client.scoring_functions.list()
cprint([x.identifier for x in score_fn_list], "green")

score_rows = client.scoring.score(
input_rows=rows_paginated.rows,
scoring_functions=["meta-reference::equality"],
)
cprint(f"Score Rows: {score_rows}", "red")

# test scoring batch with dataset id
score_batch = client.scoring.score_batch(
dataset_id="test-dataset",
scoring_functions=[x.identifier for x in score_fn_list],
save_results_dataset=False,
)

cprint(f"Score Batch: {score_batch}", "yellow")


def main(host: str, port: int, file_path: str):
asyncio.run(run_main(host, port, file_path))


if __name__ == "__main__":
fire.Fire(main)
6 changes: 6 additions & 0 deletions examples/scoring/resources/test_dataset.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
input_query,generated_answer,expected_answer,chat_completion_input
What is the capital of France?,London,Paris,"[{'role': 'user', 'content': 'What is the capital of France?'}]"
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg,"[{'role': 'user', 'content': 'Who is the CEO of Meta?'}]"
What is the largest planet in our solar system?,Jupiter,Jupiter,"[{'role': 'user', 'content': 'What is the largest planet in our solar system?'}]"
What is the smallest country in the world?,China,Vatican City,"[{'role': 'user', 'content': 'What is the smallest country in the world?'}]"
What is the currency of Japan?,Yen,Yen,"[{'role': 'user', 'content': 'What is the currency of Japan?'}]"

0 comments on commit 13f4e83

Please sign in to comment.