Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Fix for failing healpix dataloader test #696

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modulus/datapipes/healpix/data_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def create_time_series_dataset_classic(
for variable in all_variables:
file_name = _get_file_name(src_directory, prefix, variable, suffix)
logger.debug("open nc dataset %s", file_name)
if "sample" in list(xr.open_dataset(file_name).dims.keys()):
if "sample" in list(xr.open_dataset(file_name).sizes.keys()):
ds = xr.open_dataset(file_name, chunks={"sample": batch_size}).rename(
{"sample": "time"}
)
Expand Down
4 changes: 2 additions & 2 deletions modulus/datapipes/healpix/timeseries_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ def _get_time_index(self, item):
if self.forecast_mode
else (item + 1) * self.batch_size + self._window_length
)
if not self.drop_last and max_index > self.ds.dims["time"]:
batch_size = self.batch_size - (max_index - self.ds.dims["time"])
if not self.drop_last and max_index > self.ds.sizes["time"]:
batch_size = self.batch_size - (max_index - self.ds.sizes["time"])
else:
batch_size = self.batch_size
return (start_index, max_index), batch_size
Expand Down
34 changes: 15 additions & 19 deletions test/models/dlwp_healpix/test_healpix_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,38 +154,34 @@ def test_HEALPixLayer_initialization(device, multiplier):
def test_HEALPixLayer_forward(device, multiplier):
layer = HEALPixLayer(layer=MulX, multiplier=multiplier)

kernel_size = 3
dilation = 2
in_channels = 4
out_channels = 8

tensor_size = torch.randint(low=2, high=4, size=(1,)).tolist()
tensor_size = [24, 4, *tensor_size, *tensor_size]
tensor_size = [24, in_channels, *tensor_size, *tensor_size]
invar = torch.rand(tensor_size, device=device)
outvar = layer(invar)

assert common.compare_output(outvar, invar * multiplier)

# test nhwc mode and dilation
layer = HEALPixLayer(
layer=torch.nn.Conv2d,
in_channels=4,
out_channels=8,
kernel_size=3,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
device=device,
# dilation=4,
)

outvar = layer(invar)

layer = HEALPixLayer(
layer=torch.nn.Conv2d,
in_channels=4,
out_channels=8,
kernel_size=3,
device=device,
dilation=1,
dilation=dilation,
enable_healpixpad=True,
enable_nhwc=True,
)

assert outvar.shape == layer(invar).shape
assert outvar.stride() != layer(invar).stride()
# size of the padding added byu HEALPixLayer
expected_shape = [24, out_channels, tensor_size[-1], tensor_size[-1]]
expected_shape = torch.Size(expected_shape)

assert expected_shape == layer(invar).shape

del layer, outvar, invar
torch.cuda.empty_cache()