Skip to content

Commit

Permalink
Update docs and update _arun method as well
Browse files Browse the repository at this point in the history
  • Loading branch information
sjrl committed Apr 9, 2024
1 parent 1182a6e commit b6fd02e
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions haystack/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,12 @@ def run( # type: ignore
return node_output

def _combine_node_outputs(self, existing_input: Dict[str, Any], node_output: Dict[str, Any]) -> Dict[str, Any]:
"""
Combines the outputs of two nodes into a single input for a downstream node. This is useful for join nodes.
:param existing_input: The output of the first node.
:param node_output: The output of the second node.
"""
additional_input = {}
# Pass keys that appear in both inputs that have the same values
shared_items = {
Expand Down Expand Up @@ -749,16 +755,8 @@ async def _arun( # noqa: C901,PLR0912 type: ignore
**existing_input.get("_debug", {}),
**node_output.get("_debug", {}),
}
# Pass keys that appear in both inputs that have the same values
shared_items = {
k: existing_input[k]
for k in existing_input
if k in node_output and existing_input[k] == node_output[k]
}
for key in shared_items:
if key != "inputs" or key != "params" or key != "_debug":
updated_input[key] = shared_items[key]
# TODO Auto pass on keys that only appear once
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **updated_input}
if query and "query" not in updated_input:
updated_input["query"] = query
if file_paths and "file_paths" not in updated_input:
Expand All @@ -770,9 +768,9 @@ async def _arun( # noqa: C901,PLR0912 type: ignore
if meta and "meta" not in updated_input:
updated_input["meta"] = meta
else:
# TODO Would need to redo the shared items here to work with more than 2 streams joining
existing_input["inputs"].append(node_output)
updated_input = existing_input
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **existing_input}
queue[n] = updated_input
else:
queue[n] = node_output
Expand Down

0 comments on commit b6fd02e

Please sign in to comment.