Skip to content

Commit

Permalink
address PR comments + minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
evan-danswer committed Mar 6, 2025
1 parent a4bb37d commit dc6738a
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def retrieve_documents(
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]

if AGENT_RETRIEVAL_STATS:
pre_rerank_docs = callback_container[0]
pre_rerank_docs = callback_container[0] if callback_container else []
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def call_tool(
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(tool, tool_args)
tool_runner = ToolRunner(
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
)
tool_kickoff = tool_runner.kickoff()

emit_packet(tool_kickoff, writer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@
from onyx.agents.agent_search.orchestration.states import ToolChoice
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.constants import EMBEDDING_KEY
from onyx.agents.agent_search.shared_graph_utils.constants import IS_KEYWORD_KEY
from onyx.agents.agent_search.shared_graph_utils.constants import KEYWORDS_KEY
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
Expand Down Expand Up @@ -54,13 +52,15 @@ def choose_tool(

embedding_thread: TimeoutThread[Embedding] | None = None
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
override_kwargs: SearchToolOverrideKwargs | None = None
if (
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and (
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
)
):
override_kwargs = SearchToolOverrideKwargs()
# Run in a background thread to avoid blocking the main thread
embedding_thread = run_in_background(
get_query_embedding,
Expand Down Expand Up @@ -108,16 +108,19 @@ def choose_tool(
if embedding_thread and tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
tool_args[EMBEDDING_KEY] = embedding
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
tool_args[IS_KEYWORD_KEY] = is_keyword
tool_args[KEYWORDS_KEY] = keywords
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
search_tool_override_kwargs=override_kwargs,
),
)

Expand Down Expand Up @@ -190,16 +193,19 @@ def choose_tool(
if embedding_thread and selected_tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
selected_tool_call_request["args"][EMBEDDING_KEY] = embedding
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and selected_tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
selected_tool_call_request["args"][IS_KEYWORD_KEY] = is_keyword
selected_tool_call_request["args"][KEYWORDS_KEY] = keywords
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords

return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
search_tool_override_kwargs=override_kwargs,
),
)
2 changes: 2 additions & 0 deletions backend/onyx/agents/agent_search/orchestration/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
Expand Down Expand Up @@ -35,6 +36,7 @@ class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
search_tool_override_kwargs: SearchToolOverrideKwargs | None = None

class Config:
arbitrary_types_allowed = True
Expand Down
2 changes: 2 additions & 0 deletions backend/onyx/chat/llm_response_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler


# This is Legacy code that is not used anymore.
# It is kept here for reference.
class LLMResponseHandlerManager:
"""
This class is responsible for postprocessing the LLM response stream.
Expand Down
2 changes: 2 additions & 0 deletions backend/onyx/context/search/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def retrieval_preprocessing(
else None
)

# Sometimes this is pre-computed in parallel with other heavy tasks to improve
# latency, and in that case we don't need to run the model again
run_query_analysis = (
None
if (skip_query_analysis or search_request.precomputed_is_keyword is not None)
Expand Down
13 changes: 9 additions & 4 deletions backend/onyx/tools/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection
from shared_configs.model_server_models import Embedding


class ToolResponse(BaseModel):
Expand Down Expand Up @@ -60,11 +61,15 @@ class SearchQueryInfo(BaseModel):
recency_bias_multiplier: float


# None indicates that the default value should be used
class SearchToolOverrideKwargs(BaseModel):
force_no_rerank: bool
alternate_db_session: Session | None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
skip_query_analysis: bool
force_no_rerank: bool | None = None
alternate_db_session: Session | None = None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None = None
skip_query_analysis: bool | None = None
precomputed_query_embedding: Embedding | None = None
precomputed_is_keyword: bool | None = None
precomputed_keywords: list[str] | None = None

class Config:
arbitrary_types_allowed = True
Expand Down
34 changes: 23 additions & 11 deletions backend/onyx/tools/tool_implementations/search/search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
from collections.abc import Generator
from typing import Any
from typing import cast
from typing import TypeVar

from sqlalchemy.orm import Session

from onyx.agents.agent_search.shared_graph_utils.constants import EMBEDDING_KEY
from onyx.agents.agent_search.shared_graph_utils.constants import IS_KEYWORD_KEY
from onyx.agents.agent_search.shared_graph_utils.constants import KEYWORDS_KEY
from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ContextualPruningConfig
Expand Down Expand Up @@ -56,7 +54,6 @@
)
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
from shared_configs.model_server_models import Embedding

logger = setup_logger()

Expand Down Expand Up @@ -287,20 +284,23 @@ def run(
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]:
query = cast(str, llm_kwargs[QUERY_FIELD])
precomputed_query_embedding = cast(
Embedding | None, llm_kwargs.get(EMBEDDING_KEY)
)
precomputed_is_keyword = cast(bool | None, llm_kwargs.get(IS_KEYWORD_KEY))
precomputed_keywords = cast(list[str] | None, llm_kwargs.get(KEYWORDS_KEY))
precomputed_query_embedding = None
precomputed_is_keyword = None
precomputed_keywords = None
force_no_rerank = False
alternate_db_session = None
retrieved_sections_callback = None
skip_query_analysis = False
if override_kwargs:
force_no_rerank = override_kwargs.force_no_rerank
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
alternate_db_session = override_kwargs.alternate_db_session
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
skip_query_analysis = override_kwargs.skip_query_analysis
skip_query_analysis = use_alt_not_None(
override_kwargs.skip_query_analysis, False
)
precomputed_query_embedding = override_kwargs.precomputed_query_embedding
precomputed_is_keyword = override_kwargs.precomputed_is_keyword
precomputed_keywords = override_kwargs.precomputed_keywords
if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
return
Expand Down Expand Up @@ -397,6 +397,11 @@ def build_next_prompt(
# SearchTool passed in to allow for access to SearchTool properties.
# We can't just call SearchTool methods in the graph because we're operating on
# the retrieved docs (reranking, deduping, etc.) after the SearchTool has run.
#
# The various inference sections are passed in as functions to allow for lazy
# evaluation. The SearchPipeline object properties that they correspond to are
# actually functions defined with @property decorators, and passing them into
# this function causes them to get evaluated immediately which is undesirable.
def yield_search_responses(
query: str,
get_retrieved_sections: Callable[[], list[InferenceSection]],
Expand Down Expand Up @@ -449,3 +454,10 @@ def yield_search_responses(
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]

yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)


T = TypeVar("T")


def use_alt_not_None(value: T | None, alt: T) -> T:
return value if value is not None else alt
16 changes: 13 additions & 3 deletions backend/onyx/tools/tool_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import Generic
from typing import TypeVar

from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
Expand All @@ -11,10 +13,16 @@
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel


class ToolRunner:
def __init__(self, tool: Tool, args: dict[str, Any]):
R = TypeVar("R")


class ToolRunner(Generic[R]):
def __init__(
self, tool: Tool[R], args: dict[str, Any], override_kwargs: R | None = None
):
self.tool = tool
self.args = args
self.override_kwargs = override_kwargs

self._tool_responses: list[ToolResponse] | None = None

Expand All @@ -27,7 +35,9 @@ def tool_responses(self) -> Generator[ToolResponse, None, None]:
return

tool_responses: list[ToolResponse] = []
for tool_response in self.tool.run(**self.args):
for tool_response in self.tool.run(
override_kwargs=self.override_kwargs, **self.args
):
yield tool_response
tool_responses.append(tool_response)

Expand Down

0 comments on commit dc6738a

Please sign in to comment.