"""Predictive check using densities."""
import warnings
from collections.abc import Mapping, Sequence
from importlib import import_module
from typing import Any, Literal
import numpy as np
import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.dist_plot import plot_dist
from arviz_plots.plots.utils import (
get_visual_kwargs,
process_group_variables_coords,
set_wrap_layout,
)
from arviz_plots.plots.utils_plot_types import warn_if_binary, warn_if_discrete
[docs]
def plot_ppc_dist(
dt,
*,
var_names=None,
filter_vars=None,
group="posterior_predictive",
coords=None,
sample_dims=None,
kind=None,
num_samples=50,
plot_collection=None,
backend=None,
labeller=None,
aes_by_visuals: Mapping[
Literal["predictive_dist", "observed_dist", "title"], Sequence[str]
] = None,
visuals: Mapping[
Literal["predictive_dist", "observed_dist", "title", "remove_axis"],
Mapping[str, Any] | bool,
] = None,
stats: Mapping[
Literal["predictive_dist", "observed_dist"], Mapping[str, Any] | xr.Dataset
] = None,
**pc_kwargs,
):
"""
Plot 1D marginals for the predictive distribution and the observed data.
Parameters
----------
dt : DataTree
If group is "posterior_predictive", it should contain the ``posterior_predictive`` and
``observed_data`` groups. If group is "prior_predictive", it should contain the
``prior_predictive`` group.
var_names : str or list of str, optional
One or more variables to be plotted.
Prefix the variables by ~ when you want to exclude them from the plot.
filter_vars : {None, “like”, “regex”}, default=None
If None, interpret var_names as the real variables names.
If “like”, interpret var_names as substrings of the real variables names.
If “regex”, interpret var_names as regular expressions on the real variables names.
group : str,
Group to be plotted. Defaults to "posterior_predictive".
It could also be "prior_predictive".
coords : dict, optional
sample_dims : str or sequence of hashable, optional
Sampled dimensions used to overlay `num_samples` lines.
Defaults to ``rcParams["data.sample_dims"]``
kind : {"auto", "kde", "hist", "ecdf", "dot"}, optional
How to represent the marginal density. Defaults to ``rcParams["plot.density_kind"]``
If "dot" is selected, only the top points of the predictive draws are shown.
num_samples : int, optional
Number of samples to plot. Defaults to 100.
plot_collection : PlotCollection, optional
backend : {"matplotlib", "bokeh"}, optional
labeller : labeller, optional
aes_by_visuals : mapping of {str : sequence of str}, optional
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `visuals` except "remove_axis".
With a single model, no aesthetic mappings are generated by default,
each variable+coord combination gets a :term:`plot` but they all look the same,
unless there are user provided aesthetic mappings.
With multiple models, ``plot_dist`` maps "color" and "y" to the "model" dimension.
By default, all aesthetics but "y" are mapped to the distribution representation,
and if multiple models are present, "color" and "y" are mapped to the
credible interval and the point estimate.
When "point_estimate" key is provided but "point_estimate_text" isn't,
the values assigned to the first are also used for the second.
visuals : mapping of {str : mapping or bool}, optional
Valid keys are:
* predictive_dist, observed_dist -> passed to a function that depends on
the `kind` argument.
* "kde" -> passed to :func:`~arviz_plots.visuals.line_xy`
* "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line`
* "hist" -> passed to :func: `~arviz_plots.visuals.step_hist`
* "dot" -> passed to :func:`~arviz_plots.visuals.scatter_xy`
* title -> passed to :func:`~arviz_plots.visuals.labelled_title`
* remove_axis -> not passed anywhere, can only be ``False`` to skip calling this function
observed_dist defaults to False, no observed data is plotted, if group is
"prior_predictive".
stats : mapping, optional
Valid keys are:
* predictive_dist, observed_dist -> passed to kde, ecdf, ...
**pc_kwargs
Passed to :meth:`~arviz_plots.PlotCollection.wrap`
Returns
-------
PlotCollection
See Also
--------
:ref:`plots_intro` :
General introduction to batteries-included plotting functions, common use and logic overview
Examples
--------
Make a plot of the posterior predictive distribution vs the observed data.
We used an ECDF representation customized the colors.
.. plot::
:context: close-figs
>>> from arviz_plots import plot_ppc_dist, style
>>> style.use("arviz-variat")
>>> from arviz_base import load_arviz_data
>>> radon = load_arviz_data('radon')
>>> pc = plot_ppc_dist(
>>> radon,
>>> kind="ecdf",
>>> visuals={
>>> "predictive_dist": {"color":"C1"},
>>> "observed_dist": {"color":"C3"}
>>> },
>>> )
Faceting and aesthetics mappings happen on unique coordinate values. If there are repeated
coordinate values they will be grouped and reduced along with `sample_dims`.
This example updates the coordinate values to have repeated values and requests
faceting along the "obs_id" dimension. It also keeps 90 out of the 919 observations
in the original dataset; otherwise we'd end up with 85 :term:`plots` in the :term:`figure`
.. plot::
:context: close-figs
>>> county = radon.constant_data["County"][radon.constant_data["county_idx"]]
>>> reindexed_dt = radon.filter(
>>> lambda node: node.name in ("observed_data", "posterior_predictive")
>>> ).map_over_datasets(
>>> lambda node: node.assign_coords(obs_id=county).isel(obs_id=slice(None, 90))
>>> )
>>> pc = azp.plot_ppc_dist(reindexed_dt, cols=["obs_id"], kind="auto")
Note how counties with a lot of observations have a smoother ECDF whereas counties
with only 2-3 observations have only 2-3 steps in their ECDF.
.. minigallery:: plot_ppc_dist
"""
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if isinstance(sample_dims, str):
sample_dims = [sample_dims]
sample_dims = list(sample_dims)
if kind is None:
kind = rcParams["plot.density_kind"]
if stats is None:
stats = {}
else:
stats = stats.copy()
if visuals is None:
visuals = {}
else:
visuals = visuals.copy()
if backend is None:
if plot_collection is None:
backend = rcParams["plot.backend"]
else:
backend = plot_collection.backend
if kind not in ("kde", "hist", "ecdf", "dot", "auto"):
raise ValueError("kind must be either 'kde', 'hist', 'ecdf' or 'dot'")
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
rng = np.random.default_rng(4214)
pp_dims = [dims for dims in dt[group].dims if dims not in sample_dims]
predictive_dist = process_group_variables_coords(
dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
)
if "observed_data" in dt:
observed_dist = process_group_variables_coords(
dt,
group="observed_data",
var_names=var_names,
filter_vars=filter_vars,
coords=coords,
)
else:
observed_dist = None
warn_if_binary(observed_dist, predictive_dist)
warn_if_discrete(observed_dist, predictive_dist, kind)
# Select a random subset of samples
n_pp_samples = np.prod(
[predictive_dist.sizes[dim] for dim in sample_dims if dim in predictive_dist.dims]
)
if num_samples > n_pp_samples:
num_samples = n_pp_samples
warnings.warn("num_samples is larger than the number of predictive samples.")
if kind == "dot":
stats.setdefault("predictive_dist", {"top_only": True})
pp_sample_ix = rng.choice(n_pp_samples, size=num_samples, replace=False)
predictive_dist = predictive_dist.stack(sample=sample_dims).isel(sample=pp_sample_ix)
if plot_collection is None:
pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy()
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
pc_kwargs["aes"].setdefault("overlay_ppc", ["sample"])
pc_kwargs.setdefault("cols", "__variable__")
pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, predictive_dist)
plot_collection = PlotCollection.wrap(
predictive_dist,
backend=backend,
**pc_kwargs,
)
if aes_by_visuals is None:
aes_by_visuals = {}
else:
aes_by_visuals = aes_by_visuals.copy()
if labeller is None:
labeller = BaseLabeller()
# We don't want credible_interval or point_estimate to be mapped to the density representation
visuals.setdefault("credible_interval", False)
visuals.setdefault("point_estimate", False)
visuals.setdefault("point_estimate_text", False)
visuals.setdefault("rug_plot", False)
# Plot the predictive density
pred_density_kwargs = get_visual_kwargs(visuals, "predictive_dist")
if pred_density_kwargs is not False:
visuals.setdefault("dist", pred_density_kwargs)
visuals["dist"].setdefault("alpha", 0.3)
plot_collection = plot_dist(
predictive_dist,
group=group,
sample_dims=pp_dims,
kind=kind,
visuals=visuals,
aes_by_visuals=aes_by_visuals,
pc_kwargs=pc_kwargs,
plot_collection=plot_collection,
stats={"dist": stats.get("predictive_dist", {})},
)
plot_collection.rename_visuals(dist="predictive_dist")
# Plot the observed density
observed_density_kwargs = get_visual_kwargs(
visuals, "observed_dist", False if group == "prior_predictive" else None
)
if observed_density_kwargs is not False:
observed_density_kwargs.setdefault("color", "B1")
observed_visuals = {
"dist": observed_density_kwargs,
"credible_interval": False,
"point_estimate": False,
"point_estimate_text": False,
"title": False,
"rug_plot": False,
"remove_axis": False,
}
plot_collection = plot_dist(
observed_dist,
group="observed_data",
sample_dims=pp_dims,
kind=kind,
visuals=observed_visuals,
aes_by_visuals=aes_by_visuals,
plot_collection=plot_collection,
stats={"dist": stats.get("observed_dist", {})},
)
plot_collection.rename_visuals(dist="observed_dist")
return plot_collection