Skip to content

Commit

Permalink
Added support for unique keys
Browse files Browse the repository at this point in the history
  • Loading branch information
sjrl committed Apr 9, 2024
1 parent b6fd02e commit 1b5ae7d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
16 changes: 13 additions & 3 deletions haystack/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,6 @@ def run( # type: ignore
updated_input["meta"] = meta
else:
existing_input["inputs"].append(node_output)
# TODO This doesn't have an effect until we also pass on keys that only occur once
additional_input = self._combine_node_outputs(existing_input, node_output)
updated_input = {**additional_input, **existing_input}
queue[n] = updated_input
Expand All @@ -630,13 +629,24 @@ def _combine_node_outputs(self, existing_input: Dict[str, Any], node_output: Dic
:param node_output: The output of the second node.
"""
additional_input = {}
# Pass keys that appear in both inputs that have the same values
# TODO Should we support overwriting keys that exist in both? --> first node's value is kept
# Add shared items from existing_input and node_output that have matching values
shared_items = {
k: existing_input[k] for k in existing_input if k in node_output and existing_input[k] == node_output[k]
}
for key in shared_items:
if key != "inputs" or key != "params" or key != "_debug":
if key not in ["inputs", "params", "_debug"]:
additional_input[key] = shared_items[key]
unique_existing_input = {k: v for k, v in existing_input.items() if k not in shared_items}
# Add unique keys from existing_input
for key in unique_existing_input:
if key not in ["inputs", "params", "_debug"]:
additional_input[key] = unique_existing_input[key]
# Add unique keys from node_output
unique_node_output = {k: v for k, v in node_output.items() if k not in shared_items}
for key in unique_node_output:
if key not in ["inputs", "params", "_debug"]:
additional_input[key] = unique_node_output[key]
return additional_input

async def _arun( # noqa: C901,PLR0912 type: ignore
Expand Down
23 changes: 14 additions & 9 deletions test/pipelines/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2143,29 +2143,34 @@ def test_pipeline_execution_using_join_preserves_previous_keys_three_streams():
document_store_3.write_documents(dicts_3)

# Create Shaper to insert "invocation_context" and "test_key" into the node_output
shaper = Shaper(func="rename", inputs={"value": "query"}, outputs=["test_key"])
shaper1 = Shaper(func="rename", inputs={"value": "query"}, outputs=["test_key1"])
shaper2 = Shaper(func="rename", inputs={"value": "query"}, outputs=["test_key2"])

pipeline = Pipeline()
pipeline.add_node(component=shaper, name="Shaper", inputs=["Query"])
pipeline.add_node(component=retriever_1, name="Retriever1", inputs=["Shaper"])
pipeline.add_node(component=retriever_2, name="Retriever2", inputs=["Shaper"])
pipeline.add_node(component=retriever_3, name="Retriever3", inputs=["Shaper"])
pipeline.add_node(component=shaper1, name="Shaper1", inputs=["Query"])
pipeline.add_node(component=shaper2, name="Shaper2", inputs=["Query"])
pipeline.add_node(component=retriever_3, name="Retriever3", inputs=["Shaper2"])
pipeline.add_node(component=retriever_1, name="Retriever1", inputs=["Shaper1"])
pipeline.add_node(component=retriever_2, name="Retriever2", inputs=["Shaper1"])

pipeline.add_node(
component=JoinDocuments(join_mode="concatenate"), name="Join", inputs=["Retriever1", "Retriever2", "Retriever3"]
component=JoinDocuments(join_mode="concatenate"), name="Join", inputs=["Retriever3", "Retriever1", "Retriever2"]
)
res = pipeline.run(query="Alpha Beta Gamma Delta")
assert set(res.keys()) == {
"documents",
"labels",
"root_node",
"params",
"test_key",
"test_key1",
"test_key2",
"invocation_context",
"query",
"node_id",
}
assert res["test_key"] == "Alpha Beta Gamma Delta"
assert res["invocation_context"] == {"query": "Alpha Beta Gamma Delta", "test_key": "Alpha Beta Gamma Delta"}
assert res["test_key1"] == "Alpha Beta Gamma Delta"
assert res["test_key2"] == "Alpha Beta Gamma Delta"
assert res["invocation_context"] == {"query": "Alpha Beta Gamma Delta", "test_key1": "Alpha Beta Gamma Delta"}
assert len(res["documents"]) == 3


Expand Down

0 comments on commit 1b5ae7d

Please sign in to comment.