Skip to content

Commit

Permalink
FEAT(plot): overlay Execution STEPS on diagrams
Browse files Browse the repository at this point in the history
  • Loading branch information
ankostis committed Oct 5, 2019
1 parent fef1a2a commit 54ec53a
Showing 1 changed file with 28 additions and 10 deletions.
38 changes: 28 additions & 10 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import time
import networkx as nx

from io import StringIO

from .base import Operation

Expand Down Expand Up @@ -394,7 +393,7 @@ def plot(self, filename=None, show=False):
Supported arguments: filename, show
See :func:`network.plot_graph()`
"""
return plot_graph(self.graph, filename=filename, show=show)
return plot_graph(self.graph, filename, show, self.steps)


def ready_to_schedule_operation(op, has_executed, graph):
Expand Down Expand Up @@ -445,9 +444,9 @@ def get_data_node(name, graph):
return None


def plot_graph(graph, filename=None, show=False):
def plot_graph(graph, filename=None, show=False, steps=None):
"""
Plot a *Graphviz* graph and return it, if no other argument provided.
Plot a *Graphviz* graph/steps and return it, if no other argument provided.
:param graph:
what to plot
Expand All @@ -457,6 +456,8 @@ def plot_graph(graph, filename=None, show=False):
:param boolean show:
If this is set to True, use matplotlib to show the graph diagram
(Default: False)
:param steps:
a list of nodes & instructions to overlay on the diagram
:returns:
An instance of the pydot graph
Expand All @@ -469,18 +470,23 @@ def plot_graph(graph, filename=None, show=False):
assert graph is not None

def get_node_name(a):
if isinstance(a, DataPlaceholderNode):
return a
return a.name
if isinstance(a, Operation):
return a.name
return a

g = pydot.Dot(graph_type="digraph")

# draw nodes
for nx_node in graph.nodes():
for nx_node in graph.nodes:
kw = {}
if isinstance(nx_node, DataPlaceholderNode):
node = pydot.Node(name=nx_node, shape="rect")
if nx_node in steps:
kw = {'color': 'red', 'style': 'bold'}
node = pydot.Node(name=nx_node, shape="rect", **kw)
else:
node = pydot.Node(name=nx_node.name, shape="circle")
if nx_node in steps:
kw = {'style': 'bold'}
node = pydot.Node(name=nx_node.name, shape="circle", **kw)
g.add_node(node)

# draw edges
Expand All @@ -490,6 +496,18 @@ def get_node_name(a):
edge = pydot.Edge(src=src_name, dst=dst_name)
g.add_edge(edge)

# draw steps sequence
if steps and len(steps) > 1:
it1 = iter(steps)
it2 = iter(steps); next(it2)
for i, (src, dst) in enumerate(zip(it1, it2), 1):
src_name = get_node_name(src)
dst_name = get_node_name(dst)
edge = pydot.Edge(
src=src_name, dst=dst_name, label=str(i), style="dotted",
penwidth='2')
g.add_edge(edge)

# save plot
if filename:
_basename, ext = os.path.splitext(filename)
Expand Down

0 comments on commit 54ec53a

Please sign in to comment.