Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a more clean style code #2

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 64 additions & 59 deletions service/add_citation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,83 +2,88 @@
from llama_index.retrievers.bm25 import BM25Retriever

def split_sentences(text):
"""Split the input text into sentences."""
return nltk.sent_tokenize(text)

def create_bm25_retriever(retrieved_nodes):
"""Create and return a BM25Retriever instance."""
return BM25Retriever.from_defaults(nodes=retrieved_nodes, similarity_top_k=2)

def process_citation(paper, cited_paper_id_to_cnt, cited_paper_list, cite_cnt):
"""Process a single citation and update citation counts."""
paper_id = paper.node.metadata['id']
paper_title = paper.node.metadata['title']
if paper_id not in cited_paper_id_to_cnt:
cited_paper_id_to_cnt[paper_id] = cite_cnt
cited_paper_list.append((paper_id, paper_title))
cite_cnt += 1
return cited_paper_id_to_cnt[paper_id], cite_cnt

def create_cite_string(paper_id, cite_cnt):
"""Create a citation string for a paper."""
return f"[[{cite_cnt}]](https://arxiv.org/abs/{paper_id})"

def add_citation_to_response(final_response, right, cite_str):
"""Add a citation string to the response."""
return final_response[:right - 1] + cite_str + final_response[right - 1:]

def create_references_list(cited_paper_list):
"""Create a formatted string of references."""
cited_list_str = ""
for cite_idx, (cited_paper_id, cited_paper_title) in enumerate(cited_paper_list, start=1):
cited_list_str += f"""[[{cite_idx}] {cited_paper_title}](https://arxiv.org/abs/{cited_paper_id})\n\n"""
return cited_list_str

def add_citation_with_retrieved_node(retrieved_nodes, final_response):
if retrieved_nodes is None or len(retrieved_nodes) <= 0:
"""Main function to add citations to the response."""
if not retrieved_nodes:
return final_response
bm25_retriever = BM25Retriever.from_defaults(nodes=retrieved_nodes, similarity_top_k=2)

bm25_retriever = create_bm25_retriever(retrieved_nodes)
sentences = [sentence for sentence in split_sentences(final_response) if len(sentence) > 20]
start = 0
cite_cnt = 1
threshold = 13.5
cited_paper_id_to_cnt = {}
cited_paper_list = []

for sentence in sentences:
left = final_response.find(sentence, start)
right = left + len(sentence)
relevant_nodes = bm25_retriever.retrieve(sentence)

if len(relevant_nodes) == 0 or len(sentence.strip()) < 20:
start = right
continue

