Skip to content

Commit

Permalink
chore(wren-ai-service): eval data curation app updates (#565)
Browse files Browse the repository at this point in the history
* fix contexts

* save documents based on contexts

* fix test

* remove deepcopy
  • Loading branch information
cyyeh authored Aug 1, 2024
1 parent d14a17d commit efe90ff
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 11 deletions.
10 changes: 10 additions & 0 deletions wren-ai-service/eval/data_curation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DATA_SOURCES,
get_contexts_from_sqls,
get_data_from_wren_engine_with_sqls,
get_documents_given_contexts,
get_eval_dataset_in_toml_string,
get_openai_client,
get_question_sql_pairs,
Expand Down Expand Up @@ -54,6 +55,7 @@
# widget callbacks
def on_change_upload_eval_dataset():
doc = tomlkit.parse(st.session_state.uploaded_eval_file.getvalue().decode("utf-8"))

assert (
doc["mdl"] == st.session_state["mdl_json"]
), "The model in the uploaded dataset is different from the deployed model"
Expand Down Expand Up @@ -130,18 +132,23 @@ def on_change_sql(i: int, key: str):
new_context = asyncio.run(
get_contexts_from_sqls([sql], st.session_state["mdl_json"])
)[0]
document = get_documents_given_contexts(
[new_context], st.session_state["mdl_json"]
)
if i != -1:
st.session_state["llm_question_sql_pairs"][i]["sql"] = sql
st.session_state["llm_question_sql_pairs"][i]["is_valid"] = valid
st.session_state["llm_question_sql_pairs"][i]["error"] = error
if valid:
st.session_state["llm_question_sql_pairs"][i]["context"] = new_context
st.session_state["llm_question_sql_pairs"][i]["document"] = document
else:
st.session_state["user_question_sql_pair"]["sql"] = sql
st.session_state["user_question_sql_pair"]["is_valid"] = valid
st.session_state["user_question_sql_pair"]["error"] = error
if valid:
st.session_state["user_question_sql_pair"]["context"] = new_context
st.session_state["user_question_sql_pair"]["document"] = document


def on_click_add_candidate_dataset(i: int, categories: list):
Expand All @@ -151,13 +158,15 @@ def on_click_add_candidate_dataset(i: int, categories: list):
"question": st.session_state["llm_question_sql_pairs"][i]["question"],
"context": st.session_state["llm_question_sql_pairs"][i]["context"],
"sql": st.session_state["llm_question_sql_pairs"][i]["sql"],
"document": st.session_state["llm_question_sql_pairs"][i]["document"],
}
else:
dataset_to_add = {
"categories": categories,
"question": st.session_state["user_question_sql_pair"]["question"],
"context": st.session_state["user_question_sql_pair"]["context"],
"sql": st.session_state["user_question_sql_pair"]["sql"],
"document": st.session_state["user_question_sql_pair"]["document"],
}

