From a4bb37deef026c4dd1c22fadb66d0401e4b5c10a Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Tue, 4 Mar 2025 17:09:36 -0800 Subject: [PATCH] improved basic search latency --- .../nodes/generate_initial_answer.py | 5 +- .../nodes/generate_validate_refined_answer.py | 5 +- .../deep_search/main/operations.py | 9 +- .../nodes/format_results.py | 5 +- .../orchestration/nodes/choose_tool.py | 53 +++++- .../orchestration/nodes/use_tool_response.py | 21 ++- .../shared_graph_utils/constants.py | 5 + .../stream_processing/citation_processing.py | 158 +++++++++--------- backend/onyx/context/search/models.py | 8 +- backend/onyx/context/search/pipeline.py | 10 +- .../search/preprocessing/preprocessing.py | 16 +- .../context/search/retrieval/search_runner.py | 30 ++-- .../search/search_tool.py | 39 +++-- .../search/search_utils.py | 10 ++ backend/onyx/utils/threadpool_concurrency.py | 33 +++- .../onyx/utils/test_threadpool_concurrency.py | 89 ++++++++++ .../onyx/utils/test_threadpool_contextvars.py | 38 +++++ 17 files changed, 402 insertions(+), 132 deletions(-) diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py index 1269b3dd484..7a7c8ffc2da 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py @@ -153,8 +153,9 @@ def generate_initial_answer( ) for tool_response in yield_search_responses( query=question, - reranked_sections=answer_generation_documents.streaming_documents, - final_context_sections=answer_generation_documents.context_documents, + get_retrieved_sections=lambda: answer_generation_documents.context_documents, + get_reranked_sections=lambda: answer_generation_documents.streaming_documents, + get_final_context_sections=lambda: answer_generation_documents.context_documents, search_query_info=query_info, get_section_relevance=lambda: relevance_list, search_tool=graph_config.tooling.search_tool, diff --git a/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py b/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py index 9782c1340e5..b17c39a6dfa 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py +++ b/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py @@ -179,8 +179,9 @@ def generate_validate_refined_answer( ) for tool_response in yield_search_responses( query=question, - reranked_sections=answer_generation_documents.streaming_documents, - final_context_sections=answer_generation_documents.context_documents, + get_retrieved_sections=lambda: answer_generation_documents.context_documents, + get_reranked_sections=lambda: answer_generation_documents.streaming_documents, + get_final_context_sections=lambda: answer_generation_documents.context_documents, search_query_info=query_info, get_section_relevance=lambda: relevance_list, search_tool=graph_config.tooling.search_tool, diff --git a/backend/onyx/agents/agent_search/deep_search/main/operations.py b/backend/onyx/agents/agent_search/deep_search/main/operations.py index 152581e1029..46d41d4773b 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/operations.py +++ b/backend/onyx/agents/agent_search/deep_search/main/operations.py @@ -13,7 +13,6 @@ from onyx.chat.models import StreamStopReason from onyx.chat.models import StreamType from onyx.chat.models import SubQuestionPiece -from onyx.context.search.models import IndexFilters from onyx.tools.models import SearchQueryInfo from onyx.utils.logger import setup_logger @@ -144,8 +143,6 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo: if result.query_info is not None: query_info = result.query_info break - return query_info or SearchQueryInfo( - predicted_search=None, - final_filters=IndexFilters(access_control_list=None), - recency_bias_multiplier=1.0, - ) + + assert query_info is not None, "must have query info" + return query_info diff --git a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/format_results.py b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/format_results.py index 5683f4c70b6..272f02e4a02 100644 --- a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/format_results.py @@ -56,8 +56,9 @@ def format_results( relevance_list = relevance_from_docs(reranked_documents) for tool_response in yield_search_responses( query=state.question, - reranked_sections=state.retrieved_documents, - final_context_sections=reranked_documents, + get_retrieved_sections=lambda: reranked_documents, + get_reranked_sections=lambda: state.retrieved_documents, + get_final_context_sections=lambda: reranked_documents, search_query_info=query_info, get_section_relevance=lambda: relevance_list, search_tool=graph_config.tooling.search_tool, diff --git a/backend/onyx/agents/agent_search/orchestration/nodes/choose_tool.py b/backend/onyx/agents/agent_search/orchestration/nodes/choose_tool.py index f7fdd71e50c..b8b15d5c953 100644 --- a/backend/onyx/agents/agent_search/orchestration/nodes/choose_tool.py +++ b/backend/onyx/agents/agent_search/orchestration/nodes/choose_tool.py @@ -10,13 +10,24 @@ from onyx.agents.agent_search.orchestration.states import ToolChoice from onyx.agents.agent_search.orchestration.states import ToolChoiceState from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate +from onyx.agents.agent_search.shared_graph_utils.constants import EMBEDDING_KEY +from onyx.agents.agent_search.shared_graph_utils.constants import IS_KEYWORD_KEY +from onyx.agents.agent_search.shared_graph_utils.constants import KEYWORDS_KEY from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name from onyx.chat.tool_handling.tool_response_handler import ( get_tool_call_for_non_tool_calling_llm_impl, ) +from onyx.context.search.preprocessing.preprocessing import query_analysis +from onyx.context.search.retrieval.search_runner import get_query_embedding from onyx.tools.tool import Tool +from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.utils.logger import setup_logger +from onyx.utils.threadpool_concurrency import run_in_background +from onyx.utils.threadpool_concurrency import TimeoutThread +from onyx.utils.threadpool_concurrency import wait_on_background +from onyx.utils.timing import log_function_time +from shared_configs.model_server_models import Embedding logger = setup_logger() @@ -25,6 +36,7 @@ # and a function that handles extracting the necessary fields # from the state and config # TODO: fan-out to multiple tool call nodes? Make this configurable? +@log_function_time(print_only=True) def choose_tool( state: ToolChoiceState, config: RunnableConfig, @@ -37,6 +49,29 @@ def choose_tool( should_stream_answer = state.should_stream_answer agent_config = cast(GraphConfig, config["metadata"]["config"]) + + force_use_tool = agent_config.tooling.force_use_tool + + embedding_thread: TimeoutThread[Embedding] | None = None + keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None + if ( + not agent_config.behavior.use_agentic_search + and agent_config.tooling.search_tool is not None + and ( + not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name + ) + ): + # Run in a background thread to avoid blocking the main thread + embedding_thread = run_in_background( + get_query_embedding, + agent_config.inputs.search_request.query, + agent_config.persistence.db_session, + ) + keyword_thread = run_in_background( + query_analysis, + agent_config.inputs.search_request.query, + ) + using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder @@ -47,7 +82,6 @@ def choose_tool( tools = [ tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools ] - force_use_tool = agent_config.tooling.force_use_tool tool, tool_args = None, None if force_use_tool.force_use and force_use_tool.args is not None: @@ -71,6 +105,14 @@ def choose_tool( # If we have a tool and tool args, we are ready to request a tool call. # This only happens if the tool call was forced or we are using a non-tool calling LLM. if tool and tool_args: + if embedding_thread and tool.name == SearchTool._NAME: + # Wait for the embedding thread to finish + embedding = wait_on_background(embedding_thread) + tool_args[EMBEDDING_KEY] = embedding + if keyword_thread and tool.name == SearchTool._NAME: + is_keyword, keywords = wait_on_background(keyword_thread) + tool_args[IS_KEYWORD_KEY] = is_keyword + tool_args[KEYWORDS_KEY] = keywords return ToolChoiceUpdate( tool_choice=ToolChoice( tool=tool, @@ -145,6 +187,15 @@ def choose_tool( logger.debug(f"Selected tool: {selected_tool.name}") logger.debug(f"Selected tool call request: {selected_tool_call_request}") + if embedding_thread and selected_tool.name == SearchTool._NAME: + # Wait for the embedding thread to finish + embedding = wait_on_background(embedding_thread) + selected_tool_call_request["args"][EMBEDDING_KEY] = embedding + if keyword_thread and selected_tool.name == SearchTool._NAME: + is_keyword, keywords = wait_on_background(keyword_thread) + selected_tool_call_request["args"][IS_KEYWORD_KEY] = is_keyword + selected_tool_call_request["args"][KEYWORDS_KEY] = keywords + return ToolChoiceUpdate( tool_choice=ToolChoice( tool=selected_tool, diff --git a/backend/onyx/agents/agent_search/orchestration/nodes/use_tool_response.py b/backend/onyx/agents/agent_search/orchestration/nodes/use_tool_response.py index 6874aae9795..4bec51bc51a 100644 --- a/backend/onyx/agents/agent_search/orchestration/nodes/use_tool_response.py +++ b/backend/onyx/agents/agent_search/orchestration/nodes/use_tool_response.py @@ -9,18 +9,23 @@ from onyx.agents.agent_search.basic.utils import process_llm_stream from onyx.agents.agent_search.models import GraphConfig from onyx.chat.models import LlmDoc -from onyx.chat.models import OnyxContexts from onyx.tools.tool_implementations.search.search_tool import ( - SEARCH_DOC_CONTENT_ID, + SEARCH_RESPONSE_SUMMARY_ID, +) +from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary +from onyx.tools.tool_implementations.search.search_utils import ( + context_from_inference_section, ) from onyx.tools.tool_implementations.search_like_tool_utils import ( FINAL_CONTEXT_DOCUMENTS_ID, ) from onyx.utils.logger import setup_logger +from onyx.utils.timing import log_function_time logger = setup_logger() +@log_function_time(print_only=True) def basic_use_tool_response( state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None ) -> BasicOutput: @@ -50,11 +55,13 @@ def basic_use_tool_response( for yield_item in tool_call_responses: if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID: final_search_results = cast(list[LlmDoc], yield_item.response) - elif yield_item.id == SEARCH_DOC_CONTENT_ID: - search_contexts = cast(OnyxContexts, yield_item.response).contexts - for doc in search_contexts: - if doc.document_id not in initial_search_results: - initial_search_results.append(doc) + elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID: + search_response_summary = cast(SearchResponseSummary, yield_item.response) + for section in search_response_summary.top_sections: + if section.center_chunk.document_id not in initial_search_results: + initial_search_results.append( + context_from_inference_section(section) + ) new_tool_call_chunk = AIMessageChunk(content="") if not agent_config.behavior.skip_gen_ai_answer_generation: diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/constants.py b/backend/onyx/agents/agent_search/shared_graph_utils/constants.py index 79ebcf33853..ca7828cd8ad 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/constants.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/constants.py @@ -13,6 +13,11 @@ AGENT_ANSWER_SEPARATOR = "Answer:" +EMBEDDING_KEY = "embedding" +IS_KEYWORD_KEY = "is_keyword" +KEYWORDS_KEY = "keywords" + + class AgentLLMErrorType(str, Enum): TIMEOUT = "timeout" RATE_LIMIT = "rate_limit" diff --git a/backend/onyx/chat/stream_processing/citation_processing.py b/backend/onyx/chat/stream_processing/citation_processing.py index 071b28c3457..342472f5f13 100644 --- a/backend/onyx/chat/stream_processing/citation_processing.py +++ b/backend/onyx/chat/stream_processing/citation_processing.py @@ -90,97 +90,97 @@ def process_token( next(group for group in citation.groups() if group is not None) ) - if 1 <= numerical_value <= self.max_citation_num: - context_llm_doc = self.context_docs[numerical_value - 1] - final_citation_num = self.final_order_mapping[ - context_llm_doc.document_id - ] + if not (1 <= numerical_value <= self.max_citation_num): + continue + + context_llm_doc = self.context_docs[numerical_value - 1] + final_citation_num = self.final_order_mapping[ + context_llm_doc.document_id + ] - if final_citation_num not in self.citation_order: - self.citation_order.append(final_citation_num) + if final_citation_num not in self.citation_order: + self.citation_order.append(final_citation_num) - citation_order_idx = ( - self.citation_order.index(final_citation_num) + 1 + citation_order_idx = self.citation_order.index(final_citation_num) + 1 + + # get the value that was displayed to user, should always + # be in the display_doc_order_dict. But check anyways + if context_llm_doc.document_id in self.display_order_mapping: + displayed_citation_num = self.display_order_mapping[ + context_llm_doc.document_id + ] + else: + displayed_citation_num = final_citation_num + logger.warning( + f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead." ) - # get the value that was displayed to user, should always - # be in the display_doc_order_dict. But check anyways - if context_llm_doc.document_id in self.display_order_mapping: - displayed_citation_num = self.display_order_mapping[ - context_llm_doc.document_id - ] + # Skip consecutive citations of the same work + if final_citation_num in self.current_citations: + start, end = citation.span() + real_start = length_to_add + start + diff = end - start + self.curr_segment = ( + self.curr_segment[: length_to_add + start] + + self.curr_segment[real_start + diff :] + ) + length_to_add -= diff + continue + + # Handle edge case where LLM outputs citation itself + if self.curr_segment.startswith("[["): + match = re.match(r"\[\[(\d+)\]\]", self.curr_segment) + if match: + try: + doc_id = int(match.group(1)) + context_llm_doc = self.context_docs[doc_id - 1] + yield CitationInfo( + # citation_num is now the number post initial ranking, i.e. as displayed to user + citation_num=displayed_citation_num, + document_id=context_llm_doc.document_id, + ) + except Exception as e: + logger.warning( + f"Manual LLM citation didn't properly cite documents {e}" + ) else: - displayed_citation_num = final_citation_num logger.warning( - f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead." + "Manual LLM citation wasn't able to close brackets" ) + continue - # Skip consecutive citations of the same work - if final_citation_num in self.current_citations: - start, end = citation.span() - real_start = length_to_add + start - diff = end - start - self.curr_segment = ( - self.curr_segment[: length_to_add + start] - + self.curr_segment[real_start + diff :] - ) - length_to_add -= diff - continue - - # Handle edge case where LLM outputs citation itself - if self.curr_segment.startswith("[["): - match = re.match(r"\[\[(\d+)\]\]", self.curr_segment) - if match: - try: - doc_id = int(match.group(1)) - context_llm_doc = self.context_docs[doc_id - 1] - yield CitationInfo( - # citation_num is now the number post initial ranking, i.e. as displayed to user - citation_num=displayed_citation_num, - document_id=context_llm_doc.document_id, - ) - except Exception as e: - logger.warning( - f"Manual LLM citation didn't properly cite documents {e}" - ) - else: - logger.warning( - "Manual LLM citation wasn't able to close brackets" - ) - continue - - link = context_llm_doc.link + link = context_llm_doc.link - self.past_cite_count = len(self.llm_out) - self.current_citations.append(final_citation_num) + self.past_cite_count = len(self.llm_out) + self.current_citations.append(final_citation_num) - if citation_order_idx not in self.cited_inds: - self.cited_inds.add(citation_order_idx) - yield CitationInfo( - # citation number is now the one that was displayed to user - citation_num=displayed_citation_num, - document_id=context_llm_doc.document_id, - ) + if citation_order_idx not in self.cited_inds: + self.cited_inds.add(citation_order_idx) + yield CitationInfo( + # citation number is now the one that was displayed to user + citation_num=displayed_citation_num, + document_id=context_llm_doc.document_id, + ) - start, end = citation.span() - if link: - prev_length = len(self.curr_segment) - self.curr_segment = ( - self.curr_segment[: start + length_to_add] - + f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user - + self.curr_segment[end + length_to_add :] - ) - length_to_add += len(self.curr_segment) - prev_length - else: - prev_length = len(self.curr_segment) - self.curr_segment = ( - self.curr_segment[: start + length_to_add] - + f"[[{displayed_citation_num}]]()" # use the value that was displayed to user - + self.curr_segment[end + length_to_add :] - ) - length_to_add += len(self.curr_segment) - prev_length + start, end = citation.span() + if link: + prev_length = len(self.curr_segment) + self.curr_segment = ( + self.curr_segment[: start + length_to_add] + + f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user + + self.curr_segment[end + length_to_add :] + ) + length_to_add += len(self.curr_segment) - prev_length + else: + prev_length = len(self.curr_segment) + self.curr_segment = ( + self.curr_segment[: start + length_to_add] + + f"[[{displayed_citation_num}]]()" # use the value that was displayed to user + + self.curr_segment[end + length_to_add :] + ) + length_to_add += len(self.curr_segment) - prev_length - last_citation_end = end + length_to_add + last_citation_end = end + length_to_add if last_citation_end > 0: result += self.curr_segment[:last_citation_end] diff --git a/backend/onyx/context/search/models.py b/backend/onyx/context/search/models.py index 7eeb3568695..4d7a12b857d 100644 --- a/backend/onyx/context/search/models.py +++ b/backend/onyx/context/search/models.py @@ -16,7 +16,7 @@ from onyx.indexing.models import BaseChunk from onyx.indexing.models import IndexingSetting from shared_configs.enums import RerankerProvider - +from shared_configs.model_server_models import Embedding MAX_METRICS_CONTENT = ( 200 # Just need enough characters to identify where in the doc the chunk is @@ -147,6 +147,10 @@ class SearchRequest(ChunkContext): evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED model_config = ConfigDict(arbitrary_types_allowed=True) + precomputed_query_embedding: Embedding | None = None + precomputed_is_keyword: bool | None = None + precomputed_keywords: list[str] | None = None + class SearchQuery(ChunkContext): "Processed Request that is directly passed to the SearchPipeline" @@ -171,6 +175,8 @@ class SearchQuery(ChunkContext): offset: int = 0 model_config = ConfigDict(frozen=True) + precomputed_query_embedding: Embedding | None = None + class RetrievalDetails(ChunkContext): # Use LLM to determine whether to do a retrieval or only rely on existing history diff --git a/backend/onyx/context/search/pipeline.py b/backend/onyx/context/search/pipeline.py index faf7a898892..c810b708963 100644 --- a/backend/onyx/context/search/pipeline.py +++ b/backend/onyx/context/search/pipeline.py @@ -331,6 +331,14 @@ def _get_sections(self) -> list[InferenceSection]: self._retrieved_sections = expanded_inference_sections return expanded_inference_sections + @property + def retrieved_sections(self) -> list[InferenceSection]: + if self._retrieved_sections is not None: + return self._retrieved_sections + + self._retrieved_sections = self._get_sections() + return self._retrieved_sections + @property def reranked_sections(self) -> list[InferenceSection]: """Reranking is always done at the chunk level since section merging could create arbitrarily @@ -343,7 +351,7 @@ def reranked_sections(self) -> list[InferenceSection]: if self._reranked_sections is not None: return self._reranked_sections - retrieved_sections = self._get_sections() + retrieved_sections = self.retrieved_sections if self.retrieved_sections_callback is not None: self.retrieved_sections_callback(retrieved_sections) diff --git a/backend/onyx/context/search/preprocessing/preprocessing.py b/backend/onyx/context/search/preprocessing/preprocessing.py index d18ddd32be1..7f18b3f34c2 100644 --- a/backend/onyx/context/search/preprocessing/preprocessing.py +++ b/backend/onyx/context/search/preprocessing/preprocessing.py @@ -118,7 +118,9 @@ def retrieval_preprocessing( ) run_query_analysis = ( - None if skip_query_analysis else FunctionCall(query_analysis, (query,), {}) + None + if (skip_query_analysis or search_request.precomputed_is_keyword is not None) + else FunctionCall(query_analysis, (query,), {}) ) functions_to_run = [ @@ -143,11 +145,12 @@ def retrieval_preprocessing( # The extracted keywords right now are not very reliable, not using for now # Can maybe use for highlighting - is_keyword, extracted_keywords = ( - parallel_results[run_query_analysis.result_id] - if run_query_analysis - else (False, None) - ) + is_keyword, _extracted_keywords = False, None + if search_request.precomputed_is_keyword is not None: + is_keyword = search_request.precomputed_is_keyword + _extracted_keywords = search_request.precomputed_keywords + elif run_query_analysis: + is_keyword, _extracted_keywords = parallel_results[run_query_analysis.result_id] all_query_terms = query.split() processed_keywords = ( @@ -247,4 +250,5 @@ def retrieval_preprocessing( chunks_above=chunks_above, chunks_below=chunks_below, full_doc=search_request.full_doc, + precomputed_query_embedding=search_request.precomputed_query_embedding, ) diff --git a/backend/onyx/context/search/retrieval/search_runner.py b/backend/onyx/context/search/retrieval/search_runner.py index 64491a20a02..6c77167adca 100644 --- a/backend/onyx/context/search/retrieval/search_runner.py +++ b/backend/onyx/context/search/retrieval/search_runner.py @@ -31,7 +31,7 @@ from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.enums import EmbedTextType - +from shared_configs.model_server_models import Embedding logger = setup_logger() @@ -109,6 +109,20 @@ def combine_retrieval_results( return sorted_chunks +def get_query_embedding(query: str, db_session: Session) -> Embedding: + search_settings = get_current_search_settings(db_session) + + model = EmbeddingModel.from_db_model( + search_settings=search_settings, + # The below are globally set, this flow always uses the indexing one + server_host=MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, + ) + + query_embedding = model.encode([query], text_type=EmbedTextType.QUERY)[0] + return query_embedding + + @log_function_time(print_only=True) def doc_index_retrieval( query: SearchQuery, @@ -121,17 +135,10 @@ def doc_index_retrieval( from the large chunks to the referenced chunks, dedupes the chunks, and cleans the chunks. """ - search_settings = get_current_search_settings(db_session) - - model = EmbeddingModel.from_db_model( - search_settings=search_settings, - # The below are globally set, this flow always uses the indexing one - server_host=MODEL_SERVER_HOST, - server_port=MODEL_SERVER_PORT, + query_embedding = query.precomputed_query_embedding or get_query_embedding( + query.query, db_session ) - query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0] - top_chunks = document_index.hybrid_retrieval( query=query.query, query_embedding=query_embedding, @@ -250,6 +257,9 @@ def retrieve_chunks( simplified_queries.add(simplified_rephrase) q_copy = query.copy(update={"query": rephrase}, deep=True) + q_copy.precomputed_query_embedding = ( + None # need to recompute for each rephrase + ) run_queries.append( ( doc_index_retrieval, diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index 4b556e47115..19024a525c1 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -6,12 +6,14 @@ from sqlalchemy.orm import Session +from onyx.agents.agent_search.shared_graph_utils.constants import EMBEDDING_KEY +from onyx.agents.agent_search.shared_graph_utils.constants import IS_KEYWORD_KEY +from onyx.agents.agent_search.shared_graph_utils.constants import KEYWORDS_KEY from onyx.chat.chat_utils import llm_doc_from_inference_section from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import ContextualPruningConfig from onyx.chat.models import DocumentPruningConfig from onyx.chat.models import LlmDoc -from onyx.chat.models import OnyxContext from onyx.chat.models import OnyxContexts from onyx.chat.models import PromptConfig from onyx.chat.models import SectionRelevancePiece @@ -42,6 +44,9 @@ from onyx.tools.models import SearchToolOverrideKwargs from onyx.tools.models import ToolResponse from onyx.tools.tool import Tool +from onyx.tools.tool_implementations.search.search_utils import ( + context_from_inference_section, +) from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict from onyx.tools.tool_implementations.search_like_tool_utils import ( build_next_prompt_for_search_like_tool, @@ -51,6 +56,7 @@ ) from onyx.utils.logger import setup_logger from onyx.utils.special_types import JSON_ro +from shared_configs.model_server_models import Embedding logger = setup_logger() @@ -281,6 +287,11 @@ def run( self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any ) -> Generator[ToolResponse, None, None]: query = cast(str, llm_kwargs[QUERY_FIELD]) + precomputed_query_embedding = cast( + Embedding | None, llm_kwargs.get(EMBEDDING_KEY) + ) + precomputed_is_keyword = cast(bool | None, llm_kwargs.get(IS_KEYWORD_KEY)) + precomputed_keywords = cast(list[str] | None, llm_kwargs.get(KEYWORDS_KEY)) force_no_rerank = False alternate_db_session = None retrieved_sections_callback = None @@ -290,7 +301,6 @@ def run( alternate_db_session = override_kwargs.alternate_db_session retrieved_sections_callback = override_kwargs.retrieved_sections_callback skip_query_analysis = override_kwargs.skip_query_analysis - if self.selected_sections: yield from self._build_response_for_specified_sections(query) return @@ -327,6 +337,9 @@ def run( if self.retrieval_options else None ), + precomputed_query_embedding=precomputed_query_embedding, + precomputed_is_keyword=precomputed_is_keyword, + precomputed_keywords=precomputed_keywords, ), user=self.user, llm=self.llm, @@ -345,8 +358,9 @@ def run( ) yield from yield_search_responses( query, - search_pipeline.reranked_sections, - search_pipeline.final_context_sections, + lambda: search_pipeline.retrieved_sections, + lambda: search_pipeline.reranked_sections, + lambda: search_pipeline.final_context_sections, search_query_info, lambda: search_pipeline.section_relevance, self, @@ -385,8 +399,9 @@ def build_next_prompt( # the retrieved docs (reranking, deduping, etc.) after the SearchTool has run. def yield_search_responses( query: str, - reranked_sections: list[InferenceSection], - final_context_sections: list[InferenceSection], + get_retrieved_sections: Callable[[], list[InferenceSection]], + get_reranked_sections: Callable[[], list[InferenceSection]], + get_final_context_sections: Callable[[], list[InferenceSection]], search_query_info: SearchQueryInfo, get_section_relevance: Callable[[], list[SectionRelevancePiece] | None], search_tool: SearchTool, @@ -395,7 +410,7 @@ def yield_search_responses( id=SEARCH_RESPONSE_SUMMARY_ID, response=SearchResponseSummary( rephrased_query=query, - top_sections=final_context_sections, + top_sections=get_retrieved_sections(), predicted_flow=QueryFlow.QUESTION_ANSWER, predicted_search=search_query_info.predicted_search, final_filters=search_query_info.final_filters, @@ -407,13 +422,8 @@ def yield_search_responses( id=SEARCH_DOC_CONTENT_ID, response=OnyxContexts( contexts=[ - OnyxContext( - content=section.combined_content, - document_id=section.center_chunk.document_id, - semantic_identifier=section.center_chunk.semantic_identifier, - blurb=section.center_chunk.blurb, - ) - for section in reranked_sections + context_from_inference_section(section) + for section in get_reranked_sections() ] ), ) @@ -424,6 +434,7 @@ def yield_search_responses( response=section_relevance, ) + final_context_sections = get_final_context_sections() pruned_sections = prune_sections( sections=final_context_sections, section_relevance_list=section_relevance_list_impl( diff --git a/backend/onyx/tools/tool_implementations/search/search_utils.py b/backend/onyx/tools/tool_implementations/search/search_utils.py index dd44ca0338d..7b6c6383e43 100644 --- a/backend/onyx/tools/tool_implementations/search/search_utils.py +++ b/backend/onyx/tools/tool_implementations/search/search_utils.py @@ -1,4 +1,5 @@ from onyx.chat.models import LlmDoc +from onyx.chat.models import OnyxContext from onyx.context.search.models import InferenceSection from onyx.prompts.prompt_utils import clean_up_source @@ -29,3 +30,12 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict: "%B %d, %Y %H:%M" ) return doc_dict + + +def context_from_inference_section(section: InferenceSection) -> OnyxContext: + return OnyxContext( + content=section.combined_content, + document_id=section.center_chunk.document_id, + semantic_identifier=section.center_chunk.semantic_identifier, + blurb=section.center_chunk.blurb, + ) diff --git a/backend/onyx/utils/threadpool_concurrency.py b/backend/onyx/utils/threadpool_concurrency.py index 4ef87348f1c..fd8b70174a4 100644 --- a/backend/onyx/utils/threadpool_concurrency.py +++ b/backend/onyx/utils/threadpool_concurrency.py @@ -118,7 +118,7 @@ def run_functions_in_parallel( return results -class TimeoutThread(threading.Thread): +class TimeoutThread(threading.Thread, Generic[R]): def __init__( self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any ): @@ -159,3 +159,34 @@ def run_with_timeout( task.end() return task.result + + +# NOTE: this function should really only be used when run_functions_tuples_in_parallel is +# difficult to use. It's up to the programmer to call wait_on_background on the thread after +# the code you want to run in parallel is finished. As with all python thread parallelism, +# this is only useful for I/O bound tasks. +def run_in_background( + func: Callable[..., R], *args: Any, **kwargs: Any +) -> TimeoutThread[R]: + """ + Runs a function in a background thread. Returns a TimeoutThread object that can be used + to wait for the function to finish with wait_on_background. + """ + context = contextvars.copy_context() + # Timeout not used in the non-blocking case + task = TimeoutThread(-1, context.run, func, *args, **kwargs) + task.start() + return task + + +def wait_on_background(task: TimeoutThread[R]) -> R: + """ + Used in conjunction with run_in_background. blocks until the task is finished, + then returns the result of the task. + """ + task.join() + + if task.exception is not None: + raise task.exception + + return task.result diff --git a/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py b/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py index 74399e4d3ad..8b9505bbc2c 100644 --- a/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py +++ b/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py @@ -1,8 +1,14 @@ +import contextvars import time import pytest +from onyx.utils.threadpool_concurrency import run_in_background from onyx.utils.threadpool_concurrency import run_with_timeout +from onyx.utils.threadpool_concurrency import wait_on_background + +# Create a context variable for testing +test_context_var = contextvars.ContextVar("test_var", default="default") def test_run_with_timeout_completes() -> None: @@ -59,3 +65,86 @@ def complex_function(x: int, y: int, multiply: bool = False) -> int: # Test with positional and keyword args result2 = run_with_timeout(1.0, complex_function, x=5, y=3, multiply=True) assert result2 == 15 + + +def test_run_in_background_and_wait_success() -> None: + """Test that run_in_background and wait_on_background work correctly for successful execution""" + + def background_function(x: int) -> int: + time.sleep(0.1) # Small delay to ensure it's actually running in background + return x * 2 + + # Start the background task + task = run_in_background(background_function, 21) + + # Verify we can do other work while task is running + start_time = time.time() + result = wait_on_background(task) + elapsed = time.time() - start_time + + assert result == 42 + assert elapsed >= 0.1 # Verify we actually waited for the sleep + + +@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning") +def test_run_in_background_propagates_exceptions() -> None: + """Test that exceptions in background tasks are properly propagated""" + + def error_function() -> None: + time.sleep(0.1) # Small delay to ensure it's actually running in background + raise ValueError("Test background error") + + task = run_in_background(error_function) + + with pytest.raises(ValueError) as exc_info: + wait_on_background(task) + + assert "Test background error" in str(exc_info.value) + + +def test_run_in_background_with_args_and_kwargs() -> None: + """Test that args and kwargs are properly passed to the background function""" + + def complex_function(x: int, y: int, multiply: bool = False) -> int: + time.sleep(0.1) # Small delay to ensure it's actually running in background + if multiply: + return x * y + return x + y + + # Test with args + task1 = run_in_background(complex_function, 5, 3) + result1 = wait_on_background(task1) + assert result1 == 8 + + # Test with args and kwargs + task2 = run_in_background(complex_function, 5, 3, multiply=True) + result2 = wait_on_background(task2) + assert result2 == 15 + + +def test_multiple_background_tasks() -> None: + """Test running multiple background tasks concurrently""" + + def slow_add(x: int, y: int) -> int: + time.sleep(0.2) # Make each task take some time + return x + y + + # Start multiple tasks + start_time = time.time() + task1 = run_in_background(slow_add, 1, 2) + task2 = run_in_background(slow_add, 3, 4) + task3 = run_in_background(slow_add, 5, 6) + + # Wait for all results + result1 = wait_on_background(task1) + result2 = wait_on_background(task2) + result3 = wait_on_background(task3) + elapsed = time.time() - start_time + + # Verify results + assert result1 == 3 + assert result2 == 7 + assert result3 == 11 + + # Verify tasks ran in parallel (total time should be ~0.2s, not ~0.6s) + assert 0.2 <= elapsed < 0.4 # Allow some buffer for test environment variations diff --git a/backend/tests/unit/onyx/utils/test_threadpool_contextvars.py b/backend/tests/unit/onyx/utils/test_threadpool_contextvars.py index 4d6d9a6a345..ab92b4e5557 100644 --- a/backend/tests/unit/onyx/utils/test_threadpool_contextvars.py +++ b/backend/tests/unit/onyx/utils/test_threadpool_contextvars.py @@ -4,7 +4,9 @@ from onyx.utils.threadpool_concurrency import FunctionCall from onyx.utils.threadpool_concurrency import run_functions_in_parallel from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel +from onyx.utils.threadpool_concurrency import run_in_background from onyx.utils.threadpool_concurrency import run_with_timeout +from onyx.utils.threadpool_concurrency import wait_on_background # Create a test contextvar test_var = contextvars.ContextVar("test_var", default="default") @@ -129,3 +131,39 @@ def set_and_return_contextvar(value: str) -> str: # Verify second run results assert all(result in ["thread3", "thread4"] for result in second_results) + + +def test_run_in_background_preserves_contextvar() -> None: + """Test that run_in_background preserves contextvar values and modifications are isolated""" + + def modify_and_sleep() -> tuple[str, str]: + """Modifies contextvar, sleeps, and returns original, modified, and final values""" + original = test_var.get() + test_var.set("modified_in_background") + time.sleep(0.1) # Ensure we can check main thread during execution + final = test_var.get() + return original, final + + # Set initial value in main thread + token = test_var.set("initial_value") + try: + # Start background task + task = run_in_background(modify_and_sleep) + + # Verify main thread value remains unchanged while task runs + assert test_var.get() == "initial_value" + + # Get results from background thread + original, modified = wait_on_background(task) + + # Verify the background thread: + # 1. Saw the initial value + assert original == "initial_value" + # 2. Successfully modified its own copy + assert modified == "modified_in_background" + + # Verify main thread value is still unchanged after task completion + assert test_var.get() == "initial_value" + finally: + # Clean up + test_var.reset(token)