Skip to content

Commit

Permalink
Rewrite mask_between_doys to add spatial dims support
Browse files Browse the repository at this point in the history
  • Loading branch information
aulemahal committed Jan 31, 2025
1 parent 36b1862 commit ac20249
Showing 1 changed file with 70 additions and 74 deletions.
144 changes: 70 additions & 74 deletions src/xclim/core/calendar.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ def doy_to_days_since(
# 2cases:
# val is a day in the same year as its index : da - offset
# val is a day in the next year : da + doy_max - offset
out = xr.where(dac > base_doy, dac, dac + doy_max) - start_doy
out = xr.where(dac >= base_doy, dac, dac + doy_max) - start_doy
out.attrs.update(da.attrs)
if start is not None:
out.attrs.update(units=f"days after {start}")
Expand Down Expand Up @@ -1156,91 +1156,87 @@ def mask_between_doys(
Parameters
----------
da : xr.DataArray or xr.Dataset
Input data.
doy_bounds : 2-tuple of integers or xr.DataArray
Input data. It must have a time coordinate.
doy_bounds : 2-tuple of integers or DataArray
The bounds as (start, end) of the period of interest expressed in day-of-year, integers going from
1 (January 1st) to 365 or 366 (December 31st). If a combination of int and xr.DataArray is given,
the int day-of-year corresponds to the year of the xr.DataArray.
1 (January 1st) to 365 or 366 (December 31st).
If DataArrays are passed, they must have the same coordinates on the dimensions they share.
They may have a time dimension, in which case the masking is done independently for each period defined by the coordinate,
which means the time coordinate must have an inferable frequency (see :py:func:`xr.infer_freq`).
Timesteps of the input not appearing in the time coordinate of the bounds are masked as "outside the bounds".
Missing values (nan) in the bounds are treated as an open bound (same as a None in a slice).
include_bounds : 2-tuple of booleans
Whether the bounds of `doy_bounds` should be inclusive or not.
Returns
-------
xr.DataArray or xr.Dataset
Boolean mask array with the same shape as `da` with True value inside the period of
interest and False outside.
xr.DataArray
Boolean array with the same time coordinate as `da` and any other dimension present on the bounds.
True value inside the period of interest and False outside.
"""
if isinstance(doy_bounds[0], int) and isinstance(doy_bounds[1], int):
if isinstance(doy_bounds[0], int) and isinstance(doy_bounds[1], int): # Simple case
mask = da.time.dt.dayofyear.isin(_get_doys(*doy_bounds, include_bounds))

else:
cal = get_calendar(da, dim="time")

start, end = doy_bounds
# convert ints to DataArrays
if isinstance(start, int):
start = xr.where(end.isnull(), np.nan, start)
start = start.convert_calendar(cal)
start.attrs["calendar"] = cal
else:
start = start.convert_calendar(cal)
start.attrs["calendar"] = cal
start = doy_to_days_since(start)

if isinstance(end, int):
end = xr.where(start.isnull(), np.nan, end)
end = end.convert_calendar(cal)
end.attrs["calendar"] = cal
else:
end = end.convert_calendar(cal)
end.attrs["calendar"] = cal
end = doy_to_days_since(end)

freq = []
for bound in [start, end]:
try:
freq.append(xr.infer_freq(bound.time))
except ValueError:
freq.append(None)
freq = set(freq) - {None}
if len(freq) != 1:
raise ValueError(
f"Non-inferrable resampling frequency or inconsistent frequencies. Got start, end = {freq}."
start = xr.full_like(end, start)
elif isinstance(end, int):
end = xr.full_like(start, end)
# Ensure they both have the same dims
# align join='exact' will fail on common but different coords, broadcast will add missing coords
start, end = xr.broadcast(*xr.align(start, end, join="exact"))

if not include_bounds[0]:
start += 1
if not include_bounds[1]:
end -= 1

if "time" in start.dims:
freq = xr.infer_freq(start.time)
# Convert the doy bounds to a duration since the beginning of each period defined in the bound's time coordinate
# Also ensures the bounds share the sime time calendar as the input
# Any missing value is replaced with the min/max of possible values
calkws = dict(
calendar=da.time.dt.calendar, use_cftime=(da.time.dtype == "O")
)
start = doy_to_days_since(start.convert_calendar(**calkws)).fillna(0)
end = doy_to_days_since(end.convert_calendar(**calkws)).fillna(366)

out = []
# For each period, mask the days since between start and end
for base_time, indexes in da.resample(time=freq).groups.items():
group = da.isel(time=indexes)

if base_time in start.time:
start_d = start.sel(time=base_time)
end_d = end.sel(time=base_time)

# select days between start and end for group
days = (group.time - base_time).dt.days
days = days.where(days >= 0)
mask = (days >= start_d) & (days <= end_d)
else: # This group has no defined bounds : put False in the mask
# Array with the same shape as the "mask" in the other case : broadcast of time and bounds dims
template = xr.broadcast(
group.time.dt.day, start.isel(time=0, drop=True)
)[0]
mask = xr.full_like(template, False, dtype="bool")
out.append(mask)
mask = xr.concat(out, dim="time")
else: # Only "Spatial" dims, we can't constrain as in days since, so there are two cases
doys = da.time.dt.dayofyear # for readability
# Any missing value is replaced with the min/max of possible values
start = start.fillna(1)
end = end.fillna(366)
mask = xr.where(
start <= end,
(doys >= start)
& (doys <= end), # case 1 : start <= end, ROI is within a calendar year
~(
(doys >= end) & (doys <= start)
), # case 2 : start > end, ROI crosses the new year
)
else:
freq = freq.pop()

out = []
for base_time, indexes in da.resample(time=freq).groups.items():
# get group slice
group = da.isel(time=indexes)

if base_time in start.time:
start_d = start.sel(time=base_time)
else:
start_d = None
if base_time in end.time:
end_d = end.sel(time=base_time)
else:
end_d = None

if start_d is not None and end_d is not None:
if not include_bounds[0]:
start_d += 1
if not include_bounds[1]:
end_d -= 1

# select days between start and end for group
days = (group.time - base_time).dt.days
days[days < 0] = np.nan

mask = (days >= start_d) & (days <= end_d)
else:
# Get an array with the good shape and put False
mask = start.isel(time=0).drop_vars("time").expand_dims(time=group.time)
mask = xr.full_like(mask, False)

out.append(mask)
mask = xr.concat(out, dim="time")
return mask


Expand Down

0 comments on commit ac20249

Please sign in to comment.