Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
parkererickson-tg committed Jun 5, 2024
1 parent 133878a commit a061ffa
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 5 deletions.
9 changes: 8 additions & 1 deletion app/agent/agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,14 @@ def create_graph(self):
}
)
else:
self.workflow.add_edge("generate_function", "generate_answer")
self.workflow.add_conditional_edges(
"generate_function",
self.check_state_for_generation_error,
{
"error": "rewrite_question",
"success": "generate_answer"
}
)
if self.supportai_enabled:
self.workflow.add_conditional_edges(
"generate_answer",
Expand Down
2 changes: 1 addition & 1 deletion app/py_schemas/tool_io_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AgentOutput(BaseModel):


class MapAttributeToAttributeResponse(BaseModel):
attr_map: Dict[str, str] = Field(
attr_map: Optional[Dict[str, str]] = Field(
description="The dictionary of the form {'source_attribute': 'output_attribute'}"
)

Expand Down
2 changes: 1 addition & 1 deletion app/tools/map_question_to_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _run(self, query: str) -> str:
)[0]["text"]
parsed_map = attr_parser.invoke(map_attr).attr_map
parsed_q.target_vertex_attributes[vertex] = [
parsed_map[x] for x in list(parsed_q.target_vertex_attributes[vertex])
parsed_map.get(x) for x in list(parsed_q.target_vertex_attributes[vertex])
]

logger.debug(f"request_id={req_id_cv.get()} MapVertexAttributes applied")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_huggingface_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def setUpClass(cls) -> None:

def test_config_read(self):
resp = self.client.get("/")
self.assertEqual(resp.json()["config"], "microsoft/Phi-3-mini-4k-instruct")
self.assertEqual(resp.json()["config"], "phi3")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def query_registration_test_generator(
def test(self):
resp = self.client.post(
"/" + dataset + "/upsert_docs",
json=json.loads({"id": "", "query_info": query_json}),
json={"id": "", "query_info": query_json},
auth=(username, password),
)
self.assertEqual(resp.status_code, 200)
Expand Down

0 comments on commit a061ffa

Please sign in to comment.