Skip to content

Commit

Permalink
Add option to specify chunks in broadcast_to_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
bouweandela committed Dec 12, 2023
1 parent c11d8bb commit 0bb5c47
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 4 deletions.
27 changes: 27 additions & 0 deletions lib/iris/tests/unit/util/test_broadcast_to_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,33 @@ def test_lazy_masked(self, mocked_compute):
b[i, :, j, :].compute().T, m.compute()
)

@mock.patch.object(dask.base, "compute", wraps=dask.base.compute)
def test_lazy_chunks(self, mocked_compute):
# chunks can be specified along with the target shape and are only used
# along new dimensions or on dimensions that have size 1 in the source
# array.
m = da.ma.masked_array(
data=[[1, 2, 3, 4, 5]],
mask=[[0, 1, 0, 0, 0]],
).rechunk((1, 2))
b = broadcast_to_shape(
m,
dim_map=(1, 2),
shape=(3, 4, 5),
chunks=(
1, # used because target is new dim
2, # used because input size 1
3, # not used because broadcast does not rechunk
),
)
mocked_compute.assert_not_called()
for i in range(3):
for j in range(4):
self.assertMaskedArrayEqual(
b[i, j, :].compute(), m[0].compute()
)
assert b.chunks == ((1, 1, 1), (2, 2), (2, 2, 1))

def test_masked_degenerate(self):
# masked arrays can have degenerate masks too
a = np.random.random([2, 3])
Expand Down
30 changes: 26 additions & 4 deletions lib/iris/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import iris.exceptions


def broadcast_to_shape(array, shape, dim_map):
def broadcast_to_shape(array, shape, dim_map, chunks=None):
"""
Broadcast an array to a given shape.
Expand All @@ -53,6 +53,14 @@ def broadcast_to_shape(array, shape, dim_map):
to, so the first element of *dim_map* gives the index of *shape*
that corresponds to the first dimension of *array* etc.
* chunks :class:`tuple`, optional
If the source array is a :class:`dask.array.Array` and a value is
provided, then the result will use these chunks instead of the same
chunks as the source array. Setting chunks explicitly as part of
broadcast_to_shape is more efficient than rechunking afterwards. The
values provided here will only be used along dimensions that are new on
the result or have size 1 on the source array.
Examples:
Broadcasting an array of shape (2, 3) to the shape (5, 2, 6, 3)
Expand All @@ -74,27 +82,41 @@ def broadcast_to_shape(array, shape, dim_map):
See more at :doc:`/userguide/real_and_lazy_data`.
"""
if isinstance(array, da.Array):
if chunks is not None:
chunks = list(chunks)
for src_idx, tgt_idx in enumerate(dim_map):
# Only use the specified chunks along new dimensions or on
# dimensions that have size 1 in the source array.
if array.shape[src_idx] != 1:
chunks[tgt_idx] = array.chunks[src_idx]
broadcast = functools.partial(
da.broadcast_to, shape=shape, chunks=chunks
)
else:
broadcast = functools.partial(np.broadcast_to, shape=shape)

n_orig_dims = len(array.shape)
n_new_dims = len(shape) - n_orig_dims
array = array.reshape(array.shape + (1,) * n_new_dims)

# Get dims in required order.
array = np.moveaxis(array, range(n_orig_dims), dim_map)
new_array = np.broadcast_to(array, shape)
new_array = broadcast(array)

if ma.isMA(array):
# broadcast_to strips masks so we need to handle them explicitly.
mask = ma.getmask(array)
if mask is ma.nomask:
new_mask = ma.nomask
else:
new_mask = np.broadcast_to(mask, shape)
new_mask = broadcast(mask)
new_array = ma.array(new_array, mask=new_mask)

elif is_lazy_masked_data(array):
# broadcast_to strips masks so we need to handle them explicitly.
mask = da.ma.getmaskarray(array)
new_mask = da.broadcast_to(mask, shape)
new_mask = broadcast(mask)
new_array = da.ma.masked_array(new_array, new_mask)

return new_array
Expand Down

0 comments on commit 0bb5c47

Please sign in to comment.