From bc8ca997e705a200ef2773b71f1841a8c6613831 Mon Sep 17 00:00:00 2001 From: San <99511815+sanowl@users.noreply.github.com> Date: Wed, 14 Aug 2024 16:45:49 +0300 Subject: [PATCH] . --- service/add_citation.py | 123 +++++++++++++++++++----------------- service/hybrid_retriever.py | 70 ++++++++++---------- 2 files changed, 101 insertions(+), 92 deletions(-) diff --git a/service/add_citation.py b/service/add_citation.py index a5899ab..eda5e62 100644 --- a/service/add_citation.py +++ b/service/add_citation.py @@ -2,83 +2,88 @@ from llama_index.retrievers.bm25 import BM25Retriever def split_sentences(text): + """Split the input text into sentences.""" return nltk.sent_tokenize(text) +def create_bm25_retriever(retrieved_nodes): + """Create and return a BM25Retriever instance.""" + return BM25Retriever.from_defaults(nodes=retrieved_nodes, similarity_top_k=2) + +def process_citation(paper, cited_paper_id_to_cnt, cited_paper_list, cite_cnt): + """Process a single citation and update citation counts.""" + paper_id = paper.node.metadata['id'] + paper_title = paper.node.metadata['title'] + if paper_id not in cited_paper_id_to_cnt: + cited_paper_id_to_cnt[paper_id] = cite_cnt + cited_paper_list.append((paper_id, paper_title)) + cite_cnt += 1 + return cited_paper_id_to_cnt[paper_id], cite_cnt + +def create_cite_string(paper_id, cite_cnt): + """Create a citation string for a paper.""" + return f"[[{cite_cnt}]](https://arxiv.org/abs/{paper_id})" + +def add_citation_to_response(final_response, right, cite_str): + """Add a citation string to the response.""" + return final_response[:right - 1] + cite_str + final_response[right - 1:] + +def create_references_list(cited_paper_list): + """Create a formatted string of references.""" + cited_list_str = "" + for cite_idx, (cited_paper_id, cited_paper_title) in enumerate(cited_paper_list, start=1): + cited_list_str += f"""[[{cite_idx}] {cited_paper_title}](https://arxiv.org/abs/{cited_paper_id})\n\n""" + return cited_list_str + def add_citation_with_retrieved_node(retrieved_nodes, final_response): - if retrieved_nodes is None or len(retrieved_nodes) <= 0: + """Main function to add citations to the response.""" + if not retrieved_nodes: return final_response - bm25_retriever = BM25Retriever.from_defaults(nodes=retrieved_nodes, similarity_top_k=2) + + bm25_retriever = create_bm25_retriever(retrieved_nodes) sentences = [sentence for sentence in split_sentences(final_response) if len(sentence) > 20] start = 0 cite_cnt = 1 threshold = 13.5 cited_paper_id_to_cnt = {} cited_paper_list = [] + for sentence in sentences: left = final_response.find(sentence, start) right = left + len(sentence) relevant_nodes = bm25_retriever.retrieve(sentence) + if len(relevant_nodes) == 0 or len(sentence.strip()) < 20: start = right continue + if len(relevant_nodes) == 1 or relevant_nodes[0].node.metadata['id'] == relevant_nodes[1].node.metadata['id']: - paper1 = relevant_nodes[0] - paper1_id = paper1.node.metadata['id'] - paper1_title = paper1.node.metadata['title'] - if paper1.score > threshold: - if paper1_id not in cited_paper_id_to_cnt: - cited_paper_id_to_cnt[paper1_id] = cite_cnt - cited_paper_list.append((paper1_id, paper1_title)) - cite_cnt += 1 - paper1_cite_cnt = cited_paper_id_to_cnt[paper1_id] - cite_str = f"[[{paper1_cite_cnt}]](https://arxiv.org/abs/{paper1_id})" - final_response = final_response[:right - 1] + cite_str + final_response[right - 1:] + paper = relevant_nodes[0] + if paper.score > threshold: + cite_cnt, cite_cnt = process_citation(paper, cited_paper_id_to_cnt, cited_paper_list, cite_cnt) + cite_str = create_cite_string(paper.node.metadata['id'], cite_cnt - 1) + final_response = add_citation_to_response(final_response, right, cite_str) start = right + len(cite_str) - continue - paper1 = relevant_nodes[0] - paper2 = relevant_nodes[1] - paper1_id = paper1.node.metadata['id'] - paper1_title = paper1.node.metadata['title'] - paper2_id = paper2.node.metadata['id'] - paper2_title = paper2.node.metadata['title'] - if paper1.score > threshold and paper2.score > threshold: - if paper1_id not in cited_paper_id_to_cnt: - cited_paper_id_to_cnt[paper1_id] = cite_cnt - cited_paper_list.append((paper1_id, paper1_title)) - cite_cnt += 1 - if paper2_id not in cited_paper_id_to_cnt: - cited_paper_id_to_cnt[paper2_id] = cite_cnt - cited_paper_list.append((paper2_id, paper2_title)) - cite_cnt += 1 - paper1_cite_cnt = cited_paper_id_to_cnt[paper1_id] - paper2_cite_cnt = cited_paper_id_to_cnt[paper2_id] - if paper1_cite_cnt > paper2_cite_cnt: - paper1_cite_cnt, paper2_cite_cnt = paper2_cite_cnt, paper1_cite_cnt - paper1_id, paper2_id = paper2_id, paper1_id - cite_str = f"[[{paper1_cite_cnt}]](https://arxiv.org/abs/{paper1_id})[[{paper2_cite_cnt}]](https://arxiv.org/abs/{paper2_id})" - final_response = final_response[:right - 1] + cite_str + final_response[right - 1:] - start = right + len(cite_str) - elif paper1.score > threshold: - if paper1_id not in cited_paper_id_to_cnt: - cited_paper_id_to_cnt[paper1_id] = cite_cnt - cited_paper_list.append((paper1_id, paper1_title)) - cite_cnt += 1 - paper1_cite_cnt = cited_paper_id_to_cnt[paper1_id] - cite_str = f"[[{paper1_cite_cnt}]](https://arxiv.org/abs/{paper1_id})" - final_response = final_response[:right - 1] + cite_str + final_response[right - 1:] - start = right + len(cite_str) - elif paper2.score > threshold: - if paper2_id not in cited_paper_id_to_cnt: - cited_paper_id_to_cnt[paper2_id] = cite_cnt - cited_paper_list.append((paper2_id, paper2_title)) - cite_cnt += 1 - paper2_cite_cnt = cited_paper_id_to_cnt[paper2_id] - cite_str = f"[[{paper2_cite_cnt}]](https://arxiv.org/abs/{paper2_id})" - final_response = final_response[:right - 1] + cite_str + final_response[right - 1:] - start = right + len(cite_str) - cited_list_str = "" - for cite_idx, (cited_paper_id, cited_paper_title) in enumerate(cited_paper_list, start=1): - cited_list_str += f"""[[{cite_idx}] {cited_paper_title}](https://arxiv.org/abs/{cited_paper_id})\n\n""" - if len(cited_list_str) > 0: + else: + paper1, paper2 = relevant_nodes[0], relevant_nodes[1] + if paper1.score > threshold and paper2.score > threshold: + cite_cnt1, cite_cnt = process_citation(paper1, cited_paper_id_to_cnt, cited_paper_list, cite_cnt) + cite_cnt2, cite_cnt = process_citation(paper2, cited_paper_id_to_cnt, cited_paper_list, cite_cnt) + cite_str = create_cite_string(paper1.node.metadata['id'], cite_cnt1) + create_cite_string(paper2.node.metadata['id'], cite_cnt2) + final_response = add_citation_to_response(final_response, right, cite_str) + start = right + len(cite_str) + elif paper1.score > threshold: + cite_cnt1, cite_cnt = process_citation(paper1, cited_paper_id_to_cnt, cited_paper_list, cite_cnt) + cite_str = create_cite_string(paper1.node.metadata['id'], cite_cnt1) + final_response = add_citation_to_response(final_response, right, cite_str) + start = right + len(cite_str) + elif paper2.score > threshold: + cite_cnt2, cite_cnt = process_citation(paper2, cited_paper_id_to_cnt, cited_paper_list, cite_cnt) + cite_str = create_cite_string(paper2.node.metadata['id'], cite_cnt2) + final_response = add_citation_to_response(final_response, right, cite_str) + start = right + len(cite_str) + + cited_list_str = create_references_list(cited_paper_list) + if cited_list_str: final_response += "\n\n**REFERENCES**\n\n" + cited_list_str + return final_response \ No newline at end of file diff --git a/service/hybrid_retriever.py b/service/hybrid_retriever.py index dd240f6..efab548 100644 --- a/service/hybrid_retriever.py +++ b/service/hybrid_retriever.py @@ -1,64 +1,68 @@ from llama_index.core.retrievers import BaseRetriever from service.field_selector import field_selector from service.date_selector import date_selector -from config import category_list, qdrant_month_list, elastic_seach_month_list +from config import category_list, qdrant_month_list, elastic_search_month_list import concurrent.futures def retrieve_bm25(elastic_search_retriever, query, need_categories, need_months): + """Retrieve nodes using BM25 algorithm.""" return elastic_search_retriever.custom_retrieve(query, need_categories, need_months) def retrieve_vector(qdrant_retriever, query, need_categories, need_months): + """Retrieve nodes using vector search.""" return qdrant_retriever.custom_retrieve(query, need_categories, need_months) class HybridRetriever(BaseRetriever): - def __init__(self, qdrant_retriever, - elastic_search_retriever, - node_postprocessors): + def __init__(self, qdrant_retriever, elastic_search_retriever, node_postprocessors): super().__init__() self.qdrant_retriever = qdrant_retriever self.elastic_search_retriever = elastic_search_retriever self.node_postprocessors = node_postprocessors def _retrieve(self, query): - need_categories = field_selector(query) - if len(need_categories) == 0: - need_categories = category_list - qdrant_need_months, es_need_month = date_selector(query, range_type="all") - if qdrant_need_months is None: - qdrant_need_months = qdrant_month_list - if es_need_month is None: - es_need_month = elastic_seach_month_list + """Retrieve nodes using both BM25 and vector search concurrently.""" + need_categories = self._get_categories(query) + qdrant_need_months, es_need_month = self._get_date_ranges(query) + with concurrent.futures.ThreadPoolExecutor() as executor: future_bm25 = executor.submit(retrieve_bm25, self.elastic_search_retriever, query, need_categories, es_need_month) future_vector = executor.submit(retrieve_vector, self.qdrant_retriever, query, need_categories, qdrant_need_months) bm25_nodes = future_bm25.result() vector_nodes = future_vector.result() + nodes = bm25_nodes + vector_nodes - for postprocessor in self.node_postprocessors: - nodes = postprocessor.postprocess_nodes(nodes, query) - return nodes - + return self._postprocess_nodes(nodes, query) + def custom_retrieve_vector(self, query: str): - need_categories = field_selector(query) - if len(need_categories) == 0: - need_categories = category_list - qdrant_need_months = date_selector(query, range_type="qdrant") - if qdrant_need_months is None: - qdrant_need_months = qdrant_month_list + """Retrieve nodes using only vector search.""" + need_categories = self._get_categories(query) + qdrant_need_months = date_selector(query, range_type="qdrant") or qdrant_month_list nodes = retrieve_vector(self.qdrant_retriever, query, need_categories, qdrant_need_months) - for postprocessor in self.node_postprocessors: - nodes = postprocessor.postprocess_nodes(nodes, query) - return nodes - + return self._postprocess_nodes(nodes, query) + def custom_retrieve_bm25(self, query: str): - need_categories = field_selector(query) - if len(need_categories) == 0: - need_categories = category_list - es_need_month = date_selector(query, range_type="elastic search") - if es_need_month is None: - es_need_month = elastic_seach_month_list + """Retrieve nodes using only BM25 search.""" + need_categories = self._get_categories(query) + es_need_month = date_selector(query, range_type="elastic search") or elastic_search_month_list nodes = retrieve_bm25(self.elastic_search_retriever, query, need_categories, es_need_month) + return self._postprocess_nodes(nodes, query) + + def _get_categories(self, query): + """Get categories for the query.""" + need_categories = field_selector(query) + return need_categories if need_categories else category_list + + def _get_date_ranges(self, query): + """Get date ranges for Qdrant and Elasticsearch.""" + qdrant_need_months, es_need_month = date_selector(query, range_type="all") + return ( + qdrant_need_months or qdrant_month_list, + es_need_month or elastic_search_month_list + ) + + def _postprocess_nodes(self, nodes, query): + """Apply postprocessors to the retrieved nodes.""" for postprocessor in self.node_postprocessors: nodes = postprocessor.postprocess_nodes(nodes, query) - return nodes + return nodes \ No newline at end of file