diff --git a/app/agent/agent_graph.py b/app/agent/agent_graph.py index 7a54f487..c964ca89 100644 --- a/app/agent/agent_graph.py +++ b/app/agent/agent_graph.py @@ -298,7 +298,14 @@ def create_graph(self): } ) else: - self.workflow.add_edge("generate_function", "generate_answer") + self.workflow.add_conditional_edges( + "generate_function", + self.check_state_for_generation_error, + { + "error": "rewrite_question", + "success": "generate_answer" + } + ) if self.supportai_enabled: self.workflow.add_conditional_edges( "generate_answer", diff --git a/app/py_schemas/tool_io_schemas.py b/app/py_schemas/tool_io_schemas.py index 027e6388..8f2a51b0 100644 --- a/app/py_schemas/tool_io_schemas.py +++ b/app/py_schemas/tool_io_schemas.py @@ -34,7 +34,7 @@ class AgentOutput(BaseModel): class MapAttributeToAttributeResponse(BaseModel): - attr_map: Dict[str, str] = Field( + attr_map: Optional[Dict[str, str]] = Field( description="The dictionary of the form {'source_attribute': 'output_attribute'}" ) diff --git a/app/tools/map_question_to_schema.py b/app/tools/map_question_to_schema.py index 4f7dfbae..80d24dd8 100644 --- a/app/tools/map_question_to_schema.py +++ b/app/tools/map_question_to_schema.py @@ -133,7 +133,7 @@ def _run(self, query: str) -> str: )[0]["text"] parsed_map = attr_parser.invoke(map_attr).attr_map parsed_q.target_vertex_attributes[vertex] = [ - parsed_map[x] for x in list(parsed_q.target_vertex_attributes[vertex]) + parsed_map.get(x) for x in list(parsed_q.target_vertex_attributes[vertex]) ] logger.debug(f"request_id={req_id_cv.get()} MapVertexAttributes applied") diff --git a/tests/test_huggingface_phi3.py b/tests/test_huggingface_phi3.py index 943b2084..2515c1af 100644 --- a/tests/test_huggingface_phi3.py +++ b/tests/test_huggingface_phi3.py @@ -24,7 +24,7 @@ def setUpClass(cls) -> None: def test_config_read(self): resp = self.client.get("/") - self.assertEqual(resp.json()["config"], "microsoft/Phi-3-mini-4k-instruct") + self.assertEqual(resp.json()["config"], "phi3") if __name__ == "__main__": diff --git a/tests/test_service.py b/tests/test_service.py index 29fb7df7..1c16cd00 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -169,7 +169,7 @@ def query_registration_test_generator( def test(self): resp = self.client.post( "/" + dataset + "/upsert_docs", - json=json.loads({"id": "", "query_info": query_json}), + json={"id": "", "query_info": query_json}, auth=(username, password), ) self.assertEqual(resp.status_code, 200)