Skip to content

Commit

Permalink
move function again, add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyballentine committed Jun 15, 2024
1 parent 3d38e87 commit ecba60e
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 36 deletions.
38 changes: 2 additions & 36 deletions backend/src/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from chain.input import EdgeInput, Input, InputMap
from events import EventConsumer, InputsDict
from progress_controller import Aborted, ProgressController, ProgressToken
from util import timed_supplier
from util import combine_sets, timed_supplier

Output = List[object]

Expand Down Expand Up @@ -848,40 +848,6 @@ async def __iterate_generator_nodes(self, generator_nodes: list[GeneratorNode]):
error_string = "- " + "\n- ".join(deferred_errors)
raise Exception(f"Errors occurred during iteration:\n{error_string}")

def __combine_sets(self, set_list: list[set[NodeId]]) -> list[set[NodeId]]:
"""
Combines sets in a list which have at least one intersecting value
Example:
in: [{gen1, gen2}, {gen1, gen4}, {gen3}]
out: [{gen1, gen2, gen4}, {gen3}]
Note:
This code was written by ChatGPT. I tried to make my own algorithm for this, as well as
find resources to help online, and was unsuccessful. From all my testing, this implementation
seems to be both correct and performant. However, if you are familiar with this problem
and you know a better way to do this, please submit a PR with a modification.
"""
sets = [set(x) for x in set_list]
combined = True
while combined:
combined = False
new_sets = []
# Process each set in the input list
while sets:
current = sets.pop()
merged = False
# Compare the current set with the remaining sets
for i, s in enumerate(sets):
if current & s: # Check for intersection
sets[i] = current | s # Union of sets with common elements
merged = True
combined = True # Indicates that a merge occurred
break
if not merged:
new_sets.append(current) # No merge, add current set to new_sets
sets = new_sets # Update sets with the remaining sets
return sets

async def __process_nodes(self):
self.__send_chain_start()

Expand Down Expand Up @@ -910,7 +876,7 @@ async def __process_nodes(self):
gens_by_outs[out_node.id] = {node.id}

groups: list[set[NodeId]] = list(gens_by_outs.values())
combined_groups = self.__combine_sets(groups)
combined_groups = combine_sets(groups)

# TODO: Look for a way to avoid duplicating this work
for group in combined_groups:
Expand Down
38 changes: 38 additions & 0 deletions backend/src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,41 @@ def wrapper():
return result, duration

return wrapper


T = TypeVar("T")


def combine_sets(set_list: list[set[T]]) -> list[set[T]]:
"""
Combines sets in a list which have at least one intersecting value
Example:
in: [{1, 2}, {1, 4}, {3}, {3, 5}, {6}]
out: [{1, 2, 4}, {3, 5}, {6}]
Note:
This code was written by ChatGPT. I tried to make my own algorithm for this, as well as
find resources to help online, and was unsuccessful. From all my testing, this implementation
seems to be both correct and performant. However, if you are familiar with this problem
and you know a better way to do this, please submit a PR with a modification.
"""
sets = [set(x) for x in set_list]
combined = True
while combined:
combined = False
new_sets = []
# Process each set in the input list
while sets:
current = sets.pop()
merged = False
# Compare the current set with the remaining sets
for i, s in enumerate(sets):
if current & s: # Check for intersection
sets[i] = current | s # Union of sets with common elements
merged = True
combined = True # Indicates that a merge occurred
break
if not merged:
new_sets.append(current) # No merge, add current set to new_sets
sets = new_sets # Update sets with the remaining sets
return sets
94 changes: 94 additions & 0 deletions backend/tests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from util import combine_sets


def test_combine_sets_one():
test_sets = [{1, 2, 3}, {1, 4}, {5, 6}, {7, 8}]
test_output = [{1, 2, 3, 4}, {5, 6}, {7, 8}]
result = combine_sets(test_sets)
assert result == test_output


def test_combine_sets_two():
test_sets = [
{1, 2, 3},
{1, 4},
{5, 6},
{7, 8},
{6, 7, 8, 9, 10},
{50, 51, 52},
{60, 61},
]
test_output = [{1, 2, 3, 4}, {5, 6, 7, 8, 9, 10}, {50, 51, 52}, {60, 61}]
result = combine_sets(test_sets)
assert result == test_output


def test_combine_sets_three():
test_sets = [
{60, 90, 34},
{1, 2, 3},
{1, 4, 87},
{5, 6},
{7, 8},
{6, 7, 8, 9, 10},
{50, 51, 52},
{60, 61},
]
test_output = [
{34, 60, 61, 90},
{1, 2, 3, 4, 87},
{5, 6, 7, 8, 9, 10},
{50, 51, 52},
]
result = combine_sets(test_sets)
assert result == test_output


def test_combine_sets_four():
test_sets = [
{"a", "b", "c"},
{"d", "e"},
{"f", "g", "h"},
{"h", "i", "j", "k"},
{"x", "y", "z"},
]
test_output = [
{"a", "b", "c"},
{"d", "e"},
{"f", "g", "h", "i", "j", "k"},
{"x", "y", "z"},
]
result = combine_sets(test_sets)
assert result == test_output


def test_combine_sets_five():
test_sets = [
{"a", "b", "c"},
{"c", "d", "e", "f"},
{"f", "g", "h"},
{"h", "i", "j", "k"},
{"x", "y", "z"},
{"k", "lmnopqrstuvw", "x", "y"},
]
test_output = [
{
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
"j",
"k",
"lmnopqrstuvw",
"x",
"y",
"z",
},
]
result = combine_sets(test_sets)
assert result == test_output

0 comments on commit ecba60e

Please sign in to comment.