# reset input for user question sql pair
Expand All @@ -180,6 +189,7 @@ def on_change_user_question():
st.session_state["user_question_sql_pair"] = {
"question": st.session_state["user_question"],
"context": [],
"document": [],
"sql": "",
"is_valid": False,
"error": "",
Expand Down
124 changes: 121 additions & 3 deletions wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,118 @@ async def get_contexts_from_sqls(
return results


def get_documents_given_contexts(
contexts_list: list[list[str]], mdl_json: dict
) -> list[list[dict]]:
def _build_partial_mdl_json(
contexts_list: list[list[str]], mdl_json: dict
) -> list[dict]:
mdj_json_model_lookup_table = {
model["name"]: {
**model,
"column_lookup": {
column["name"]: column
for column in model["columns"]
if "relationship" not in column
},
"relationship_lookup": {
column["relationship"]: column
for column in model["columns"]
if "relationship" in column
},
}
for model in mdl_json["models"]
}

new_mdl_jsons = []
for contexts in contexts_list:
model_candidates = {}
relationship_candidates = []
for context in contexts:
table_name, column_name = context.split(".")
model = mdj_json_model_lookup_table.get(table_name)
if model:
if table_name not in model_candidates:
model_candidates[table_name] = {
"name": model["name"],
"properties": model["properties"],
"tableReference": model["tableReference"],
"primaryKey": model["primaryKey"],
"columns": [],
}

# add column info
column = mdj_json_model_lookup_table[table_name]["column_lookup"][
column_name
]
model_candidates[table_name]["columns"].append(column)

contexts_in_set = set(contexts)
for relationship in mdl_json["relationships"]:
relationship_name = relationship["name"]
condition_str = "".join(
relationship["condition"].split()
) # remove all whitespaces
conditions = condition_str.split("=")
if (
conditions[0] in contexts_in_set
and conditions[1] in contexts_in_set
):
table_name_first_condition = conditions[0].split(".")[0]
table_name_second_condition = conditions[1].split(".")[0]
# add relationship column info
if (
relationship_column := mdj_json_model_lookup_table.get(
table_name_first_condition, {}
)
.get("relationship_lookup", {})
.get(relationship_name, {})
):
model_candidates[table_name_first_condition]["columns"].append(
relationship_column
)
elif (
relationship_column := mdj_json_model_lookup_table.get(
table_name_second_condition, {}
)
.get("relationship_lookup", {})
.get(relationship_name, {})
):
model_candidates[table_name_second_condition]["columns"].append(
relationship_column
)

# add relationship info
relationship_candidates.append(relationship)

new_mdl_jsons.append(
{
"models": list(model_candidates.values()),
"relationships": relationship_candidates,
"views": [],
"metrics": [],
}
)

return new_mdl_jsons

new_mdl_jsons = _build_partial_mdl_json(contexts_list, mdl_json)

return [
[
{
"id": str(i),
"meta": {"id": str(i)},
"content": ddl_command,
}
for i, ddl_command in enumerate(
ddl_converter.get_ddl_commands(new_mdl_json)
)
]
for new_mdl_json in new_mdl_jsons
]


async def get_question_sql_pairs(
llm_client: AsyncClient,
llm_model: str,
Expand Down Expand Up @@ -186,13 +298,19 @@ async def get_question_sql_pairs(
)
sqls = [question_sql_pair["sql"] for question_sql_pair in question_sql_pairs]
contexts = await get_contexts_from_sqls(sqls, mdl_json)
documents = get_documents_given_contexts(contexts, mdl_json)
sqls_data = await get_data_from_wren_engine_with_sqls(
sqls, data_source, mdl_json, connection_info
)
return [
{**quesiton_sql_pair, "context": context, "data": sql_data}
for quesiton_sql_pair, context, sql_data in zip(
question_sql_pairs, contexts, sqls_data
{
**quesiton_sql_pair,
"context": context,
"data": sql_data,
"document": document,
}
for quesiton_sql_pair, context, sql_data, document in zip(
question_sql_pairs, contexts, sqls_data, documents
)
]
except Exception as e:
Expand Down
28 changes: 22 additions & 6 deletions wren-ai-service/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def get_contexts_from_sql(
def _get_contexts_from_sql_analysis_results(sql_analysis_results: list[dict]):
def _compose_contexts_of_select_type(select_items: list[dict]):
return [
f'{expr_source['sourceDataset']}.{expr_source['expression']}'
f'{expr_source['sourceDataset']}.{expr_source['sourceColumn']}'
for select_item in select_items
for expr_source in select_item["exprSources"]
]
Expand All @@ -62,7 +62,7 @@ def _compose_contexts_of_filter_type(filter: dict):
contexts = []
if filter["type"] == "EXPR":
contexts += [
f'{expr_source["sourceDataset"]}.{expr_source["expression"]}'
f'{expr_source["sourceDataset"]}.{expr_source["sourceColumn"]}'
for expr_source in filter["exprSources"]
]
elif filter["type"] in ("AND", "OR"):
Expand All @@ -75,19 +75,32 @@ def _compose_contexts_of_groupby_type(groupby_keys: list[list[dict]]):
contexts = []
for groupby_key_list in groupby_keys:
contexts += [
f'{expr_source["sourceDataset"]}.{expr_source["expression"]}'
f'{expr_source["sourceDataset"]}.{expr_source["sourceColumn"]}'
for groupby_key in groupby_key_list
for expr_source in groupby_key["exprSources"]
]
return contexts

def _compose_contexts_of_sorting_type(sortings: list[dict]):
return [
f'{expr_source["sourceDataset"]}.{expr_source["expression"]}'
f'{expr_source["sourceDataset"]}.{expr_source["sourceColumn"]}'
for sorting in sortings
for expr_source in sorting["exprSources"]
]

def _compose_contexts_of_relation_type(relation: dict):
contexts = []
if relation["type"] != "TABLE":
contexts += [
f'{expr_source["sourceDataset"]}.{expr_source["sourceColumn"]}'
for expr_source in relation["exprSources"]
]

contexts += _compose_contexts_of_relation_type(relation["left"])
contexts += _compose_contexts_of_relation_type(relation["right"])

return contexts

contexts = []
for result in sql_analysis_results:
if "selectItems" in result:
Expand All @@ -98,6 +111,8 @@ def _compose_contexts_of_sorting_type(sortings: list[dict]):
contexts += _compose_contexts_of_groupby_type(result["groupByKeys"])
if "sortings" in result:
contexts += _compose_contexts_of_sorting_type(result["sortings"])
if "relation" in result:
contexts += _compose_contexts_of_relation_type(result["relation"])

return sorted(set(contexts))

Expand All @@ -119,8 +134,9 @@ async def _get_sql_analysis(
) as response:
return await response.json()

contexts = await _get_sql_analysis(sql, mdl_json, api_endpoint, timeout)
return _get_contexts_from_sql_analysis_results(contexts)
sql_analysis_results = await _get_sql_analysis(sql, mdl_json, api_endpoint, timeout)
contexts = _get_contexts_from_sql_analysis_results(sql_analysis_results)
return contexts


def parse_toml(path: str) -> Dict[str, Any]:
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/tests/pytest/eval/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def _success_analysis_sql(m, engine_config, repeat=1):
"selectItems": [
{
"exprSources": [
{"sourceDataset": "t", "expression": "foo"},
{"sourceDataset": "t", "expression": "boo"},
{"sourceDataset": "t", "sourceColumn": "foo"},
{"sourceDataset": "t", "sourceColumn": "boo"},
]
}
]
Expand Down

0 comments on commit efe90ff

Please sign in to comment.