From 2967a803c5e9975b4185b9cae99e995b10e8d971 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Tue, 28 May 2024 13:32:25 +0200 Subject: [PATCH] More mypy fixes --- src/careamics/lightning_datamodule.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/careamics/lightning_datamodule.py b/src/careamics/lightning_datamodule.py index 695ded065..92dc7c86f 100644 --- a/src/careamics/lightning_datamodule.py +++ b/src/careamics/lightning_datamodule.py @@ -260,7 +260,7 @@ def __init__( self.extension_filter: str = extension_filter # Pytorch dataloader parameters - self.dataloader_params = ( + self.dataloader_params: Dict[str, Any] = ( data_config.dataloader_params if data_config.dataloader_params else {} ) @@ -325,6 +325,11 @@ def setup(self, *args: Any, **kwargs: Any) -> None: """ # if numpy array if self.data_type == SupportedData.ARRAY: + # mypy checks + assert isinstance(self.train_data, np.ndarray) + if self.train_data_target is not None: + assert isinstance(self.train_data_target, np.ndarray) + # train dataset self.train_dataset: DatasetType = InMemoryDataset( data_config=self.data_config, @@ -334,6 +339,11 @@ def setup(self, *args: Any, **kwargs: Any) -> None: # validation dataset if self.val_data is not None: + # mypy checks + assert isinstance(self.val_data, np.ndarray) + if self.val_data_target is not None: + assert isinstance(self.val_data_target, np.ndarray) + # create its own dataset self.val_dataset: DatasetType = InMemoryDataset( data_config=self.data_config,