Skip to content

Commit

Permalink
Make from_dict more flexible, and add from_pytree
Browse files Browse the repository at this point in the history
  • Loading branch information
ColCarroll committed Nov 14, 2023
1 parent 2034668 commit 78b9baf
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 3 deletions.
3 changes: 2 additions & 1 deletion arviz/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .io_cmdstan import from_cmdstan
from .io_cmdstanpy import from_cmdstanpy
from .io_datatree import from_datatree, to_datatree
from .io_dict import from_dict
from .io_dict import from_dict, from_pytree
from .io_emcee import from_emcee
from .io_json import from_json, to_json
from .io_netcdf import from_netcdf, to_netcdf
Expand Down Expand Up @@ -38,6 +38,7 @@
"from_cmdstanpy",
"from_datatree",
"from_dict",
"from_pytree",
"from_json",
"from_pyro",
"from_numpyro",
Expand Down
50 changes: 48 additions & 2 deletions arviz/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

import numpy as np
import tree
import xarray as xr

try:
Expand Down Expand Up @@ -66,6 +67,46 @@ def wrapped(cls: RequiresArgTypeT) -> Optional[RequiresReturnTypeT]:

return wrapped

def _yield_flat_up_to(shallow_tree, input_tree, path=()):
"""Yields (path, value) pairs of input_tree flattened up to shallow_tree.
Adapted from dm-tree (https://github.com/google-deepmind/tree) to allow
lists as leaves.
Args:
shallow_tree: Nested structure. Traverse no further than its leaf nodes.
input_tree: Nested structure. Return the paths and values from this tree.
Must have the same upper structure as shallow_tree.
path: Tuple. Optional argument, only used when recursing. The path from the
root of the original shallow_tree, down to the root of the shallow_tree
arg of this recursive call.
Yields:
Pairs of (path, value), where path the tuple path of a leaf node in
shallow_tree, and value is the value of the corresponding node in
input_tree.
"""
# pylint: disable=protected-access
if (isinstance(shallow_tree, tree._TEXT_OR_BYTES) or
not (isinstance(shallow_tree, tree.collections_abc.Mapping) or
tree._is_namedtuple(shallow_tree) or
tree._is_attrs(shallow_tree))):
yield (path, input_tree)
else:
input_tree = dict(tree._yield_sorted_items(input_tree))
for shallow_key, shallow_subtree in tree._yield_sorted_items(shallow_tree):
subpath = path + (shallow_key,)
input_subtree = input_tree[shallow_key]
for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree,
input_subtree,
path=subpath):
yield (leaf_path, leaf_value)
# pylint: enable=protected-access


def _flatten_with_path(structure):
return list(_yield_flat_up_to(structure, structure))


def generate_dims_coords(
shape,
Expand Down Expand Up @@ -255,7 +296,7 @@ def numpy_to_data_array(
return xr.DataArray(ary, coords=coords, dims=dims)


def dict_to_dataset(
def pytree_to_dataset(
data,
*,
attrs=None,
Expand All @@ -266,7 +307,7 @@ def dict_to_dataset(
index_origin=None,
skip_event_dims=None,
):
"""Convert a dictionary of numpy arrays to an xarray.Dataset.
"""Convert a pytree of numpy arrays to an xarray.Dataset.
Parameters
----------
Expand Down Expand Up @@ -302,6 +343,10 @@ def dict_to_dataset(
"""
if dims is None:
dims = {}
try:
data = {'__'.join(map(str, k)): v for k, v in _flatten_with_path(data)}
except TypeError:
pass

data_vars = {
key: numpy_to_data_array(
Expand All @@ -317,6 +362,7 @@ def dict_to_dataset(
}
return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

dict_to_dataset = pytree_to_dataset

def make_attrs(attrs=None, library=None):
"""Make standard attributes to attach to xarray datasets.
Expand Down
4 changes: 4 additions & 0 deletions arviz/data/converters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""High level conversion functions."""
import numpy as np
import tree
import xarray as xr

from .base import dict_to_dataset
Expand Down Expand Up @@ -105,6 +106,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
dataset = obj.to_dataset()
elif isinstance(obj, dict):
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
elif tree.is_nested(obj) and not isinstance(obj, (list, tuple)):
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
elif isinstance(obj, np.ndarray):
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"):
Expand All @@ -118,6 +121,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
"xarray dataarray",
"xarray dataset",
"dict",
"pytree",
"netcdf filename",
"numpy array",
"pystan fit",
Expand Down
2 changes: 2 additions & 0 deletions arviz/data/io_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,5 @@ def from_dict(
attrs=attrs,
**kwargs,
).to_inference_data()

from_pytree = from_dict
13 changes: 13 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,19 @@ def test_dict_to_dataset():
assert set(dataset.a.coords) == {"chain", "draw"}
assert set(dataset.b.coords) == {"chain", "draw", "c"}

def test_nested_dict_to_dataset():
datadict = {"top": {
"a": np.random.randn(100),
"b": np.random.randn(1, 100, 10)},
"d": np.random.randn(100)}
dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={"top__b": ["c"]})
assert set(dataset.data_vars) == {"top__a", "top__b", "d"}
assert set(dataset.coords) == {"chain", "draw", "c"}

assert set(dataset.top__a.coords) == {"chain", "draw"}
assert set(dataset.top__b.coords) == {"chain", "draw", "c"}
assert set(dataset.d.coords) == {"chain", "draw"}


def test_dict_to_dataset_event_dims_error():
datadict = {"a": np.random.randn(1, 100, 10)}
Expand Down

0 comments on commit 78b9baf

Please sign in to comment.