Skip to content

Commit

Permalink
refact(plot): separate graphviz building from IO
Browse files Browse the repository at this point in the history
  • Loading branch information
ankostis committed Oct 5, 2019
1 parent b08a363 commit 3a87959
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 87 deletions.
2 changes: 1 addition & 1 deletion graphkit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def plot(self, filename=None, show=False, jupyter=None,
:param str filename:
Write diagram into a file.
Common extensions are ``.png .dot .jpg .jpeg .pdf .svg``
call :func:`network.supported_plot_formats()` for more.
call :func:`plot.supported_plot_formats()` for more.
:param show:
If it evaluates to true, opens the diagram in a matplotlib window.
If it equals `-1`, it plots but does not open the Window.
Expand Down
2 changes: 1 addition & 1 deletion graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def plot(self, filename=None, show=False, jupyter=None,
:param str filename:
Write diagram into a file.
Common extensions are ``.png .dot .jpg .jpeg .pdf .svg``
call :func:`network.supported_plot_formats()` for more.
call :func:`plot.supported_plot_formats()` for more.
:param show:
If it evaluates to true, opens the diagram in a matplotlib window.
If it equals `-1``, it plots but does not open the Window.
Expand Down
173 changes: 90 additions & 83 deletions graphkit/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,90 @@ def supported_plot_formats():
return [".%s" % f for f in pydot.Dot().formats]


def build_pydot(graph, steps=None, inputs=None, outputs=None, solution=None):
""" Build a Graphviz graph """
import pydot

assert graph is not None

def get_node_name(a):
if isinstance(a, Operation):
return a.name
return a

dot = pydot.Dot(graph_type="digraph")

# draw nodes
for nx_node in graph.nodes:
kw = {}
if isinstance(nx_node, str):
# Only DeleteInstructions data in steps.
if nx_node in steps:
kw = {"color": "red", "penwidth": 2}

# SHAPE change if in inputs/outputs.
# tip: https://graphviz.gitlab.io/_pages/doc/info/shapes.html
shape = "rect"
if inputs and outputs and nx_node in inputs and nx_node in outputs:
shape = "hexagon"
else:
if inputs and nx_node in inputs:
shape = "invhouse"
if outputs and nx_node in outputs:
shape = "house"

# LABEL change from solution.
if solution and nx_node in solution:
kw["style"] = "filled"
kw["fillcolor"] = "gray"
# kw["tooltip"] = nx_node, solution.get(nx_node)
node = pydot.Node(name=nx_node, shape=shape, URL="fdgfdf", **kw)
else: # Operation
kw = {}
shape = "oval" if isinstance(nx_node, NetworkOperation) else "circle"
if nx_node in steps:
kw["style"] = "bold"
node = pydot.Node(name=nx_node.name, shape=shape, **kw)

dot.add_node(node)

# draw edges
for src, dst in graph.edges:
src_name = get_node_name(src)
dst_name = get_node_name(dst)
kw = {}
if isinstance(dst, Operation) and any(
n == src and isinstance(n, optional) for n in dst.needs
):
kw["style"] = "dashed"
edge = pydot.Edge(src=src_name, dst=dst_name, **kw)
dot.add_edge(edge)

# draw steps sequence
if steps and len(steps) > 1:
it1 = iter(steps)
it2 = iter(steps)
next(it2)
for i, (src, dst) in enumerate(zip(it1, it2), 1):
src_name = get_node_name(src)
dst_name = get_node_name(dst)
edge = pydot.Edge(
src=src_name,
dst=dst_name,
label=str(i),
style="dotted",
color="green",
fontcolor="green",
fontname="bold",
fontsize=18,
penwidth=3,
arrowhead="vee",
)
dot.add_edge(edge)

return dot


def plot_graph(
graph,
filename=None,
Expand Down Expand Up @@ -55,7 +139,7 @@ def plot_graph(
:param str filename:
Write diagram into a file.
Common extensions are ``.png .dot .jpg .jpeg .pdf .svg``
call :func:`network.supported_plot_formats()` for more.
call :func:`plot.supported_plot_formats()` for more.
:param show:
If it evaluates to true, opens the diagram in a matplotlib window.
If it equals `-1``, it plots but does not open the Window.
Expand Down Expand Up @@ -93,84 +177,7 @@ def plot_graph(
>>> pipeline.plot('plot.svg', inputs=inputs, solution=solution, outputs=['asked', 'b1']);
"""
import pydot

assert graph is not None

def get_node_name(a):
if isinstance(a, Operation):
return a.name
return a

g = pydot.Dot(graph_type="digraph")

# draw nodes
for nx_node in graph.nodes:
kw = {}
if isinstance(nx_node, str):
# Only DeleteInstructions data in steps.
if nx_node in steps:
kw = {"color": "red", "penwidth": 2}

# SHAPE change if in inputs/outputs.
# tip: https://graphviz.gitlab.io/_pages/doc/info/shapes.html
shape = "rect"
if inputs and outputs and nx_node in inputs and nx_node in outputs:
shape = "hexagon"
else:
if inputs and nx_node in inputs:
shape = "invhouse"
if outputs and nx_node in outputs:
shape = "house"

# LABEL change from solution.
if solution and nx_node in solution:
kw["style"] = "filled"
kw["fillcolor"] = "gray"
# kw["tooltip"] = nx_node, solution.get(nx_node)
node = pydot.Node(name=nx_node, shape=shape, URL="fdgfdf", **kw)
else: # Operation
kw = {}
shape = "oval" if isinstance(nx_node, NetworkOperation) else "circle"
if nx_node in steps:
kw["style"] = "bold"
node = pydot.Node(name=nx_node.name, shape=shape, **kw)

g.add_node(node)

# draw edges
for src, dst in graph.edges:
src_name = get_node_name(src)
dst_name = get_node_name(dst)
kw = {}
if isinstance(dst, Operation) and any(
n == src and isinstance(n, optional) for n in dst.needs
):
kw["style"] = "dashed"
edge = pydot.Edge(src=src_name, dst=dst_name, **kw)
g.add_edge(edge)

# draw steps sequence
if steps and len(steps) > 1:
it1 = iter(steps)
it2 = iter(steps)
next(it2)
for i, (src, dst) in enumerate(zip(it1, it2), 1):
src_name = get_node_name(src)
dst_name = get_node_name(dst)
edge = pydot.Edge(
src=src_name,
dst=dst_name,
label=str(i),
style="dotted",
color="green",
fontcolor="green",
fontname="bold",
fontsize=18,
penwidth=3,
arrowhead="vee",
)
g.add_edge(edge)
dot = build_pydot(graph, steps, inputs, outputs, solution)

# Save plot
#
Expand All @@ -183,26 +190,26 @@ def get_node_name(a):
" File extensions must be one of: %s" % (ext, " ".join(formats))
)

g.write(filename, format=ext.lower()[1:])
dot.write(filename, format=ext.lower()[1:])

## Return an SVG renderable in jupyter.
#
if jupyter:
from IPython.display import SVG

g = SVG(data=g.create_svg())
dot = SVG(data=dot.create_svg())

## Display graph via matplotlib
#
if show:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

png = g.create_png()
png = dot.create_png()
sio = io.BytesIO(png)
img = mpimg.imread(sio)
plt.imshow(img, aspect="equal")
if show != -1:
plt.show()

return g
return dot
4 changes: 2 additions & 2 deletions test/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_plot_formats(pipeline, input_names, outputs, solution, tmp_path):
# ...these are not working on my PC, or travis.
forbidden_formats = ".dia .hpgl .mif .mp .pcl .pic .vtx .xlib".split()
prev_dot = None
for ext in network.supported_plot_formats():
for ext in plot.supported_plot_formats():
if ext not in forbidden_formats:
dot = pipeline.plot(inputs=input_names, outputs=outputs, solution=solution)
assert dot
Expand All @@ -72,7 +72,7 @@ def test_plot_bad_format(pipeline, tmp_path):
pipeline.plot(filename="bad.format")

## Check help msg lists all siupported formats
for ext in network.supported_plot_formats():
for ext in plot.supported_plot_formats():
assert exinfo.match(ext)


Expand Down

0 comments on commit 3a87959

Please sign in to comment.