Skip to content

Commit

Permalink
Drop z dimension from ImageStack DataArray before shading (#6378)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Sep 19, 2024
1 parent 88996ed commit b729e34
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 6 deletions.
8 changes: 6 additions & 2 deletions holoviews/core/data/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def retrieve_unit_and_label(dim):
if isinstance(data, xr.DataArray):
kdim_len = len(kdim_param.default) if kdims is None else len(kdims)
vdim_len = len(vdim_param.default) if vdims is None else len(vdims)
if vdim_len > 1 and kdim_len == len(data.dims)-1 and data.shape[-1] == vdim_len:
if kdim_len == len(data.dims)-1 and data.shape[-1] == vdim_len:
packed = True
elif vdims:
vdim = vdims[0]
Expand Down Expand Up @@ -446,7 +446,11 @@ def unpack_scalar(cls, dataset, data):
Given a dataset object and data in the appropriate format for
the interface, return a simple scalar.
"""
if not cls.packed(dataset) and len(data.data_vars) == 1:
if cls.packed(dataset):
array = data.squeeze()
if len(array.shape) == 0:
return array.item()
elif len(data.data_vars) == 1:
array = data[dataset.vdims[0].name].squeeze()
if len(array.shape) == 0:
return array.item()
Expand Down
26 changes: 22 additions & 4 deletions holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,17 +1286,35 @@ def _process(self, element, key=None):
ydensity = element.ydensity
bounds = element.bounds

# Convert to xarray if not already
if element.interface.datatype != 'xarray':
element = element.clone(datatype=['xarray'])

kdims = element.kdims
if isinstance(element, ImageStack):
vdim = element.vdims
array = element.data
if hasattr(array, "to_array"):
array = array.to_array("z")
array = array.transpose(*[kdim.name for kdim in kdims], ...)
# If data is a xarray Dataset it has to be converted to a
# DataArray, either by selecting the singular value
# dimension or by adding a z-dimension
kdims = [kdim.name for kdim in kdims]
if not element.interface.packed(element):
if len(vdim) == 1:
array = array[vdim[0].name]
else:
array = array.to_array("z")
# If data is 3D then we have one extra constant dimension
if array.ndim > 3:
drop = [d for d in array.dims if d not in kdims+["z"]]
array = array.squeeze(dim=drop)
array = array.transpose(*kdims, ...)
else:
vdim = element.vdims[0].name
array = element.data[vdim]

# Dask is not supported by shade so materialize it
array = array.compute()

shade_opts = dict(
how=self.p.cnorm, min_alpha=self.p.min_alpha, alpha=self.p.alpha
)
Expand Down Expand Up @@ -1335,7 +1353,7 @@ def _process(self, element, key=None):
if self.p.clims:
shade_opts['span'] = self.p.clims
elif ds_version > Version('0.5.0') and self.p.cnorm != 'eq_hist':
shade_opts['span'] = element.range(vdim)
shade_opts['span'] = (array.min().item(), array.max().item())

params = dict(get_param_values(element), kdims=kdims,
bounds=bounds, vdims=RGB.vdims[:],
Expand Down
16 changes: 16 additions & 0 deletions holoviews/tests/operation/test_datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,22 @@ def test_aggregate_points_categorical(self):
actual = img.data
assert (expected.data.to_array("z").values == actual.T.values).all()

def test_aggregate_points_categorical_one_category(self):
points = Points([(0.2, 0.3, 'A'), (0.4, 0.7, 'A'), (0, 0.99, 'A')], vdims='z')
img = aggregate(points, dynamic=False, x_range=(0, 1), y_range=(0, 1),
width=2, height=2, aggregator=ds.by('z', ds.count()))
x = np.array([0.25, 0.75])
y = np.array([0.25, 0.75])
a = np.array([[1, 2], [0, 0]])
xrds = xr.DataArray(
a,
dims=('x', 'y'),
coords={"x": x, "y": y}
)
expected = ImageStack(xrds, kdims=["x", "y"], vdims=["a"])
actual = img.data
assert (expected.data.to_array("z").values == actual.T.values).all()

def test_aggregate_points_categorical_mean(self):
points = Points([(0.2, 0.3, 'A', 0.1), (0.4, 0.7, 'B', 0.2), (0, 0.99, 'C', 0.3)], vdims=['cat', 'z'])
img = aggregate(points, dynamic=False, x_range=(0, 1), y_range=(0, 1),
Expand Down

0 comments on commit b729e34

Please sign in to comment.