Skip to content

Commit

Permalink
Move logic to place new nodes to PyironFlowWidget
Browse files Browse the repository at this point in the history
Use the new traitlet there to place them in the bottom corner.
  • Loading branch information
pmrv committed Mar 9, 2025
1 parent 29d95c0 commit 4681552
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 20 deletions.
31 changes: 31 additions & 0 deletions pyironflow/reactflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
dict_to_node,
dict_to_edge,
create_macro,
NODE_WIDTH
)
from pyiron_workflow.mixin.run import ReadinessError

Expand Down Expand Up @@ -267,9 +268,39 @@ def update_status(self):
def react_flow_widget(self):
return self.gui

def place_new_node(self):
"""Find a suitable location in UI space for the newly added node.
Exact layouting not required as this can be done in UI, but newly added
nodes should be visible to the user and not completely overlap.
FIXME: Probably this is better handled completely in UI by elk.
"""
view = json.loads(self.gui.view)
if view == {}:
position = [0, 0]
else:
position = [
-view['x'] + 0.1 * view['height'],
-view['y'] + 0.9 * view['height'],
]

def blocked():
for node in self.wf.children.values():
if 'position' in dir(node):
print(node.position, position)
if node.position == tuple(position):
return True
return False
while blocked():
position[0] += NODE_WIDTH + 10

return tuple(position)

def add_node(self, node_path, label):
self.wf = self.get_workflow()
node = get_node_from_path(node_path, log=self.log)
node.position = self.place_new_node()
if node is not None:
self.log.append_stdout(f"add_node (reactflow): {node}, {label} \n")
if label in self.wf.child_labels:
Expand Down
30 changes: 10 additions & 20 deletions pyironflow/wf_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import types
import math

NODE_WIDTH = 240

def get_import_path(obj):
module = obj.__module__ if hasattr(obj, "__module__") else obj.__class__.__module__
# name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
Expand Down Expand Up @@ -116,20 +118,15 @@ def get_node_types(node_io):
return node_io_types


def get_node_position(node, max_x, node_width=240, y0=100, x_spacing=20):
def get_node_position(node):
if 'position' in dir(node):
x, y = node.position
# if isinstance(x, str):
# x, y = 0, 0
else:
x = max_x + node_width + x_spacing
y = y0

x, y = 0, 0
return {'x': x, 'y': y}


def get_node_dict(node, max_x, key=None):
node_width = 240
def get_node_dict(node, key=None):
n_inputs = len(list(node.inputs.channel_dict.keys()))
n_outputs = len(list(node.outputs.channel_dict.keys()))
if n_outputs > n_inputs:
Expand All @@ -155,13 +152,13 @@ def get_node_dict(node, max_x, key=None):
'ready': str(node.outputs.ready),
'python_object_id': id(node),
},
'position': get_node_position(node, max_x),
'position': get_node_position(node),
'type': 'customNode',
'style': {'padding': 5,
'background': get_color(node=node, theme='light'),
'borderRadius': '10px',
'width': f'{node_width}px',
'width_unitless': node_width,
'width': f'{NODE_WIDTH}PX',
'width_unitless': NODE_WIDTH,
'height': f'{node_height}px',
'height_unitless': node_height},
'targetPosition': 'left',
Expand All @@ -171,15 +168,8 @@ def get_node_dict(node, max_x, key=None):

def get_nodes(wf):
nodes = []
x_coords = []
max_x = 0
for i, (k, v) in enumerate(wf.children.items()):
if 'position' in dir(v):
x_coords.append(v.position[0])
if len(x_coords) > 0:
max_x = max(x_coords)
for i, (k, v) in enumerate(wf.children.items()):
nodes.append(get_node_dict(v, max_x, key=k))
for k, v in wf.children.items():
nodes.append(get_node_dict(v, key=k))
return nodes


Expand Down

0 comments on commit 4681552

Please sign in to comment.