Skip to content

Commit

Permalink
plot_ppc related fixes (#2283)
Browse files Browse the repository at this point in the history
* modify plot_ppc default labeling

* several plot_ppc fixes

* Multiple docstring fixes and improvements
* Avoid adding empty legends, don't add anything instead
* Ensure plot behaves correctly even if chain and draw aren't the first
  dimensions positionally
  - Add test to check behaviour and avoid regression

* black and changelog
  • Loading branch information
OriolAbril authored Oct 30, 2023
1 parent 4557004 commit c2d968f
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 44 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
### New features

### Maintenance and fixes
- Update requirements: matplotlib>=3.5, pandas>=1.4.0, numpy>=1.22.0
- Update requirements: matplotlib>=3.5, pandas>=1.4.0, numpy>=1.22.0 ([2280](https://github.com/arviz-devs/arviz/pull/2280))
- Fix behaviour of `plot_ppc` when dimension order isn't `chain, draw, ...` ([2283](https://github.com/arviz-devs/arviz/pull/2283))
- Avoid repeating the variable name in `plot_ppc`, `plot_bpv`, `plot_loo_pit`... when repeated. ([2283](https://github.com/arviz-devs/arviz/pull/2283))

### Deprecation

### Documentation
- Several fixes in `plot_ppc` docstring ([2283](https://github.com/arviz-devs/arviz/pull/2283))

## v0.16.1 (2023 Jul 18)

Expand Down
2 changes: 2 additions & 0 deletions arviz/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def var_pp_to_str(self, var_name, pp_var_name):
"""WIP."""
var_name_str = self.var_name_to_str(var_name)
pp_var_name_str = self.var_name_to_str(pp_var_name)
if var_name_str == pp_var_name_str:
return f"{var_name_str}"
return f"{var_name_str} / {pp_var_name_str}"

def model_name_to_str(self, model_name):
Expand Down
2 changes: 0 additions & 2 deletions arviz/plots/backends/matplotlib/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,6 @@ def plot_ppc(
if legend:
if i == 0:
ax_i.legend(fontsize=xt_labelsize * 0.75)
else:
ax_i.legend([])

if backend_show(show):
plt.show()
Expand Down
83 changes: 42 additions & 41 deletions arviz/plots/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,37 +50,34 @@ def plot_ppc(
Parameters
----------
data: az.InferenceData object
data : InferenceData
:class:`arviz.InferenceData` object containing the observed and posterior/prior
predictive data.
kind: str
Type of plot to display ("kde", "cumulative", or "scatter"). Defaults to `kde`.
alpha: float
kind : str, default "kde"
Type of plot to display ("kde", "cumulative", or "scatter").
alpha : float, optional
Opacity of posterior/prior predictive density curves.
Defaults to 0.2 for ``kind = kde`` and cumulative, for scatter defaults to 0.7.
mean: bool
mean : bool, default True
Whether or not to plot the mean posterior/prior predictive distribution.
Defaults to ``True``.
observed: bool, default True
observed : bool, default True
Whether or not to plot the observed data.
observed_rug: bool, default False
observed_rug : bool, default False
Whether or not to plot a rug plot for the observed data. Only valid if `observed` is
`True` and for kind `kde` or `cumulative`.
color: str
Valid matplotlib ``color``. Defaults to ``C0``.
color: list
color : list, optional
List with valid matplotlib colors corresponding to the posterior/prior predictive
distribution, observed data and mean of the posterior/prior predictive distribution.
Defaults to ["C0", "k", "C1"].
grid : tuple
grid : tuple, optional
Number of rows and columns. Defaults to None, the rows and columns are
automatically inferred.
figsize: tuple
figsize : tuple, optional
Figure size. If None, it will be defined automatically.
textsize: float
textsize : float, optional
Text size scaling factor for labels, titles and lines. If None, it will be
autoscaled based on ``figsize``.
data_pairs: dict
data_pairs : dict, optional
Dictionary containing relations between observed data and posterior/prior predictive data.
Dictionary structure:
Expand All @@ -90,84 +87,86 @@ def plot_ppc(
For example, ``data_pairs = {'y' : 'y_hat'}``
If None, it will assume that the observed data and the posterior/prior
predictive data have the same variable name.
var_names: list of variable names
var_names : list of str, optional
Variables to be plotted, if `None` all variable are plotted. Prefix the
variables by ``~`` when you want to exclude them from the plot.
filter_vars: {None, "like", "regex"}, optional, default=None
filter_vars : {None, "like", "regex"}, default None
If `None` (default), interpret var_names as the real variables names. If "like",
interpret var_names as substrings of the real variables names. If "regex",
interpret var_names as regular expressions on the real variables names. A la
``pandas.filter``.
coords: dict
coords : dict, optional
Dictionary mapping dimensions to selected coordinates to be plotted.
Dimensions without a mapping specified will include all coordinates for
that dimension. Defaults to including all coordinates for all
dimensions if None.
flatten: list
flatten : list
List of dimensions to flatten in ``observed_data``. Only flattens across the coordinates
specified in the ``coords`` argument. Defaults to flattening all of the dimensions.
flatten_pp: list
flatten_pp : list
List of dimensions to flatten in posterior_predictive/prior_predictive. Only flattens
across the coordinates specified in the ``coords`` argument. Defaults to flattening all
of the dimensions. Dimensions should match flatten excluding dimensions for ``data_pairs``
parameters. If ``flatten`` is defined and ``flatten_pp`` is None, then
``flatten_pp = flatten``.
num_pp_samples: int
num_pp_samples : int
The number of posterior/prior predictive samples to plot. For ``kind`` = 'scatter' and
``animation = False`` if defaults to a maximum of 5 samples and will set jitter to 0.7.
unless defined. Otherwise it defaults to all provided samples.
random_seed: int
random_seed : int
Random number generator seed passed to ``numpy.random.seed`` to allow
reproducibility of the plot. By default, no seed will be provided
and the plot will change each call if a random sample is specified
by ``num_pp_samples``.
jitter: float
jitter : float, default 0
If ``kind`` is "scatter", jitter will add random uniform noise to the height
of the ppc samples and observed data. By default 0.
animated: bool
of the ppc samples and observed data.
animated : bool, default False
Create an animation of one posterior/prior predictive sample per frame.
Defaults to ``False``. Only works with matploblib backend.
Only works with matploblib backend.
To run animations inside a notebook you have to use the `nbAgg` matplotlib's backend.
Try with `%matplotlib notebook` or `%matplotlib nbAgg`. You can switch back to the
default matplotlib's backend with `%matplotlib inline` or `%matplotlib auto`.
If switching back and forth between matplotlib's backend, you may need to run twice the cell
with the animation.
If you experience problems rendering the animation try setting
`animation_kwargs({'blit':False}`) or changing the matplotlib's backend (e.g. to TkAgg)
If you run the animation from a script write `ax, ani = az.plot_ppc(.)`
``animation_kwargs({'blit':False})`` or changing the matplotlib's backend (e.g. to TkAgg)
If you run the animation from a script write ``ax, ani = az.plot_ppc(.)``
animation_kwargs : dict
Keywords passed to :class:`matplotlib.animation.FuncAnimation`. Ignored with
matplotlib backend.
legend : bool
Add legend to figure. By default ``True``.
labeller : labeller instance, optional
legend : bool, default True
Add legend to figure.
labeller : labeller, optional
Class providing the method ``make_pp_label`` to generate the labels in the plot titles.
Read the :ref:`label_guide` for more details and usage examples.
ax: numpy array-like of matplotlib axes or bokeh figures, optional
ax : numpy array-like of matplotlib_axes or bokeh figures, optional
A 2D array of locations into which to plot the densities. If not supplied, Arviz will create
its own array of plot areas (and return it).
backend: str, optional
backend : str, optional
Select plotting backend {"matplotlib","bokeh"}. Default to "matplotlib".
backend_kwargs: bool, optional
backend_kwargs : dict, optional
These are kwargs specific to the backend being used, passed to
:func:`matplotlib.pyplot.subplots` or :func:`bokeh.plotting.figure`.
For additional documentation check the plotting method of the backend.
group: {"prior", "posterior"}, optional
group : {"prior", "posterior"}, optional
Specifies which InferenceData group should be plotted. Defaults to 'posterior'.
Other value can be 'prior'.
show: bool, optional
show : bool, optional
Call backend show function.
Returns
-------
axes: matplotlib axes or bokeh figures
axes : matplotlib_axes or bokeh_figures
ani : matplotlib.animation.FuncAnimation, optional
Only provided if `animated` is ``True``.
See Also
--------
plot_bpv: Plot Bayesian p-value for observed data and Posterior/Prior predictive.
plot_lm: Posterior predictive and mean plots for regression-like data.
plot_ppc: plot for posterior/prior predictive checks.
plot_ts: Plot timeseries data.
plot_bpv : Plot Bayesian p-value for observed data and Posterior/Prior predictive.
plot_loo_pit : Plot for posterior predictive checks using cross validation.
plot_lm : Posterior predictive and mean plots for regression-like data.
plot_ts : Plot timeseries data.
Examples
--------
Expand Down Expand Up @@ -308,6 +307,7 @@ def plot_ppc(
skip_dims=set(flatten),
var_names=var_names,
combined=True,
dim_order=["chain", "draw"],
)
),
"plot_ppc",
Expand All @@ -322,6 +322,7 @@ def plot_ppc(
var_names=pp_var_names,
skip_dims=set(flatten_pp),
combined=True,
dim_order=["chain", "draw"],
),
)
]
Expand Down
23 changes: 23 additions & 0 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from matplotlib import animation
from pandas import DataFrame
from scipy.stats import gaussian_kde, norm
import xarray as xr

from ...data import from_dict, load_arviz_data
from ...plots import (
Expand Down Expand Up @@ -732,6 +733,28 @@ def test_plot_ppc(models, kind, alpha, animated, observed, observed_rug):
assert axes


def test_plot_ppc_transposed():
idata = load_arviz_data("rugby")
idata.map(
lambda ds: ds.assign(points=xr.concat((ds.home_points, ds.away_points), "field")),
groups="observed_vars",
inplace=True,
)
assert idata.posterior_predictive.points.dims == ("field", "chain", "draw", "match")
ax = plot_ppc(
idata,
kind="scatter",
var_names="points",
flatten=["field"],
coords={"match": ["Wales Italy"]},
random_seed=3,
num_pp_samples=8,
)
x, y = ax.get_lines()[2].get_data()
assert not np.isclose(y[0], 0)
assert np.all(np.array([40, 43, 10, 9]) == x)


@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
@pytest.mark.parametrize("jitter", [None, 0, 0.1, 1, 3])
@pytest.mark.parametrize("animated", [False, True])
Expand Down

0 comments on commit c2d968f

Please sign in to comment.