Skip to content

Commit

Permalink
CVT Archive heatmap flexibility
Browse files Browse the repository at this point in the history
  • Loading branch information
btjanaka committed Sep 6, 2023
1 parent c286612 commit f3a0384
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 19 deletions.
59 changes: 40 additions & 19 deletions ribs/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import axes
from matplotlib.cm import ScalarMappable
from scipy.spatial import Voronoi # pylint: disable=no-name-in-module

Expand Down Expand Up @@ -69,7 +68,8 @@ def _validate_heatmap_visual_args(aspect, cbar, measure_dim, valid_dims,
f"Invalid arg aspect='{aspect}'; must be 'auto', 'equal', or float")
if measure_dim not in valid_dims:
raise ValueError(error_msg_measure_dim)
if not (cbar == "auto" or isinstance(cbar, axes.Axes) or cbar is None):
if not (cbar == "auto" or isinstance(cbar, matplotlib.axes.Axes) or
cbar is None):
raise ValueError(f"Invalid arg cbar={cbar}; must be 'auto', None, "
"or matplotlib.axes.Axes")

Expand All @@ -79,7 +79,7 @@ def _set_cbar(t, ax, cbar, cbar_kwargs):
cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs
if cbar == "auto":
ax.figure.colorbar(t, ax=ax, **cbar_kwargs)
elif isinstance(cbar, axes.Axes):
elif isinstance(cbar, matplotlib.axes.Axes):
cbar.figure.colorbar(t, ax=cbar, **cbar_kwargs)


Expand Down Expand Up @@ -261,17 +261,18 @@ def grid_archive_heatmap(archive,
def cvt_archive_heatmap(archive,
ax=None,
*,
plot_centroids=False,
plot_samples=False,
transpose_measures=False,
cmap="magma",
aspect="auto",
ms=1,
lw=0.5,
ec="black",
vmin=None,
vmax=None,
cbar="auto",
cbar_kwargs=None):
cbar_kwargs=None,
plot_centroids=False,
plot_samples=False,
ms=1):
"""Plots heatmap of a :class:`~ribs.archives.CVTArchive` with 2D measure
space.
Expand Down Expand Up @@ -314,9 +315,6 @@ def cvt_archive_heatmap(archive,
archive (CVTArchive): A 2D :class:`~ribs.archives.CVTArchive`.
ax (matplotlib.axes.Axes): Axes on which to plot the heatmap.
If ``None``, the current axis will be used.
plot_centroids (bool): Whether to plot the cluster centroids.
plot_samples (bool): Whether to plot the samples used when generating
the clusters.
transpose_measures (bool): By default, the first measure in the archive
will appear along the x-axis, and the second will be along the
y-axis. To switch this behavior (i.e. to transpose the axes), set
Expand All @@ -329,8 +327,11 @@ def cvt_archive_heatmap(archive,
aspect ('auto', 'equal', float): The aspect ratio of the heatmap (i.e.
height/width). Defaults to ``'auto'``. ``'equal'`` is the same as
``aspect=1``.
ms (float): Marker size for both centroids and samples.
lw (float): Line width when plotting the voronoi diagram.
lw (float): Line width when plotting the Voronoi diagram.
ec (matplotlib color): Edge color of the cells in the Voronoi diagram.
See `here
<https://matplotlib.org/stable/tutorials/colors/colors.html>`_ for
more info on specifying colors in Matplotlib.
vmin (float): Minimum objective value to use in the plot. If ``None``,
the minimum objective value in the archive is used.
vmax (float): Maximum objective value to use in the plot. If ``None``,
Expand All @@ -342,14 +343,24 @@ def cvt_archive_heatmap(archive,
the colorbar on the specified Axes.
cbar_kwargs (dict): Additional kwargs to pass to
:func:`~matplotlib.pyplot.colorbar`.
plot_centroids (bool): Whether to plot the cluster centroids.
plot_samples (bool): Whether to plot the samples used when generating
the clusters.
ms (float): Marker size for both centroids and samples.
Raises:
ValueError: The archive is not 2D.
ValueError: ``plot_samples`` is passed in but the archive does not have
samples (e.g., due to using custom centroids during construction).
"""
_validate_heatmap_visual_args(
aspect, cbar, archive.measure_dim, [2],
"Heatmaps can only be plotted for 2D CVTArchive")

