Skip to content

Commit

Permalink
feat(map schema to question): add attribute mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
parkererickson-tg committed Jan 22, 2024
1 parent c92b392 commit 5f8c2ad
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
5 changes: 4 additions & 1 deletion app/schemas/tool_io_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@ class MapQuestionToSchemaResponse(BaseModel):

class AgentOutput(BaseModel):
answer: str = Field(description="Natural language answer generated")
function_call: str = Field(description="Function call used to generate answer")
function_call: str = Field(description="Function call used to generate answer")

class MapAttributeToAttributeResponse(BaseModel):
attr_map: Dict[str, str] = Field(description="The dictionary of the form {'source_attribute': 'output_attribute'}")
29 changes: 28 additions & 1 deletion app/tools/map_schema_to_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langchain.output_parsers import PydanticOutputParser
from pyTigerGraph import TigerGraphConnection
from langchain.pydantic_v1 import BaseModel, Field, validator
from app.schemas import MapQuestionToSchemaResponse
from app.schemas import MapQuestionToSchemaResponse, MapAttributeToAttributeResponse
from typing import List, Dict
from .validation_utils import validate_schema, MapQuestionToSchemaException
import re
Expand Down Expand Up @@ -69,6 +69,33 @@ def _run(self, query: str) -> str:

logger.debug_pii(f"request_id={req_id_cv.get()} MapQuestionToSchema parsed for question={query} into normalized_form={parsed_q}")

attr_prompt = """For the following source attributes: {parsed_attrs}, map them to the corresponding output attribute in this list: {real_attrs}.
Format the response way explained below:
{format_instructions}"""

attr_parser = PydanticOutputParser(pydantic_object=MapAttributeToAttributeResponse)

ATTR_MAP_PROMPT = PromptTemplate(
template = attr_prompt,
input_variables=["parsed_attrs", "real_attrs"],
partial_variables = {"format_instructions": attr_parser.get_format_instructions()}
)

attr_map_chain = LLMChain(llm=self.llm, prompt=ATTR_MAP_PROMPT)
for vertex in parsed_q.target_vertex_attributes.keys():
map_attr = attr_map_chain.apply([{"parsed_attrs": parsed_q.target_vertex_attributes[vertex], "real_attrs": self.conn.getVertexAttrs(vertex)}])[0]["text"]
parsed_map = attr_parser.invoke(map_attr).attr_map
parsed_q.target_vertex_attributes[vertex] = [parsed_map[x] for x in list(parsed_q.target_vertex_attributes[vertex])]

logger.debug(f"request_id={req_id_cv.get()} MapVertexAttributes applied")

for edge in parsed_q.target_edge_attributes.keys():
map_attr = attr_map_chain.apply([{"parsed_attrs": parsed_q.target_edge_attributes[edge], "real_attrs": self.conn.getEdgeAttrs(edge)}])[0]["text"]
parsed_map = attr_parser.invoke(map_attr).attr_map
parsed_q.target_edge_attributes[edge] = [parsed_map[x] for x in list(parsed_q.target_edge_attributes[edge])]

logger.debug(f"request_id={req_id_cv.get()} MapEdgeAttributes applied")

try:
validate_schema(self.conn,
parsed_q.target_vertex_types,
Expand Down

0 comments on commit 5f8c2ad

Please sign in to comment.