diff --git a/haystack/pipelines/base.py b/haystack/pipelines/base.py index 1fdf506227..538d207500 100644 --- a/haystack/pipelines/base.py +++ b/haystack/pipelines/base.py @@ -623,6 +623,12 @@ def run( # type: ignore return node_output def _combine_node_outputs(self, existing_input: Dict[str, Any], node_output: Dict[str, Any]) -> Dict[str, Any]: + """ + Combines the outputs of two nodes into a single input for a downstream node. This is useful for join nodes. + + :param existing_input: The output of the first node. + :param node_output: The output of the second node. + """ additional_input = {} # Pass keys that appear in both inputs that have the same values shared_items = { @@ -749,16 +755,8 @@ async def _arun( # noqa: C901,PLR0912 type: ignore **existing_input.get("_debug", {}), **node_output.get("_debug", {}), } - # Pass keys that appear in both inputs that have the same values - shared_items = { - k: existing_input[k] - for k in existing_input - if k in node_output and existing_input[k] == node_output[k] - } - for key in shared_items: - if key != "inputs" or key != "params" or key != "_debug": - updated_input[key] = shared_items[key] - # TODO Auto pass on keys that only appear once + additional_input = self._combine_node_outputs(existing_input, node_output) + updated_input = {**additional_input, **updated_input} if query and "query" not in updated_input: updated_input["query"] = query if file_paths and "file_paths" not in updated_input: @@ -770,9 +768,9 @@ async def _arun( # noqa: C901,PLR0912 type: ignore if meta and "meta" not in updated_input: updated_input["meta"] = meta else: - # TODO Would need to redo the shared items here to work with more than 2 streams joining existing_input["inputs"].append(node_output) - updated_input = existing_input + additional_input = self._combine_node_outputs(existing_input, node_output) + updated_input = {**additional_input, **existing_input} queue[n] = updated_input else: queue[n] = node_output