if plot_samples and archive.samples is None:
raise ValueError("Samples are not available for this archive, but "
"`plot_samples` was passed in.")

if aspect is None:
aspect = "auto"

Expand All @@ -360,12 +371,15 @@ def cvt_archive_heatmap(archive,
lower_bounds = archive.lower_bounds
upper_bounds = archive.upper_bounds
centroids = archive.centroids
samples = archive.samples
if transpose_measures:
lower_bounds = np.flip(lower_bounds)
upper_bounds = np.flip(upper_bounds)
centroids = np.flip(centroids, axis=1)
samples = np.flip(samples, axis=1)

if plot_samples:
samples = archive.samples
if transpose_measures:
samples = np.flip(samples, axis=1)

# Retrieve and initialize the axis.
ax = plt.gca() if ax is None else ax
Expand Down Expand Up @@ -404,6 +418,10 @@ def cvt_archive_heatmap(archive,
min_obj = min_obj if vmin is None else vmin
max_obj = max_obj if vmax is None else vmax

# If the min and max are the same, we set a sensible default range.
if min_obj == max_obj:
min_obj, max_obj = min_obj - 0.01, max_obj + 0.01

# Shade the regions.
#
# Note: by default, the first region will be an empty list -- see:
Expand All @@ -415,23 +433,26 @@ def cvt_archive_heatmap(archive,
# `polygon` is also O(n) anyway.
if -1 not in region:
if objective is None:
color = "white"
# Transparent white (RGBA format) -- this ensures that if a
# figure is saved with a transparent background, the empty cells
# will also be transparent.
color = (1.0, 1.0, 1.0, 0.0)
else:
normalized_obj = np.clip(
(objective - min_obj) / (max_obj - min_obj), 0.0, 1.0)
color = cmap(normalized_obj)
polygon = [vor.vertices[i] for i in region]
ax.fill(*zip(*polygon), color=color, ec="k", lw=lw)
polygon = vor.vertices[region]
ax.fill(*zip(*polygon), color=color, ec=ec, lw=lw)

# Create a colorbar.
mappable = ScalarMappable(cmap=cmap)
mappable.set_clim(min_obj, max_obj)

# Plot the sample points and centroids.
if plot_samples:
ax.plot(samples[:, 0], samples[:, 1], "o", c="gray", ms=ms)
ax.plot(samples[:, 0], samples[:, 1], "o", c="grey", ms=ms)
if plot_centroids:
ax.plot(centroids[:, 0], centroids[:, 1], "ko", ms=ms)
ax.plot(centroids[:, 0], centroids[:, 1], "o", c="black", ms=ms)

# Create color bar.
_set_cbar(mappable, ax, cbar, cbar_kwargs)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions tests/visualize/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,14 @@ def test_sliding_archive_mismatch_xy_with_boundaries():
sliding_boundaries_archive_heatmap(archive, boundary_lw=0.5)


@image_comparison(baseline_images=["cvt_archive_heatmap_vmin_equals_vmax"],
remove_text=False,
extensions=["png"])
def test_cvt_archive_heatmap_vmin_equals_vmax(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, vmin=-0.5, vmax=-0.5)


@image_comparison(baseline_images=["cvt_archive_heatmap_with_centroids"],
remove_text=False,
extensions=["png"])
Expand All @@ -628,6 +636,26 @@ def test_cvt_archive_heatmap_with_samples(cvt_archive):
cvt_archive_heatmap(cvt_archive, plot_samples=True)


def test_cvt_archive_heatmap_no_samples_error():
# This archive has no samples since custom centroids were passed in.
archive = CVTArchive(solution_dim=2,
cells=2,
ranges=[(-1, 1), (-1, 1)],
custom_centroids=[[0, 0], [1, 1]])

# Thus, plotting samples on this archive should fail.
with pytest.raises(ValueError):
cvt_archive_heatmap(archive, lw=3.0, ec="grey", plot_samples=True)


@image_comparison(baseline_images=["cvt_archive_heatmap_voronoi_style"],
remove_text=False,
extensions=["png"])
def test_cvt_archive_heatmap_voronoi_style(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, lw=3.0, ec="grey")


#
# Parallel coordinate plot test
#
Expand Down

0 comments on commit f3a0384

Please sign in to comment.