diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 9b6d11c3dc..ce8a8a28eb 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -236,6 +236,10 @@ def __delattr__(self, group: str) -> None: self._groups_warmup.remove(group) object.__delattr__(self, group) + def __delitem__(self, key: str) -> None: + """Delete an item from the InferenceData object using del idata[key].""" + self.__delattr__(key) + @property def _groups_all(self) -> List[str]: return self._groups + self._groups_warmup diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index ae8c3e4cad..2f1179a476 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -496,7 +496,7 @@ def test_sel_chain_prior(self): with pytest.raises(KeyError): idata.sel(inplace=False, chain_prior=True, chain=[0, 1, 3]) - @pytest.mark.parametrize("use", ("del", "delattr")) + @pytest.mark.parametrize("use", ("del", "delattr", "delitem")) def test_del(self, use): # create inference data object data = np.random.normal(size=(4, 500, 8)) @@ -523,6 +523,8 @@ def test_del(self, use): # Use del method if use == "del": del idata.sample_stats + elif use == "delitem": + del idata["sample_stats"] else: delattr(idata, "sample_stats")