Skip to content

Commit

Permalink
More mypy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed May 28, 2024
1 parent 3039971 commit 2967a80
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/careamics/lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def __init__(
self.extension_filter: str = extension_filter

# Pytorch dataloader parameters
self.dataloader_params = (
self.dataloader_params: Dict[str, Any] = (
data_config.dataloader_params if data_config.dataloader_params else {}
)

Expand Down Expand Up @@ -325,6 +325,11 @@ def setup(self, *args: Any, **kwargs: Any) -> None:
"""
# if numpy array
if self.data_type == SupportedData.ARRAY:
# mypy checks
assert isinstance(self.train_data, np.ndarray)
if self.train_data_target is not None:
assert isinstance(self.train_data_target, np.ndarray)

# train dataset
self.train_dataset: DatasetType = InMemoryDataset(
data_config=self.data_config,
Expand All @@ -334,6 +339,11 @@ def setup(self, *args: Any, **kwargs: Any) -> None:

# validation dataset
if self.val_data is not None:
# mypy checks
assert isinstance(self.val_data, np.ndarray)
if self.val_data_target is not None:
assert isinstance(self.val_data_target, np.ndarray)

# create its own dataset
self.val_dataset: DatasetType = InMemoryDataset(
data_config=self.data_config,
Expand Down

0 comments on commit 2967a80

Please sign in to comment.