Skip to content

Commit

Permalink
Added edge summarizer endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
jrichardson97 committed Oct 31, 2023
1 parent 6cac3c2 commit 45cbaf8
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 34 deletions.
32 changes: 23 additions & 9 deletions kg_summarizer/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from reasoner_pydantic import Response as PDResponse

from kg_summarizer.trapi import GraphContainer
from kg_summarizer.trapi import GraphContainer, parse_edge
from kg_summarizer.ai import generate_response

class LLMParameters(BaseModel):
gpt_model: str
temperature: Optional[float] = 0.0
system_prompt: Optional[str] = ''

class TrapiParameters(BaseModel):
result_idx: Optional[int] = 0
Expand All @@ -26,6 +27,9 @@ class ResponseItem(BaseModel):
response: PDResponse
parameters: Parameters

class EdgeItem(BaseModel):
edge: dict
parameters: Parameters

KG_SUM_VERSION = '0.1'

Expand All @@ -49,13 +53,23 @@ async def summarize_abstract_handler(item: AbstractItem):
)
return summary

@app.post("/summarize/edges")
async def summarize_edges_handler(item: ResponseItem):
@app.post("/summarize/edge")
async def summarize_edge_handler(item: EdgeItem):
edge = parse_edge(item)

g = GraphContainer(
item.response,
verbose=False,
result_idx=item.parameters.trapi.result_idx
)
spo_sentence = f"{edge['subject']} {edge['predicate']} {edge['object']}."

if item.parameters.llm.system_prompt:
system_prompt = item.parameters.llm.system_prompt
else:
system_prompt = f"""
Summarize the following edge publication abstracts listed in the knowledge graph. Make sure the summary supports the statement '{spo_sentence}'. Only use information explicitly stated in the publication abstracts. I repeat, do not make up any information.
"""

return 999
summary = generate_response(
system_prompt,
str(edge),
item.parameters.llm.gpt_model,
item.parameters.llm.temperature,
)
return summary
61 changes: 36 additions & 25 deletions kg_summarizer/trapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,30 +170,6 @@ def parse_node_attributes(attr_list_of_dicts, node_norm_name):
}

def get_edge_info(self, fetch_pubs=True):
def parse_edge_attributes(attr_list_of_dicts, fetch_pubs=True):
edge_attr_data = {
'publications': [],
}

for attr_dict in attr_list_of_dicts:
atid = attr_dict['attribute_type_id']

if atid == 'biolink:publications':
# Sometimes there are multiple publication attributes so extend list then fetch non-duplicates at the end
publication_ids = attr_dict['value']
edge_attr_data['publications'].extend(publication_ids)

if atid == 'biolink:support_graphs':
edge_attr_data['support_graphs'] = attr_dict['value']

# Remove duplicates
edge_attr_data['publications'] = list(set(edge_attr_data['publications']))

if fetch_pubs:
edge_attr_data['publications'] = get_publications(edge_attr_data['publications'])

return edge_attr_data

self.edges = []
edge_list = []

Expand Down Expand Up @@ -378,4 +354,39 @@ def get_publications(pub_id_list):
if abstract is not None:
pub_list.append({pubid: abstract})

return pub_list
return pub_list

def parse_edge_attributes(attr_list_of_dicts, fetch_pubs=True):
edge_attr_data = {
'publications': [],
}

for attr_dict in attr_list_of_dicts:
atid = attr_dict['attribute_type_id']

if atid == 'biolink:publications':
# Sometimes there are multiple publication attributes so extend list then fetch non-duplicates at the end
publication_ids = attr_dict['value']
edge_attr_data['publications'].extend(publication_ids)

if atid == 'biolink:support_graphs':
edge_attr_data['support_graphs'] = attr_dict['value']

# Remove duplicates
edge_attr_data['publications'] = list(set(edge_attr_data['publications']))

if fetch_pubs:
edge_attr_data['publications'] = get_publications(edge_attr_data['publications'])

return edge_attr_data

def parse_edge(edge_data):
parsed_data = dict(
subject=edge_data['source']['name'],
object=edge_data['target']['name'],
predicate=edge_data['predicate'].split(':')[1]
)

edge_attr_data = parse_edge_attributes(edge_data['attributes'])

return {**parsed_data, **edge_attr_data}

0 comments on commit 45cbaf8

Please sign in to comment.