if len(relevant_nodes) == 1 or relevant_nodes[0].node.metadata['id'] == relevant_nodes[1].node.metadata['id']:
paper1 = relevant_nodes[0]
paper1_id = paper1.node.metadata['id']
paper1_title = paper1.node.metadata['title']
if paper1.score > threshold:
if paper1_id not in cited_paper_id_to_cnt:
cited_paper_id_to_cnt[paper1_id] = cite_cnt
cited_paper_list.append((paper1_id, paper1_title))
cite_cnt += 1
paper1_cite_cnt = cited_paper_id_to_cnt[paper1_id]
cite_str = f"[[{paper1_cite_cnt}]](https://arxiv.org/abs/{paper1_id})"
final_response = final_response[:right - 1] + cite_str + final_response[right - 1:]
paper = relevant_nodes[0]
if paper.score > threshold:
cite_cnt, cite_cnt = process_citation(paper, cited_paper_id_to_cnt, cited_paper_list, cite_cnt)
cite_str = create_cite_string(paper.node.metadata['id'], cite_cnt - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two variables with the same name. If the paper ID is already in the dictionary, the numbering in cite_str will be incorrect.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest of the code is fine. Thank you for your suggested improvement!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey sure I will fix it dw

final_response = add_citation_to_response(final_response, right, cite_str)
start = right + len(cite_str)
continue
paper1 = relevant_nodes[0]
paper2 = relevant_nodes[1]
paper1_id = paper1.node.metadata['id']
paper1_title = paper1.node.metadata['title']
paper2_id = paper2.node.metadata['id']
paper2_title = paper2.node.metadata['title']
if paper1.score > threshold and paper2.score > threshold:
if paper1_id not in cited_paper_id_to_cnt:
cited_paper_id_to_cnt[paper1_id] = cite_cnt
cited_paper_list.append((paper1_id, paper1_title))
cite_cnt += 1
if paper2_id not in cited_paper_id_to_cnt:
cited_paper_id_to_cnt[paper2_id] = cite_cnt
cited_paper_list.append((paper2_id, paper2_title))
cite_cnt += 1
paper1_cite_cnt = cited_paper_id_to_cnt[paper1_id]
paper2_cite_cnt = cited_paper_id_to_cnt[paper2_id]
if paper1_cite_cnt > paper2_cite_cnt:
paper1_cite_cnt, paper2_cite_cnt = paper2_cite_cnt, paper1_cite_cnt
paper1_id, paper2_id = paper2_id, paper1_id
cite_str = f"[[{paper1_cite_cnt}]](https://arxiv.org/abs/{paper1_id})[[{paper2_cite_cnt}]](https://arxiv.org/abs/{paper2_id})"
final_response = final_response[:right - 1] + cite_str + final_response[right - 1:]
start = right + len(cite_str)
elif paper1.score > threshold:
if paper1_id not in cited_paper_id_to_cnt:
cited_paper_id_to_cnt[paper1_id] = cite_cnt
cited_paper_list.append((paper1_id, paper1_title))
cite_cnt += 1
paper1_cite_cnt = cited_paper_id_to_cnt[paper1_id]
cite_str = f"[[{paper1_cite_cnt}]](https://arxiv.org/abs/{paper1_id})"
final_response = final_response[:right - 1] + cite_str + final_response[right - 1:]
start = right + len(cite_str)
elif paper2.score > threshold:
if paper2_id not in cited_paper_id_to_cnt:
cited_paper_id_to_cnt[paper2_id] = cite_cnt
cited_paper_list.append((paper2_id, paper2_title))
cite_cnt += 1
paper2_cite_cnt = cited_paper_id_to_cnt[paper2_id]
cite_str = f"[[{paper2_cite_cnt}]](https://arxiv.org/abs/{paper2_id})"
final_response = final_response[:right - 1] + cite_str + final_response[right - 1:]
start = right + len(cite_str)
cited_list_str = ""
for cite_idx, (cited_paper_id, cited_paper_title) in enumerate(cited_paper_list, start=1):
cited_list_str += f"""[[{cite_idx}] {cited_paper_title}](https://arxiv.org/abs/{cited_paper_id})\n\n"""
if len(cited_list_str) > 0:
else:
paper1, paper2 = relevant_nodes[0], relevant_nodes[1]
if paper1.score > threshold and paper2.score > threshold:
cite_cnt1, cite_cnt = process_citation(paper1, cited_paper_id_to_cnt, cited_paper_list, cite_cnt)
cite_cnt2, cite_cnt = process_citation(paper2, cited_paper_id_to_cnt, cited_paper_list, cite_cnt)
cite_str = create_cite_string(paper1.node.metadata['id'], cite_cnt1) + create_cite_string(paper2.node.metadata['id'], cite_cnt2)
final_response = add_citation_to_response(final_response, right, cite_str)
start = right + len(cite_str)
elif paper1.score > threshold:
cite_cnt1, cite_cnt = process_citation(paper1, cited_paper_id_to_cnt, cited_paper_list, cite_cnt)
cite_str = create_cite_string(paper1.node.metadata['id'], cite_cnt1)
final_response = add_citation_to_response(final_response, right, cite_str)
start = right + len(cite_str)
elif paper2.score > threshold:
cite_cnt2, cite_cnt = process_citation(paper2, cited_paper_id_to_cnt, cited_paper_list, cite_cnt)
cite_str = create_cite_string(paper2.node.metadata['id'], cite_cnt2)
final_response = add_citation_to_response(final_response, right, cite_str)
start = right + len(cite_str)

cited_list_str = create_references_list(cited_paper_list)
if cited_list_str:
final_response += "\n\n**REFERENCES**\n\n" + cited_list_str

return final_response
70 changes: 37 additions & 33 deletions service/hybrid_retriever.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,68 @@
from llama_index.core.retrievers import BaseRetriever
from service.field_selector import field_selector
from service.date_selector import date_selector
from config import category_list, qdrant_month_list, elastic_seach_month_list
from config import category_list, qdrant_month_list, elastic_search_month_list
import concurrent.futures

def retrieve_bm25(elastic_search_retriever, query, need_categories, need_months):
"""Retrieve nodes using BM25 algorithm."""
return elastic_search_retriever.custom_retrieve(query, need_categories, need_months)

def retrieve_vector(qdrant_retriever, query, need_categories, need_months):
"""Retrieve nodes using vector search."""
return qdrant_retriever.custom_retrieve(query, need_categories, need_months)

class HybridRetriever(BaseRetriever):
def __init__(self, qdrant_retriever,
elastic_search_retriever,
node_postprocessors):
def __init__(self, qdrant_retriever, elastic_search_retriever, node_postprocessors):
super().__init__()
self.qdrant_retriever = qdrant_retriever
self.elastic_search_retriever = elastic_search_retriever
self.node_postprocessors = node_postprocessors

def _retrieve(self, query):
need_categories = field_selector(query)
if len(need_categories) == 0:
need_categories = category_list
qdrant_need_months, es_need_month = date_selector(query, range_type="all")
if qdrant_need_months is None:
qdrant_need_months = qdrant_month_list
if es_need_month is None:
es_need_month = elastic_seach_month_list
"""Retrieve nodes using both BM25 and vector search concurrently."""
need_categories = self._get_categories(query)
qdrant_need_months, es_need_month = self._get_date_ranges(query)

with concurrent.futures.ThreadPoolExecutor() as executor:
future_bm25 = executor.submit(retrieve_bm25, self.elastic_search_retriever, query, need_categories, es_need_month)
future_vector = executor.submit(retrieve_vector, self.qdrant_retriever, query, need_categories, qdrant_need_months)

bm25_nodes = future_bm25.result()
vector_nodes = future_vector.result()

nodes = bm25_nodes + vector_nodes
for postprocessor in self.node_postprocessors:
nodes = postprocessor.postprocess_nodes(nodes, query)
return nodes

return self._postprocess_nodes(nodes, query)

def custom_retrieve_vector(self, query: str):
need_categories = field_selector(query)
if len(need_categories) == 0:
need_categories = category_list
qdrant_need_months = date_selector(query, range_type="qdrant")
if qdrant_need_months is None:
qdrant_need_months = qdrant_month_list
"""Retrieve nodes using only vector search."""
need_categories = self._get_categories(query)
qdrant_need_months = date_selector(query, range_type="qdrant") or qdrant_month_list
nodes = retrieve_vector(self.qdrant_retriever, query, need_categories, qdrant_need_months)
for postprocessor in self.node_postprocessors:
nodes = postprocessor.postprocess_nodes(nodes, query)
return nodes

return self._postprocess_nodes(nodes, query)

def custom_retrieve_bm25(self, query: str):
need_categories = field_selector(query)
if len(need_categories) == 0:
need_categories = category_list
es_need_month = date_selector(query, range_type="elastic search")
if es_need_month is None:
es_need_month = elastic_seach_month_list
"""Retrieve nodes using only BM25 search."""
need_categories = self._get_categories(query)
es_need_month = date_selector(query, range_type="elastic search") or elastic_search_month_list
nodes = retrieve_bm25(self.elastic_search_retriever, query, need_categories, es_need_month)
return self._postprocess_nodes(nodes, query)

def _get_categories(self, query):
"""Get categories for the query."""
need_categories = field_selector(query)
return need_categories if need_categories else category_list

def _get_date_ranges(self, query):
"""Get date ranges for Qdrant and Elasticsearch."""
qdrant_need_months, es_need_month = date_selector(query, range_type="all")
return (
qdrant_need_months or qdrant_month_list,
es_need_month or elastic_search_month_list
)

def _postprocess_nodes(self, nodes, query):
"""Apply postprocessors to the retrieved nodes."""
for postprocessor in self.node_postprocessors:
nodes = postprocessor.postprocess_nodes(nodes, query)
return nodes
return nodes