From f64d942ecb22743ba406c1fcdb54f8d885fd3a34 Mon Sep 17 00:00:00 2001 From: "OriolAbril(HEL)" Date: Thu, 3 Feb 2022 15:13:53 +0200 Subject: [PATCH 1/6] draft subsampling bootstrap for mcse --- arviz/stats/diagnostics.py | 42 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/arviz/stats/diagnostics.py b/arviz/stats/diagnostics.py index cace5a2ec6..a1d6e2df52 100644 --- a/arviz/stats/diagnostics.py +++ b/arviz/stats/diagnostics.py @@ -341,7 +341,7 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None): ) -def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None): +def mcse(data, *, var_names=None, method="mean", prob=None, func=None, dask_kwargs=None): """Calculate Markov Chain Standard Error statistic. Parameters @@ -398,6 +398,7 @@ def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None): "sd": _mcse_sd, "median": _mcse_median, "quantile": _mcse_quantile, + "func": _mcse_func_sbm, } if method not in methods: raise TypeError( @@ -410,6 +411,9 @@ def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None): if method == "quantile" and prob is None: raise TypeError("Quantile (prob) information needs to be defined.") + if method == "func" and func is None: + raise TypeError("func argument needs to be defined.") + if isinstance(data, np.ndarray): data = np.atleast_2d(data) if len(data.shape) < 3: @@ -430,7 +434,11 @@ def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None): dataset = dataset if var_names is None else dataset[var_names] ufunc_kwargs = {"ravel": False} - func_kwargs = {} if prob is None else {"prob": prob} + func_kwargs = {} + if prob is not None: + func_kwargs["prob"] = prob + elif func is not None: + func_kwargs["func"] = func return _wrap_xarray_ufunc( mcse_func, dataset, @@ -820,6 +828,36 @@ def _mcse_mean(ary): return mcse_mean_value +def _mcse_func_sbm(ary, func): + """Compute the Markov Chain error on an arbitrary function.""" + ary = np.asarray(ary) + if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): + return np.nan + ess = _ess_mean(ary) + func_estimate_sd = _sbm(ary, func) + mcse_func_value = func_estimate_sd / np.sqrt(ess) + return mcse_func_value + +def _sbm(ary, func): + """Subsampling bootstrap method. + + References + ---------- + .. [1] Doss, Charles R., et al. "Markov chain Monte Carlo estimation of quantiles." + *Electronic Journal of Statistics* 8.2 (2014): 2448-2478. + https://doi.org/10.1214/14-EJS957 + + """ + flat_ary = np.ravel(ary) + n = len(flat_ary) + b = int(np.sqrt(n)) + func_estimates = np.empty(n-b) + for i in range(n-b): + sub_ary = flat_ary[i:i+b] + func_estimates[i] = func(sub_ary) + func_estimate_sd = np.sqrt(b * np.var(func_estimates, ddof=0)) + return func_estimate_sd + def _mcse_sd(ary): """Compute the Markov Chain sd error.""" _numba_flag = Numba.numba_flag From 7373179291ce47dc5834440ce5e571d68de997cd Mon Sep 17 00:00:00 2001 From: "OriolAbril(HEL)" Date: Fri, 11 Mar 2022 16:45:32 +0200 Subject: [PATCH 2/6] fix n term in sbm mcse method --- arviz/stats/diagnostics.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/arviz/stats/diagnostics.py b/arviz/stats/diagnostics.py index a1d6e2df52..82959ad6b2 100644 --- a/arviz/stats/diagnostics.py +++ b/arviz/stats/diagnostics.py @@ -417,6 +417,10 @@ def mcse(data, *, var_names=None, method="mean", prob=None, func=None, dask_kwar if isinstance(data, np.ndarray): data = np.atleast_2d(data) if len(data.shape) < 3: + if data.size < 1000 and method == "func": + warnings.warn( + "Not enough samples for reliable estimate of MCSE for arbitrary functions" + ) if prob is not None: return mcse_func(data, prob=prob) # pylint: disable=unexpected-keyword-arg @@ -429,6 +433,10 @@ def mcse(data, *, var_names=None, method="mean", prob=None, func=None, dask_kwar raise TypeError(msg) dataset = convert_to_dataset(data, group="posterior") + if (dataset.dims["chain"] * dataset.dims["draw"]) < 1000 and method == "func": + warnings.warn( + "Not enough samples for reliable estimate of MCSE for arbitrary functions" + ) var_names = _var_names(var_names, dataset) dataset = dataset if var_names is None else dataset[var_names] @@ -831,11 +839,11 @@ def _mcse_mean(ary): def _mcse_func_sbm(ary, func): """Compute the Markov Chain error on an arbitrary function.""" ary = np.asarray(ary) - if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): + if _not_valid(ary, shape_kwargs=dict(min_draws=10, min_chains=1)): return np.nan - ess = _ess_mean(ary) + n = ary.size func_estimate_sd = _sbm(ary, func) - mcse_func_value = func_estimate_sd / np.sqrt(ess) + mcse_func_value = func_estimate_sd / np.sqrt(n) return mcse_func_value def _sbm(ary, func): From f123d255a8745f7402ae6432eb67acc209ea4bd7 Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Mon, 14 Mar 2022 11:01:16 +0200 Subject: [PATCH 3/6] fix behaviour on arrays --- arviz/stats/diagnostics.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/arviz/stats/diagnostics.py b/arviz/stats/diagnostics.py index 82959ad6b2..bdf10105cb 100644 --- a/arviz/stats/diagnostics.py +++ b/arviz/stats/diagnostics.py @@ -414,6 +414,12 @@ def mcse(data, *, var_names=None, method="mean", prob=None, func=None, dask_kwar if method == "func" and func is None: raise TypeError("func argument needs to be defined.") + func_kwargs = {} + if prob is not None: + func_kwargs["prob"] = prob + elif func is not None: + func_kwargs["func"] = func + if isinstance(data, np.ndarray): data = np.atleast_2d(data) if len(data.shape) < 3: @@ -421,16 +427,13 @@ def mcse(data, *, var_names=None, method="mean", prob=None, func=None, dask_kwar warnings.warn( "Not enough samples for reliable estimate of MCSE for arbitrary functions" ) - if prob is not None: - return mcse_func(data, prob=prob) # pylint: disable=unexpected-keyword-arg - - return mcse_func(data) - - msg = ( - "Only uni-dimensional ndarray variables are supported." - " Please transform first to dataset with `az.convert_to_dataset`." - ) - raise TypeError(msg) + return mcse_func(data, **func_kwargs) + else: + msg = ( + "Only uni-dimensional ndarray variables are supported." + " Please transform first to dataset with `az.convert_to_dataset`." + ) + raise TypeError(msg) dataset = convert_to_dataset(data, group="posterior") if (dataset.dims["chain"] * dataset.dims["draw"]) < 1000 and method == "func": @@ -442,11 +445,6 @@ def mcse(data, *, var_names=None, method="mean", prob=None, func=None, dask_kwar dataset = dataset if var_names is None else dataset[var_names] ufunc_kwargs = {"ravel": False} - func_kwargs = {} - if prob is not None: - func_kwargs["prob"] = prob - elif func is not None: - func_kwargs["func"] = func return _wrap_xarray_ufunc( mcse_func, dataset, From 697d9a5ebce9b9a64902e6b4564bd696cdf7bdab Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Wed, 23 Mar 2022 20:29:44 +0200 Subject: [PATCH 4/6] fix issue with numba version --- arviz/stats/diagnostics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/stats/diagnostics.py b/arviz/stats/diagnostics.py index bdf10105cb..b792cb2df3 100644 --- a/arviz/stats/diagnostics.py +++ b/arviz/stats/diagnostics.py @@ -827,7 +827,7 @@ def _mcse_mean(ary): return np.nan ess = _ess_mean(ary) if _numba_flag: - sd = _sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)) + sd = _sqrt(svar(np.ravel(ary), ddof=1), 0) else: sd = np.std(ary, ddof=1) mcse_mean_value = sd / np.sqrt(ess) From 00bdf7dd6ea061e9b3de6bf6726cc5601fc6a5fa Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Mon, 15 Aug 2022 12:01:31 +0200 Subject: [PATCH 5/6] updates to api --- arviz/stats/diagnostics.py | 84 ++++++++++++++++++++++++++------------ 1 file changed, 59 insertions(+), 25 deletions(-) diff --git a/arviz/stats/diagnostics.py b/arviz/stats/diagnostics.py index b792cb2df3..69ac15bcfa 100644 --- a/arviz/stats/diagnostics.py +++ b/arviz/stats/diagnostics.py @@ -341,40 +341,70 @@ def rhat(data, *, var_names=None, method="rank", dask_kwargs=None): ) -def mcse(data, *, var_names=None, method="mean", prob=None, func=None, dask_kwargs=None): - """Calculate Markov Chain Standard Error statistic. +def mcse( + data, + *, + var_names=None, + method="mean", + prob=None, + func=None, + mcse_kwargs=None, + func_kwargs=None, + dask_kwargs=None, +): + r"""Calculate Markov Chain Standard Error statistic. Parameters ---------- - data : obj + data : InferenceData-like or 2D array-like Any object that can be converted to an :class:`arviz.InferenceData` object Refer to documentation of :func:`arviz.convert_to_dataset` for details For ndarray: shape = (chain, draw). For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``. - var_names : list + var_names : list of str, optional Names of variables to include in the rhat report - method : str - Select mcse method. Valid methods are: + method : {'mean', 'sd', 'median', 'quantile', 'func'}, optional + The method to use when estimating the MCSE. - "mean" - "sd" - "median" - "quantile" + - "func" - prob : float + Methods "mean", "sd", "median" and "quantile" are described in [1]_. + + prob : float, optional Quantile information. + func : callable, optional + Summary function whose MCSE should be calculated. Only used whem + method is "func". + TODO: add call signature info, something like ``func(ary, **func_kwargs)`` + func_kwargs : dict, optional + Keyword arguments passed to *func* when calling it. dask_kwargs : dict, optional Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`. Returns ------- xarray.Dataset - Return the msce dataset + Dataset with the MCSE results + + Other Parameters + ---------------- + mcse_kwargs : dict, optional + Extra keyword arguments passed to the MCSE estimation method. See Also -------- - ess : Compute autocovariance estimates for every lag for the input array. summary : Create a data frame with summary statistics. plot_mcse : Plot quantile or local Monte Carlo Standard Error. + ess : Compute autocovariance estimates for every lag for the input array. + + References + ---------- + .. [1] Vehtari, Aki, et al. "Rank-normalization, folding, and localization: an improved + $\hat{R}$ for assessing convergence of MCMC (with discussion)." + Bayesian analysis 16.2 (2021): 667-718. https://doi.org/10.1214/20-BA1221 Examples -------- @@ -414,11 +444,12 @@ def mcse(data, *, var_names=None, method="mean", prob=None, func=None, dask_kwar if method == "func" and func is None: raise TypeError("func argument needs to be defined.") - func_kwargs = {} + mcse_kwargs = {} if mcse_kwargs is None else mcse_kwargs if prob is not None: - func_kwargs["prob"] = prob + mcse_kwargs.setdefault("prob", prob) elif func is not None: - func_kwargs["func"] = func + mcse_kwargs.setdefault("func", func) + mcse_kwargs.setdefault("func_kwargs", func_kwargs) if isinstance(data, np.ndarray): data = np.atleast_2d(data) @@ -427,7 +458,7 @@ def mcse(data, *, var_names=None, method="mean", prob=None, func=None, dask_kwar warnings.warn( "Not enough samples for reliable estimate of MCSE for arbitrary functions" ) - return mcse_func(data, **func_kwargs) + return mcse_func(data, **mcse_kwargs) else: msg = ( "Only uni-dimensional ndarray variables are supported." @@ -437,9 +468,7 @@ def mcse(data, *, var_names=None, method="mean", prob=None, func=None, dask_kwar dataset = convert_to_dataset(data, group="posterior") if (dataset.dims["chain"] * dataset.dims["draw"]) < 1000 and method == "func": - warnings.warn( - "Not enough samples for reliable estimate of MCSE for arbitrary functions" - ) + warnings.warn("Not enough samples for reliable estimate of MCSE for arbitrary functions") var_names = _var_names(var_names, dataset) dataset = dataset if var_names is None else dataset[var_names] @@ -449,7 +478,7 @@ def mcse(data, *, var_names=None, method="mean", prob=None, func=None, dask_kwar mcse_func, dataset, ufunc_kwargs=ufunc_kwargs, - func_kwargs=func_kwargs, + func_kwargs=mcse_kwargs, dask_kwargs=dask_kwargs, ) @@ -834,17 +863,22 @@ def _mcse_mean(ary): return mcse_mean_value -def _mcse_func_sbm(ary, func): +def _mcse_func_sbm(ary, func, b=None, func_kwargs=None): """Compute the Markov Chain error on an arbitrary function.""" ary = np.asarray(ary) if _not_valid(ary, shape_kwargs=dict(min_draws=10, min_chains=1)): return np.nan n = ary.size - func_estimate_sd = _sbm(ary, func) + if b is None: + b = int(np.sqrt(n)) + if func_kwargs is None: + func_kwargs = {} + func_estimate_sd = _sbm(ary, func, b=b, func_kwargs=func_kwargs) mcse_func_value = func_estimate_sd / np.sqrt(n) return mcse_func_value -def _sbm(ary, func): + +def _sbm(ary, func, b, func_kwargs): """Subsampling bootstrap method. References @@ -856,14 +890,14 @@ def _sbm(ary, func): """ flat_ary = np.ravel(ary) n = len(flat_ary) - b = int(np.sqrt(n)) - func_estimates = np.empty(n-b) - for i in range(n-b): - sub_ary = flat_ary[i:i+b] - func_estimates[i] = func(sub_ary) + func_estimates = np.empty(n - b) + for i in range(n - b): + sub_ary = flat_ary[i : i + b] + func_estimates[i] = func(sub_ary, **func_kwargs) func_estimate_sd = np.sqrt(b * np.var(func_estimates, ddof=0)) return func_estimate_sd + def _mcse_sd(ary): """Compute the Markov Chain sd error.""" _numba_flag = Numba.numba_flag From 13dd989037df15d3a93a41d63c07b1f91dae84b7 Mon Sep 17 00:00:00 2001 From: "OriolAbril(HEL)" Date: Thu, 18 Aug 2022 01:56:29 +0300 Subject: [PATCH 6/6] add var_func argument --- arviz/stats/diagnostics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/arviz/stats/diagnostics.py b/arviz/stats/diagnostics.py index 69ac15bcfa..5f3bcd4e40 100644 --- a/arviz/stats/diagnostics.py +++ b/arviz/stats/diagnostics.py @@ -863,7 +863,7 @@ def _mcse_mean(ary): return mcse_mean_value -def _mcse_func_sbm(ary, func, b=None, func_kwargs=None): +def _mcse_func_sbm(ary, func, b=None, var_func=np.var, func_kwargs=None): """Compute the Markov Chain error on an arbitrary function.""" ary = np.asarray(ary) if _not_valid(ary, shape_kwargs=dict(min_draws=10, min_chains=1)): @@ -873,12 +873,12 @@ def _mcse_func_sbm(ary, func, b=None, func_kwargs=None): b = int(np.sqrt(n)) if func_kwargs is None: func_kwargs = {} - func_estimate_sd = _sbm(ary, func, b=b, func_kwargs=func_kwargs) + func_estimate_sd = _sbm(ary, func, b=b, var_func=var_func, func_kwargs=func_kwargs) mcse_func_value = func_estimate_sd / np.sqrt(n) return mcse_func_value -def _sbm(ary, func, b, func_kwargs): +def _sbm(ary, func, b, var_func, func_kwargs): """Subsampling bootstrap method. References @@ -894,7 +894,7 @@ def _sbm(ary, func, b, func_kwargs): for i in range(n - b): sub_ary = flat_ary[i : i + b] func_estimates[i] = func(sub_ary, **func_kwargs) - func_estimate_sd = np.sqrt(b * np.var(func_estimates, ddof=0)) + func_estimate_sd = np.sqrt(b * var_func(func_estimates)) return func_estimate_sd