From 4fed1272f370cdfa497c27c446ac3e98d2800a7e Mon Sep 17 00:00:00 2001 From: Peter Shevcnenko <57573631+MorrisNein@users.noreply.github.com> Date: Thu, 26 Oct 2023 16:59:19 +0300 Subject: [PATCH] Scalable graph vizualization (#214) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * introduce scalable graph visualization & small code refactoring * fix title * fix test_hierarchy_pos * зуз8 * add args `nodes_layout_function`, `node_names_placement` * adjust font size --- golem/core/dag/graph.py | 28 +- golem/visualisation/graph_viz.py | 547 ++++++++++-------- .../visualisation/test_visualisation_utils.py | 11 +- 3 files changed, 343 insertions(+), 243 deletions(-) diff --git a/golem/core/dag/graph.py b/golem/core/dag/graph.py index 298d21d97..239bb23da 100644 --- a/golem/core/dag/graph.py +++ b/golem/core/dag/graph.py @@ -1,7 +1,9 @@ from abc import ABC, abstractmethod from enum import Enum from os import PathLike -from typing import Dict, List, Optional, Sequence, Union, Tuple, TypeVar +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, TypeVar, Union + +import networkx as nx from golem.core.dag.graph_node import GraphNode from golem.visualisation.graph_viz import GraphVisualizer, NodeColorType @@ -208,7 +210,9 @@ def show(self, save_path: Optional[Union[PathLike, str]] = None, engine: Optiona node_size_scale: Optional[float] = None, font_size_scale: Optional[float] = None, edge_curvature_scale: Optional[float] = None, title: Optional[str] = None, - nodes_labels: Dict[int, str] = None, edges_labels: Dict[int, str] = None): + node_names_placement: Optional[Literal['auto', 'nodes', 'legend', 'none']] = None, + nodes_labels: Dict[int, str] = None, edges_labels: Dict[int, str] = None, + nodes_layout_function: Optional[Callable[[nx.DiGraph], Dict[Any, Tuple[float, float]]]] = None): """Visualizes graph or saves its picture to the specified ``path`` Args: @@ -220,14 +224,28 @@ def show(self, save_path: Optional[Union[PathLike, str]] = None, engine: Optiona edge_curvature_scale: use to make edges more or less curved. Supported only for the engine 'matplotlib'. dpi: DPI of the output image. Not supported for the engine 'pyvis'. title: title for plot + node_names_placement: variant of node names displaying. Defaults to ``auto``. + + Possible options: + + - ``auto`` -> empirical rule by node size + + - ``nodes`` -> place node names on top of the nodes + + - ``legend`` -> place node names at the legend + + - ``none`` -> do not show node names + nodes_labels: labels to display near nodes edges_labels: labels to display near edges + nodes_layout_function: any of `Networkx layout functions \ + `_ . """ - GraphVisualizer(graph=self)\ + GraphVisualizer(graph=self) \ .visualise(save_path=save_path, engine=engine, node_color=node_color, dpi=dpi, node_size_scale=node_size_scale, font_size_scale=font_size_scale, - edge_curvature_scale=edge_curvature_scale, - title=title, + edge_curvature_scale=edge_curvature_scale, node_names_placement=node_names_placement, + title=title, nodes_layout_function=nodes_layout_function, nodes_labels=nodes_labels, edges_labels=edges_labels) @property diff --git a/golem/visualisation/graph_viz.py b/golem/visualisation/graph_viz.py index fac8bd78b..09517a86c 100644 --- a/golem/visualisation/graph_viz.py +++ b/golem/visualisation/graph_viz.py @@ -5,7 +5,7 @@ from copy import deepcopy from pathlib import Path from textwrap import wrap -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union, List +from typing import Any, Callable, Dict, Iterable, Literal, Optional, Sequence, TYPE_CHECKING, Tuple, Union from uuid import uuid4 import networkx as nx @@ -16,16 +16,18 @@ from pyvis.network import Network from seaborn import color_palette -from golem.core.dag.graph_utils import distance_to_primary_level from golem.core.dag.convert import graph_structure_as_nx_graph +from golem.core.dag.graph_utils import distance_to_primary_level from golem.core.log import default_log from golem.core.paths import default_data_dir if TYPE_CHECKING: from golem.core.dag.graph import Graph from golem.core.optimisers.graph import OptGraph + from golem.core.dag.graph_node import GraphNode GraphType = Union[Graph, OptGraph] + GraphConvertType = Callable[[GraphType], Tuple[nx.DiGraph, Dict[uuid4, GraphNode]]] PathType = Union[os.PathLike, str] @@ -36,106 +38,141 @@ class GraphVisualizer: - def __init__(self, graph: GraphType, visuals_params: Optional[Dict[str, Any]] = None, ): + def __init__(self, graph: GraphType, visuals_params: Optional[Dict[str, Any]] = None, + to_nx_convert_func: GraphConvertType = graph_structure_as_nx_graph): visuals_params = visuals_params or {} default_visuals_params = dict( engine='matplotlib', dpi=100, - node_color=self.__get_colors_by_labels, + node_color=self._get_colors_by_labels, node_size_scale=1.0, font_size_scale=1.0, edge_curvature_scale=1.0, - graph_to_nx_convert_func=graph_structure_as_nx_graph + node_names_placement='auto', + nodes_layout_function=GraphVisualizer._get_hierarchy_pos_by_distance_to_primary_level, + figure_size=(7, 7), + save_path=None, ) default_visuals_params.update(visuals_params) self.visuals_params = default_visuals_params - self.graph = graph - + self.to_nx_convert_func = to_nx_convert_func + self._update_graph(graph) self.log = default_log(self) + def _update_graph(self, graph: GraphType): + self.graph = graph + self.nx_graph, self.nodes_dict = self.to_nx_convert_func(self.graph) + def visualise(self, save_path: Optional[PathType] = None, engine: Optional[str] = None, node_color: Optional[NodeColorType] = None, dpi: Optional[int] = None, - node_size_scale: Optional[float] = None, - font_size_scale: Optional[float] = None, edge_curvature_scale: Optional[float] = None, - title: Optional[str] = None, - nodes_labels: Dict[int, str] = None, edges_labels: Dict[int, str] = None): - engine = engine or self.get_predefined_value('engine') + node_size_scale: Optional[float] = None, font_size_scale: Optional[float] = None, + edge_curvature_scale: Optional[float] = None, figure_size: Optional[Tuple[int, int]] = None, + nodes_labels: Dict[int, str] = None, edges_labels: Dict[int, str] = None, + node_names_placement: Optional[Literal['auto', 'nodes', 'legend', 'none']] = None, + nodes_layout_function: Optional[Callable[[nx.DiGraph], Dict[Any, Tuple[float, float]]]] = None, + title: Optional[str] = None): + engine = engine or self._get_predefined_value('engine') if not self.graph.nodes: raise ValueError('Empty graph can not be visualized.') if engine == 'matplotlib': - self.__draw_with_networkx(save_path=save_path, node_color=node_color, dpi=dpi, - node_size_scale=node_size_scale, font_size_scale=font_size_scale, - edge_curvature_scale=edge_curvature_scale, - title=title, nodes_labels=nodes_labels, edges_labels=edges_labels) + self._draw_with_networkx(save_path=save_path, node_color=node_color, dpi=dpi, + node_size_scale=node_size_scale, font_size_scale=font_size_scale, + edge_curvature_scale=edge_curvature_scale, figure_size=figure_size, + title=title, nodes_labels=nodes_labels, edges_labels=edges_labels, + nodes_layout_function=nodes_layout_function, + node_names_placement=node_names_placement) elif engine == 'pyvis': - self.__draw_with_pyvis(save_path, node_color) + self._draw_with_pyvis(save_path, node_color) elif engine == 'graphviz': - self.__draw_with_graphviz(save_path, node_color, dpi) + self._draw_with_graphviz(save_path, node_color, dpi) else: raise NotImplementedError(f'Unexpected visualization engine: {engine}. ' 'Possible values: matplotlib, pyvis, graphviz.') - @staticmethod - def __get_colors_by_labels(labels: Iterable[str]) -> LabelsColorMapType: - unique_labels = list(set(labels)) - palette = color_palette('tab10', len(unique_labels)) - return {label: palette[unique_labels.index(label)] for label in labels} + def draw_nx_dag( + self, ax: Optional[plt.Axes] = None, node_color: Optional[NodeColorType] = None, + node_size_scale: Optional[float] = None, font_size_scale: Optional[float] = None, + edge_curvature_scale: Optional[float] = None, nodes_labels: Dict[int, str] = None, + edges_labels: Dict[int, str] = None, + nodes_layout_function: Optional[Callable[[nx.DiGraph], Dict[Any, Tuple[float, float]]]] = None, + node_names_placement: Optional[Literal['auto', 'nodes', 'legend', 'none']] = None): + node_color = node_color or self._get_predefined_value('node_color') + node_size_scale = node_size_scale or self._get_predefined_value('node_size_scale') + font_size_scale = font_size_scale or self._get_predefined_value('font_size_scale') + edge_curvature_scale = (edge_curvature_scale if edge_curvature_scale is not None + else self._get_predefined_value('edge_curvature_scale')) + nodes_layout_function = nodes_layout_function or self._get_predefined_value('nodes_layout_function') + node_names_placement = node_names_placement or self._get_predefined_value('node_names_placement') - def __draw_with_graphviz(self, save_path: Optional[PathType] = None, node_color: Optional[NodeColorType] = None, - dpi: Optional[int] = None, graph_to_nx_convert_func: Optional[Callable] = None): - save_path = save_path or self.get_predefined_value('save_path') - node_color = node_color or self.get_predefined_value('node_color') - dpi = dpi or self.get_predefined_value('dpi') - graph_to_nx_convert_func = graph_to_nx_convert_func or self.get_predefined_value('graph_to_nx_convert_func') + nx_graph, nodes = self.nx_graph, self.nodes_dict + + if ax is None: + ax = plt.gca() - nx_graph, nodes = graph_to_nx_convert_func(self.graph) # Define colors if callable(node_color): - colors = node_color([str(node) for node in nodes.values()]) - elif isinstance(node_color, dict): - colors = node_color + node_color = node_color([str(node) for node in nodes.values()]) + if isinstance(node_color, dict): + node_color = [node_color.get(str(node), node_color.get(None)) for node in nodes.values()] else: - colors = {str(node): node_color for node in nodes.values()} - for n, data in nx_graph.nodes(data=True): - label = str(nodes[n]) - data['label'] = label.replace('_', ' ') - data['color'] = to_hex(colors.get(label, colors.get(None))) + node_color = [node_color for _ in nodes] + # Get node positions + if nodes_layout_function == GraphVisualizer._get_hierarchy_pos_by_distance_to_primary_level: + pos = nodes_layout_function(nx_graph, nodes) + else: + pos = nodes_layout_function(nx_graph) - gv_graph = nx.nx_agraph.to_agraph(nx_graph) - kwargs = {'prog': 'dot', 'args': f'-Gnodesep=0.5 -Gdpi={dpi} -Grankdir="LR"'} + node_size = self._get_scaled_node_size(len(nodes), node_size_scale) - if save_path: - gv_graph.draw(save_path, **kwargs) - else: - save_path = Path(default_data_dir(), 'graph_plots', str(uuid4()) + '.png') - save_path.parent.mkdir(exist_ok=True) - gv_graph.draw(save_path, **kwargs) + with_node_names = node_names_placement != 'none' - img = plt.imread(str(save_path)) - plt.imshow(img) - plt.gca().axis('off') - plt.gcf().set_dpi(dpi) - plt.tight_layout() + if node_names_placement in ('auto', 'none'): + node_names_placement = GraphVisualizer._define_node_names_placement(node_size) + + if node_names_placement == 'nodes': + self._draw_nx_big_nodes(ax, pos, nodes, node_color, node_size, font_size_scale, with_node_names) + elif node_names_placement == 'legend': + self._draw_nx_small_nodes(ax, pos, nodes, node_color, node_size, font_size_scale, with_node_names) + self._draw_nx_curved_edges(ax, pos, node_size, edge_curvature_scale) + self._draw_nx_labels(ax, pos, font_size_scale, nodes_labels, edges_labels) + + def _get_predefined_value(self, param: str): + if param not in self.visuals_params: + self.log.warning(f'No default param found: {param}.') + return self.visuals_params.get(param) + + def _draw_with_networkx( + self, save_path: Optional[PathType] = None, node_color: Optional[NodeColorType] = None, + dpi: Optional[int] = None, node_size_scale: Optional[float] = None, + font_size_scale: Optional[float] = None, edge_curvature_scale: Optional[float] = None, + figure_size: Optional[Tuple[int, int]] = None, title: Optional[str] = None, + nodes_labels: Dict[int, str] = None, edges_labels: Dict[int, str] = None, + nodes_layout_function: Optional[Callable[[nx.DiGraph], Dict[Any, Tuple[float, float]]]] = None, + node_names_placement: Optional[Literal['auto', 'nodes', 'legend', 'none']] = None): + save_path = save_path or self._get_predefined_value('save_path') + node_color = node_color or self._get_predefined_value('node_color') + dpi = dpi or self._get_predefined_value('dpi') + figure_size = figure_size or self._get_predefined_value('figure_size') + + ax = GraphVisualizer._setup_matplotlib_figure(figure_size, dpi, title) + self.draw_nx_dag(ax, node_color, node_size_scale, font_size_scale, edge_curvature_scale, + nodes_labels, edges_labels, nodes_layout_function, node_names_placement) + GraphVisualizer._rescale_matplotlib_figure(ax) + if not save_path: plt.show() - remove_old_files_from_dir(save_path.parent) + else: + plt.savefig(save_path, dpi=dpi) + plt.close() - def __draw_with_pyvis(self, save_path: Optional[PathType] = None, node_color: Optional[NodeColorType] = None, - graph_to_nx_convert_func: Optional[Callable] = None): - save_path = save_path or self.get_predefined_value('save_path') - node_color = node_color or self.get_predefined_value('node_color') - graph_to_nx_convert_func = graph_to_nx_convert_func or self.get_predefined_value('graph_to_nx_convert_func') + def _draw_with_pyvis(self, save_path: Optional[PathType] = None, node_color: Optional[NodeColorType] = None): + save_path = save_path or self._get_predefined_value('save_path') + node_color = node_color or self._get_predefined_value('node_color') net = Network('500px', '1000px', directed=True) - nx_graph, nodes = graph_to_nx_convert_func(self.graph) - # Define colors - if callable(node_color): - colors = node_color([str(node) for node in nodes.values()]) - elif isinstance(node_color, dict): - colors = node_color - else: - colors = {str(node): node_color for node in nodes.values()} + nx_graph, nodes = self.nx_graph, self.nodes_dict + node_color = self._define_colors(node_color, nodes) for n, data in nx_graph.nodes(data=True): operation = nodes[n] label = str(operation) @@ -146,7 +183,7 @@ def __draw_with_pyvis(self, save_path: Optional[PathType] = None, node_color: Op params = str(params)[1:-1] data['title'] = params data['level'] = distance_to_primary_level(operation) - data['color'] = to_hex(colors.get(label, colors.get(None))) + data['color'] = to_hex(node_color.get(label, node_color.get(None))) data['font'] = '20px' data['labelHighlightBold'] = True @@ -163,82 +200,140 @@ def __draw_with_pyvis(self, save_path: Optional[PathType] = None, node_color: Op net.show(str(save_path)) remove_old_files_from_dir(save_path.parent) - def __draw_with_networkx(self, save_path: Optional[PathType] = None, - node_color: Optional[NodeColorType] = None, - dpi: Optional[int] = None, node_size_scale: Optional[float] = None, - font_size_scale: Optional[float] = None, edge_curvature_scale: Optional[float] = None, - graph_to_nx_convert_func: Optional[Callable] = None, title: Optional[str] = None, - nodes_labels: Dict[int, str] = None, edges_labels: Dict[int, str] = None): - save_path = save_path or self.get_predefined_value('save_path') - node_color = node_color or self.get_predefined_value('node_color') - dpi = dpi or self.get_predefined_value('dpi') - node_size_scale = node_size_scale or self.get_predefined_value('node_size_scale') - font_size_scale = font_size_scale or self.get_predefined_value('font_size_scale') - edge_curvature_scale = (edge_curvature_scale if edge_curvature_scale is not None - else self.get_predefined_value('edge_curvature_scale')) - graph_to_nx_convert_func = graph_to_nx_convert_func or self.get_predefined_value('graph_to_nx_convert_func') + def _draw_with_graphviz(self, save_path: Optional[PathType] = None, node_color: Optional[NodeColorType] = None, + dpi: Optional[int] = None): + save_path = save_path or self._get_predefined_value('save_path') + node_color = node_color or self._get_predefined_value('node_color') + dpi = dpi or self._get_predefined_value('dpi') - fig, ax = plt.subplots(figsize=(7, 7)) - fig.set_dpi(dpi) + nx_graph, nodes = self.nx_graph, self.nodes_dict + node_color = self._define_colors(node_color, nodes) + for n, data in nx_graph.nodes(data=True): + label = str(nodes[n]) + data['label'] = label.replace('_', ' ') + data['color'] = to_hex(node_color.get(label, node_color.get(None))) - plt.title(title) - self.draw_nx_dag(ax, node_color, node_size_scale, font_size_scale, edge_curvature_scale, - graph_to_nx_convert_func, nodes_labels, edges_labels) - if not save_path: - plt.show() + gv_graph = nx.nx_agraph.to_agraph(nx_graph) + kwargs = {'prog': 'dot', 'args': f'-Gnodesep=0.5 -Gdpi={dpi} -Grankdir="LR"'} + + if save_path: + gv_graph.draw(save_path, **kwargs) else: - plt.savefig(save_path, dpi=dpi) - plt.close() + save_path = Path(default_data_dir(), 'graph_plots', str(uuid4()) + '.png') + save_path.parent.mkdir(exist_ok=True) + gv_graph.draw(save_path, **kwargs) - def draw_nx_dag(self, ax: Optional[plt.Axes] = None, - node_color: Optional[NodeColorType] = None, - node_size_scale: float = 1, font_size_scale: float = 1, edge_curvature_scale: float = 1, - graph_to_nx_convert_func: Callable = graph_structure_as_nx_graph, - nodes_labels: Dict[int, str] = None, edges_labels: Dict[int, str] = None): - - def draw_nx_labels(pos, node_labels, ax, max_sequence_length, font_size_scale=1.0): - def get_scaled_font_size(nodes_amount): - min_size = 2 - max_size = 20 - - size = min_size + int((max_size - min_size) / np.log2(max(nodes_amount, 2))) - return size - - if ax is None: - ax = plt.gca() - for node, (x, y) in pos.items(): - text = '\n'.join(wrap(node_labels[node].replace('_', ' ').replace('-', ' '), 10)) - ax.text(x, y, text, ha='center', va='center', - fontsize=get_scaled_font_size(max_sequence_length) * font_size_scale, - bbox=dict(alpha=0.9, color='w', boxstyle='round')) - - def get_scaled_node_size(nodes_amount): - min_size = 500 - max_size = 5000 - size = min_size + int((max_size - min_size) / np.log2(max(nodes_amount, 2))) - return size + img = plt.imread(str(save_path)) + plt.imshow(img) + plt.gca().axis('off') + plt.gcf().set_dpi(dpi) + plt.tight_layout() + plt.show() + remove_old_files_from_dir(save_path.parent) - if ax is None: - ax = plt.gca() + @staticmethod + def _get_scaled_node_size(nodes_amount: int, size_scale: float) -> float: + min_size = 150 + max_size = 12000 + size = max(max_size * (1 - np.log10(nodes_amount)), min_size) + return size * size_scale - nx_graph, nodes = graph_to_nx_convert_func(self.graph) - # Define colors + @staticmethod + def _get_scaled_font_size(nodes_amount: int, size_scale: float) -> float: + min_size = 14 + max_size = 30 + size = max(max_size * (1 - np.log10(nodes_amount)), min_size) + return size * size_scale + + @staticmethod + def _get_colors_by_labels(labels: Iterable[str]) -> LabelsColorMapType: + unique_labels = list(set(labels)) + palette = color_palette('tab10', len(unique_labels)) + return {label: palette[unique_labels.index(label)] for label in labels} + + @staticmethod + def _define_colors(node_color, nodes): if callable(node_color): - node_color = node_color([str(node) for node in nodes.values()]) - if isinstance(node_color, dict): - node_color = [node_color.get(str(node), node_color.get(None)) for node in nodes.values()] - # Define hierarchy_level - for node_id, node_data in nx_graph.nodes(data=True): - node_data['hierarchy_level'] = distance_to_primary_level(nodes[node_id]) - # Get nodes positions - pos, longest_sequence = get_hierarchy_pos(nx_graph) - node_size = get_scaled_node_size(longest_sequence) * node_size_scale + colors = node_color([str(node) for node in nodes.values()]) + elif isinstance(node_color, dict): + colors = node_color + else: + colors = {str(node): node_color for node in nodes.values()} + return colors + + @staticmethod + def _setup_matplotlib_figure(figure_size: Tuple[float, float], dpi: int, title: Optional[str] = None) -> plt.Axes: + fig, ax = plt.subplots(figsize=figure_size) + fig.set_dpi(dpi) + plt.title(title) + return ax + + @staticmethod + def _rescale_matplotlib_figure(ax): + """Rescale the figure for all nodes to fit in.""" + + x_1, x_2 = ax.get_xlim() + y_1, y_2 = ax.get_ylim() + offset = 0.2 + x_offset = x_2 * offset + y_offset = y_2 * offset + ax.set_xlim(x_1 - x_offset, x_2 + x_offset) + ax.set_ylim(y_1 - y_offset, y_2 + y_offset) + ax.axis('off') + plt.tight_layout() + + def _draw_nx_big_nodes(self, ax, pos, nodes, node_color, node_size, font_size_scale, with_node_names): # Draw the graph's nodes. - nx.draw_networkx_nodes(nx_graph, pos, node_size=node_size, ax=ax, node_color='w', linewidths=3, + nx.draw_networkx_nodes(self.nx_graph, pos, node_size=node_size, ax=ax, node_color='w', linewidths=3, edgecolors=node_color) + if not with_node_names: + return # Draw the graph's node labels. - draw_nx_labels(pos, {node_id: str(node) for node_id, node in nodes.items()}, ax, longest_sequence, - font_size_scale) + node_labels = {node_id: str(node) for node_id, node in nodes.items()} + font_size = GraphVisualizer._get_scaled_font_size(len(nodes), font_size_scale) + for node, (x, y) in pos.items(): + text = '\n'.join(wrap(node_labels[node].replace('_', ' ').replace('-', ' '), 10)) + ax.text(x, y, text, + ha='center', va='center', + fontsize=font_size, + bbox=dict(alpha=0.9, color='w', boxstyle='round')) + + def _draw_nx_small_nodes(self, ax, pos, nodes, node_color, node_size, font_size_scale, with_node_names): + nx_graph = self.nx_graph + markers = 'os^>v len(markers) - 1: + self.log.warning(f'Too much node labels derive the same color: {color}. The markers may repeat.\n' + '\tSpecify the parameter "node_color" to set distinct colors.') + color_count = color_count % len(markers) + marker = markers[color_count] + label_markers[label] = marker + color_counts[color] = color_count + 1 + nx.draw_networkx_nodes(nx_graph, pos, [node_id], ax=ax, node_color=[color], node_size=node_size, + node_shape=marker) + if label in labels_added: + continue + ax.plot([], [], marker=marker, linestyle='None', color=color, label=label) + labels_added.add(label) + if not with_node_names: + return + # @morrisnein took the following code from https://stackoverflow.com/a/27512450 + handles, labels = ax.get_legend_handles_labels() + # Sort both labels and handles by labels + labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0])) + ax.legend(handles, labels, prop={'size': round(20 * font_size_scale)}) + + def _draw_nx_curved_edges(self, ax, pos, node_size, edge_curvature_scale): + nx_graph = self.nx_graph # The ongoing section defines curvature for all edges. # This is 'connection style' for an edge that does not intersect any nodes. connection_style = 'arc3' @@ -261,9 +356,9 @@ def get_scaled_node_size(nodes_amount): continue # The node is adjacent to the edge. p_3 = np.array(pos[node_id]) distance_to_node = abs(np.cross(p_1_2, p_3 - p_1)) / p_1_2_length - if (distance_to_node > min(node_distance_gap, min_distance_found) or # The node is too far. - ((p_3 - p_1) @ p_1_2) < 0 or # There's no perpendicular from the node to the edge. - ((p_3 - p_2) @ -p_1_2) < 0): + if (distance_to_node > min(node_distance_gap, min_distance_found) # The node is too far. + or ((p_3 - p_1) @ p_1_2) < 0 # There's no perpendicular from the node to the edge. + or ((p_3 - p_2) @ -p_1_2) < 0): continue min_distance_found = distance_to_node closest_node_id = node_id @@ -281,148 +376,138 @@ def get_scaled_node_size(nodes_amount): # Then, its ordinate shows on which side of the edge it is, "on the left" or "on the right". rotation_matrix = np.array([[cos_alpha, sin_alpha], [-sin_alpha, cos_alpha]]) p_1_3_rotated = rotation_matrix @ p_1_3 - curvature_direction = (-1) ** (p_1_3_rotated[1] < 0) # +1 is a "boat" \/, -1 is a "cat" /\. + curvature_direction = (-1) ** (p_1_3_rotated[1] < 0) # +1 is a "cup" \/, -1 is a "cat" /\. edge_curvature = curvature_direction * curvature_strength e['connectionstyle'] = connection_style_curved_template.format(edge_curvature) + # Define edge center position for labels. + edge_center_position = np.mean([p_1, p_2], axis=0) + edge_curvature_shift = np.linalg.inv(rotation_matrix) @ [0, -1 * edge_curvature / 4] + edge_center_position += edge_curvature_shift + e['edge_center_position'] = edge_center_position # Draw the graph's edges. arrow_style = ArrowStyle('Simple', head_length=1.5, head_width=0.8) for u, v, e in nx_graph.edges(data=True): nx.draw_networkx_edges(nx_graph, pos, edgelist=[(u, v)], node_size=node_size, ax=ax, arrowsize=10, arrowstyle=arrow_style, connectionstyle=e['connectionstyle']) - if nodes_labels or edges_labels: - self._set_labels(ax, pos, nx_graph, - longest_sequence, longest_sequence, font_size_scale, - nodes_labels, edges_labels) - # Rescale the figure for all nodes to fit in. - x_1, x_2 = ax.get_xlim() - y_1, y_2 = ax.get_ylim() - offset = 0.2 - x_offset = x_2 * offset - y_offset = y_2 * offset - ax.set_xlim(x_1 - x_offset, x_2 + x_offset) - ax.set_ylim(y_1 - y_offset, y_2 + y_offset) - ax.axis('off') - plt.tight_layout() + self._rescale_matplotlib_figure(ax) - def get_predefined_value(self, param: str): - return self.visuals_params.get(param) - - def _set_labels(self, ax: plt.Axes, pos: Any, nx_graph: nx.DiGraph, - longest_sequence: int, longest_y_sequence: int, font_size_scale: float, - nodes_labels: Dict[int, str], edges_labels: Dict[int, str]): + def _draw_nx_labels(self, ax: plt.Axes, pos: Any, font_size_scale: float, + nodes_labels: Dict[int, str], edges_labels: Dict[int, str]): """ Set labels with scores to nodes and edges. """ - def calculate_labels_bias(ax: plt.Axes, longest_y_sequence: int): + def calculate_labels_bias(ax: plt.Axes, y_span: int): y_1, y_2 = ax.get_ylim() y_size = y_2 - y_1 - if longest_y_sequence == 1: + if y_span == 1: bias_scale = 0.25 # Fits between the central line and the upper bound. else: - bias_scale = 1 / longest_y_sequence / 3 * 0.9 # Fits between the narrowest horizontal rows. + bias_scale = 1 / y_span / 3 * 0.5 # Fits between the narrowest horizontal rows. bias = y_size * bias_scale return bias - def _get_scaled_font_size(nodes_amount: int, size_scale: float) -> float: - min_size = 11 - max_size = 25 - size = max(max_size * (1 - np.log10(nodes_amount)), min_size) - return size * size_scale - def match_labels_with_nx_nodes(nx_graph: nx.DiGraph, labels: Dict[int, str]) -> Dict[str, str]: """ Matches index of node in GOLEM graph with networkx node name. """ nx_nodes = list(nx_graph.nodes.keys()) nx_labels = {} - for index in labels: - nx_labels[nx_nodes[index]] = labels[index] + for index, label in labels.items(): + nx_labels[nx_nodes[index]] = label return nx_labels def match_labels_with_nx_edges(nx_graph: nx.DiGraph, labels: Dict[int, str]) \ - -> Dict[Tuple[str, str], List[str]]: + -> Dict[Tuple[str, str], str]: """ Matches index of edge in GOLEM graph with tuple of networkx nodes names. """ nx_nodes = list(nx_graph.nodes.keys()) edges = self.graph.get_edges() nx_labels = {} - for index in labels: + for index, label in labels.items(): edge = edges[index] parent_node_nx = nx_nodes[self.graph.nodes.index(edge[0])] child_node_nx = nx_nodes[self.graph.nodes.index(edge[1])] - nx_labels[(parent_node_nx, child_node_nx)] = labels[index] + nx_labels[(parent_node_nx, child_node_nx)] = label return nx_labels - if not edges_labels and not nodes_labels: - return - - bias = calculate_labels_bias(ax, longest_y_sequence) - if nodes_labels: - # Set labels for nodes + def draw_node_labels(node_labels, ax, bias, font_size, nx_graph, pos): labels_pos = deepcopy(pos) - font_size = _get_scaled_font_size(longest_sequence, font_size_scale * 0.7) - bbox = dict(alpha=0.9, color='w') for value in labels_pos.values(): value[1] += bias + bbox = dict(alpha=0.9, color='w') - nodes_nx_labels = match_labels_with_nx_nodes(nx_graph=nx_graph, labels=nodes_labels) + nodes_nx_labels = match_labels_with_nx_nodes(nx_graph=nx_graph, labels=node_labels) nx.draw_networkx_labels( nx_graph, labels_pos, + ax=ax, labels=nodes_nx_labels, font_color='black', font_size=font_size, bbox=bbox ) - if not edges_labels: + def draw_edge_labels(edge_labels, ax, bias, font_size, nx_graph, pos): + labels_pos_edges = deepcopy(pos) + label_bias_y = 2 / 3 * bias + if len(set([coord[1] for coord in pos.values()])) == 1 and len(list(pos.values())) > 2: + for value in labels_pos_edges.values(): + value[1] += label_bias_y + edges_nx_labels = match_labels_with_nx_edges(nx_graph=nx_graph, labels=edge_labels) + bbox = dict(alpha=0.9, color='w') + # Set labels for edges + for u, v, e in nx_graph.edges(data=True): + if (u, v) not in edges_nx_labels: + continue + current_pos = labels_pos_edges + if 'edge_center_position' in e: + x, y = e['edge_center_position'] + plt.text(x, y, edges_nx_labels[(u, v)], bbox=bbox, fontsize=font_size) + else: + nx.draw_networkx_edge_labels( + nx_graph, current_pos, {(u, v): edges_nx_labels[(u, v)]}, + label_pos=0.5, ax=ax, + font_color='black', + font_size=font_size, + rotate=False, + bbox=bbox + ) + + if not (edges_labels or nodes_labels): return - labels_pos_edges = deepcopy(pos) - label_bias_y = 2 / 3 * bias - if len(set([coord[1] for coord in pos.values()])) == 1 and len(list(pos.values())) > 2: - for value in labels_pos_edges.values(): - value[1] += label_bias_y - - edges_nx_labels = match_labels_with_nx_edges(nx_graph=nx_graph, labels=edges_labels) - # Set labels for edges - for u, v, e in nx_graph.edges(data=True): - if (u, v) not in edges_nx_labels: - continue - current_pos = labels_pos_edges - if 'edge_center_position' in e: - x, y = e['edge_center_position'] - plt.text(x, y, edges_nx_labels[(u, v)], bbox=bbox, fontsize=font_size) - else: - nx.draw_networkx_edge_labels( - nx_graph, current_pos, {(u, v): edges_nx_labels[(u, v)]}, - label_pos=0.5, ax=ax, - font_color='black', - font_size=font_size, - rotate=False, - bbox=bbox - ) - - -def get_hierarchy_pos(graph: nx.DiGraph, max_line_length: int = 6) -> Tuple[Dict[Any, Tuple[float, float]], int]: - """By default, returns 'networkx.multipartite_layout' positions based on 'hierarchy_level` from node data - the - property must be set beforehand. - If line of nodes reaches 'max_line_length', the result is the combination of 'networkx.multipartite_layout' and - 'networkx.spring_layout'. - :param graph: the graph. - :param max_line_length: the limit for common nodes horizontal or vertical line. - """ - longest_path = nx.dag_longest_path(graph, weight=None) - longest_sequence = len(longest_path) - - pos = nx.multipartite_layout(graph, subset_key='hierarchy_level') - - y_level_nodes_count = {} - for x, _ in pos.values(): - y_level_nodes_count[x] = y_level_nodes_count.get(x, 0) + 1 - nodes_on_level = y_level_nodes_count[x] - if nodes_on_level > longest_sequence: - longest_sequence = nodes_on_level - - if longest_sequence > max_line_length: - pos = {n: np.array(x_y) + (np.random.random(2) - 0.5) * 0.001 for n, x_y in pos.items()} - pos = nx.spring_layout(graph, k=2, iterations=5, pos=pos, seed=42) - - return pos, longest_sequence + + nodes_amount = len(pos) + font_size = GraphVisualizer._get_scaled_font_size(nodes_amount, font_size_scale * 0.75) + _, y_span = GraphVisualizer._get_x_y_span(pos) + bias = calculate_labels_bias(ax, y_span) + + if nodes_labels: + draw_node_labels(nodes_labels, ax, bias, font_size, self.nx_graph, pos) + + if edges_labels: + draw_edge_labels(edges_labels, ax, bias, font_size, self.nx_graph, pos) + + @staticmethod + def _get_hierarchy_pos_by_distance_to_primary_level(nx_graph: nx.DiGraph, nodes: Dict + ) -> Dict[Any, Tuple[float, float]]: + """By default, returns 'networkx.multipartite_layout' positions based on 'hierarchy_level` + from node data - the property must be set beforehand. + :param graph: the graph. + """ + for node_id, node_data in nx_graph.nodes(data=True): + node_data['hierarchy_level'] = distance_to_primary_level(nodes[node_id]) + + return nx.multipartite_layout(nx_graph, subset_key='hierarchy_level') + + @staticmethod + def _get_x_y_span(pos: Dict[Any, Tuple[float, float]]) -> Tuple[int, int]: + pos_x, pos_y = np.split(np.array(tuple(pos.values())), 2, axis=1) + x_span = max(pos_x) - min(pos_x) + y_span = max(pos_y) - min(pos_y) + return x_span, y_span + + @staticmethod + def _define_node_names_placement(node_size): + if node_size >= 1000: + node_names_placement = 'nodes' + else: + node_names_placement = 'legend' + return node_names_placement def remove_old_files_from_dir(dir_: Path, time_interval=datetime.timedelta(minutes=10)): diff --git a/test/unit/visualisation/test_visualisation_utils.py b/test/unit/visualisation/test_visualisation_utils.py index cd470c696..cbd875a29 100644 --- a/test/unit/visualisation/test_visualisation_utils.py +++ b/test/unit/visualisation/test_visualisation_utils.py @@ -1,9 +1,8 @@ from golem.core.adapter import DirectAdapter from golem.core.dag.convert import graph_structure_as_nx_graph -from golem.core.dag.graph_utils import distance_to_primary_level from golem.core.optimisers.fitness.multi_objective_fitness import MultiObjFitness from golem.core.optimisers.opt_history_objects.individual import Individual -from golem.visualisation.graph_viz import get_hierarchy_pos +from golem.visualisation.graph_viz import GraphVisualizer from golem.visualisation.opt_viz_extra import extract_objectives from test.unit.utils import graph_first @@ -41,12 +40,10 @@ def test_hierarchy_pos(): 1: ['a', 'b'], 2: ['a']} - graph, node_labels = graph_structure_as_nx_graph(graph) - for n, data in graph.nodes(data=True): - data['hierarchy_level'] = distance_to_primary_level(node_labels[n]) - node_labels[n] = str(node_labels[n]) + nx_graph, nodes_dict = graph_structure_as_nx_graph(graph) + node_labels = {uid: str(node) for uid, node in nodes_dict.items()} - pos, _ = get_hierarchy_pos(graph) + pos = GraphVisualizer._get_hierarchy_pos_by_distance_to_primary_level(nx_graph, nodes_dict) comparable_lists_y = make_comparable_lists(pos, real_hierarchy_levels_y, node_labels, 1, reverse=True) comparable_lists_x = make_comparable_lists(pos, real_hierarchy_levels_x,