diff --git a/innofw/core/datamodules/lightning_datamodules/detection_coco.py b/innofw/core/datamodules/lightning_datamodules/detection_coco.py index 8dc3e1e2..37db42ad 100644 --- a/innofw/core/datamodules/lightning_datamodules/detection_coco.py +++ b/innofw/core/datamodules/lightning_datamodules/detection_coco.py @@ -17,7 +17,6 @@ from innofw.core.datasets.coco import CocoDataset from innofw.core.datasets.coco import DicomCocoDataset from innofw.core.datasets.coco import DicomCocoDatasetInfer -from innofw.core.datasets.coco import DicomCocoDatasetRTK from innofw.utils.data_utils.preprocessing.dicom_handler import dicom_to_img from innofw.utils.data_utils.preprocessing.dicom_handler import img_to_dicom from innofw.utils.dm_utils.utils import find_file_by_ext @@ -238,167 +237,3 @@ def setup_infer(self): str(self.infer), Augmentation(aug), ) - - -class CustomNormalize: - def __call__(self, image, **kwargs): - image = (image - image.min()) / (image.max() - image.min() + 1e-8) - return image - - -class DicomCocoComplexingDataModule(BaseLightningDataModule): - task = ["image-detection", "image-segmentation"] - dataset = DicomCocoDatasetRTK - - def __init__( - self, - train=None, - test=None, - infer=None, - val_size: float = 0.2, - num_workers: int = 1, - augmentations=None, - stage=None, - batch_size=32, - transform=None, - val_split=0.2, - test_split=0.1, - *args, - **kwargs, - ): - super().__init__( - train, - test, - infer, - batch_size, - num_workers, - stage, - *args, - **kwargs, - ) - - def setup(self, stage=None): - pass - - def setup_train_test_val(self, **kwargs): - pass - - def setup_infer(self): - if self.aug: - transform = Augmentation(self.aug["test"]) - else: - - transform = albu.Compose( - [ - albu.Resize(256, 256), - albu.Lambda(image=CustomNormalize()), - ToTensorV2(transpose_mask=True), - ] - ) - if str(self.predict_source).endswith("mrt"): - self.predict_source = self.predict_source.parent - cont = os.listdir(self.predict_source) - assert "ct" in cont, f"No CT data in {self.predict_source}" - assert "mrt" in cont, f"No MRT data in {self.predict_source}" - - self.predict_dataset = [ - self.dataset( - data_dir=os.path.join(str(self.predict_source), "ct"), - transform=transform, - ), - self.dataset( - data_dir=os.path.join(str(self.predict_source), "mrt"), - transform=transform, - ), - ] - self.predict_dataset = torch.utils.data.ConcatDataset(self.predict_dataset) - - def train_dataloader(self): - return torch.utils.data.DataLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=True, - num_workers=self.num_workers, - ) - - def val_dataloader(self): - return torch.utils.data.DataLoader( - self.val_dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - ) - - def test_dataloader(self): - return torch.utils.data.DataLoader( - self.test_dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - ) - - def predict_dataloader(self): - """shuffle should be turned off""" - return torch.utils.data.DataLoader( - self.predict_dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - ) - - def save_preds(self, preds, stage: Stages, dst_path: pathlib.Path): - """we assume that shuffle is turned off - - Args: - preds: - stage: - dst_path: - - Returns: - - """ - - total_iter = 0 - for tensor_batch in preds: - for i in range(tensor_batch.shape[0]): - path = self.predict_dataset[total_iter]["path"] - output = tensor_batch[i].cpu().detach().numpy() - output = np.max(output, axis=0) - output = np.expand_dims(output, axis=0) - output = np.transpose(output, (1, 2, 0)) - if "/ct/" in path: - prefix = "_ct" - else: - prefix = "_mrt" - path = os.path.join(dst_path, f"{prefix}_{total_iter}.npy") - np.save(path, output) - total_iter += 1 - - -class DicomCocoDataModuleRTK(DicomCocoComplexingDataModule): - def setup_infer(self): - if self.aug: - transform = Augmentation(self.aug["test"]) - else: - - transform = albu.Compose( - [ - albu.Resize(256, 256), - albu.Lambda(image=CustomNormalize()), - ToTensorV2(transpose_mask=True), - ] - ) - self.predict_dataset = self.dataset( - data_dir=str(self.predict_source), transform=transform - ) - - def save_preds(self, preds, stage: Stages, dst_path: pathlib.Path): - prefix = "mask" - for batch_idx, tensor_batch in enumerate(preds): - for i in range(tensor_batch.shape[0]): - output = tensor_batch[i].cpu().detach().numpy() - output = np.max(output, axis=0) - output = np.expand_dims(output, axis=0) - output = np.transpose(output, (1, 2, 0)) - path = os.path.join(dst_path, f"{prefix}_{batch_idx}_{i}.npy") - np.save(path, output) diff --git a/innofw/core/datasets/coco.py b/innofw/core/datasets/coco.py index adc04018..6493b364 100644 --- a/innofw/core/datasets/coco.py +++ b/innofw/core/datasets/coco.py @@ -1,18 +1,12 @@ import os -import json from pathlib import Path import cv2 import numpy as np import torch -import pydicom -from gitdb.util import basename from torch.utils.data import Dataset -from innofw.utils.data_utils.preprocessing.dicom_handler import ( - dicom_to_img, - dicom_to_raster, -) +from innofw.utils.data_utils.preprocessing.dicom_handler import dicom_to_img class CocoDataset(Dataset): @@ -128,7 +122,9 @@ class DicomCocoDatasetInfer(Dataset): def __init__(self, dicom_dir, transforms=None): self.images = [] - self.paths = [os.path.join(dicom_dir, d) for d in os.listdir(dicom_dir)] + self.paths = [ + os.path.join(dicom_dir, d) for d in os.listdir(dicom_dir) + ] for dicom in self.paths: self.images.append(transforms(image=dicom_to_img(dicom))["image"]) @@ -170,7 +166,9 @@ def __init__(self, annotations, root_dir, transforms=None): self.root_dir = Path(root_dir) self.image_list = annotations["image_name"].values self.domain_list = annotations["domain"].values - self.boxes = [self.decodeString(item) for item in annotations["BoxesString"]] + self.boxes = [ + self.decodeString(item) for item in annotations["BoxesString"] + ] self.transforms = transforms def __len__(self): @@ -219,189 +217,7 @@ def decodeString(self, BoxesString): return boxes except: print(BoxesString) - print("Submission is not well formatted. empty boxes will be returned") + print( + "Submission is not well formatted. empty boxes will be returned" + ) return np.zeros((0, 4)) - - -class DicomCocoDatasetRTK(Dataset): - def __init__(self, *args, **kwargs): - """ - Args: - data_dir (str): Путь к директории с DICOM файлами и COCO аннотациями. - transform (callable, optional): Трансформации, применяемые к изображениям и маскам. - """ - data_dir = kwargs["data_dir"] - data_dir = os.path.abspath(data_dir) - assert os.path.isdir(data_dir), f"Invalid path {data_dir}" - self.transform = kwargs.get("transform", None) - - # Поиск COCO аннотаций в директории - self.dicom_paths = [] - - coco_path = None - for root, _, files in os.walk(data_dir): - - for file in files: - basename = os.path.basename(file) - filename, ext = os.path.splitext(basename) - if ext == ".json": - coco_path = os.path.join(data_dir, root, file) - elif ext in ["", ".dcm"]: - dicom_path = os.path.join(data_dir, root, file) - if pydicom.misc.is_dicom(dicom_path): - self.dicom_paths += [dicom_path] - if not coco_path: - # raise FileNotFoundError( - print( - f"COCO аннотации не найдены в директории {data_dir}." - ) - self.coco_found = False - else: - self.coco_found = True - - if not self.dicom_paths: - raise FileNotFoundError(f"Dicom не найдены в директории {data_dir}.") - - import re - - def extract_digits(s): - out = re.findall(r"\d+", s) - out = "".join(out) - return int(out) - - # Загрузка COCO аннотаций - if self.coco_found: - with open(coco_path, "r") as f: - self.coco = json.load(f) - self.categories = self.coco["categories"] - self.annotations = self.coco["annotations"] - self.num_classes = len(self.categories) - - self.images = self.coco["images"] - self.image_id_to_annotations = {image["id"]: [] for image in self.images} - for ann in self.annotations: - self.image_id_to_annotations[ann["image_id"]].append(ann) - - if len(self.images) != len(self.dicom_paths): - new_images = [] - for img in self.images: - for dicom_path in self.dicom_paths: - if dicom_path.endswith(img["file_name"]): - new_images += [img] - self.images = new_images - - - self.images.sort(key=lambda x: extract_digits(x["file_name"])) - else: - self.dicom_paths.sort() - - def __len__(self): - if self.coco_found: - return len(self.images) - else: - return len(self.dicom_paths) - - def get_dicom(self, i): - - dicom_path = self.dicom_paths[i] - dicom = pydicom.dcmread(dicom_path) - image = dicom_to_raster(dicom) - - if self.transform: - transformed = self.transform(image=image) - image = transformed["image"] - - if type(image) == torch.Tensor: - image = image.float() - - out = {"image": image, "path": dicom_path} - return out - - - def __getitem__(self, idx): - """ - - Args: - idx: - - Returns: - A dictionary with keys - "image": image - "mask": mask - "path": dicom_path - "raw_image": dicom_image - - - """ - if not self.coco_found: - return self.get_dicom(idx) - image_info = self.images[idx] - for dicom_path in self.dicom_paths: - if dicom_path.endswith(image_info["file_name"]): - break - else: - print(self.dicom_paths, image_info["file_name"]) - raise FileNotFoundError(f"Dicom {dicom_path} не найден.") - dicom = pydicom.dcmread(dicom_path) - image = dicom_to_raster(dicom) - - anns = self.image_id_to_annotations[image_info["id"]] - mask = self.get_mask(anns, image_info) - - if self.transform: - transformed = self.transform(image=image, mask=mask) - image = transformed["image"] - mask = transformed["mask"] - - raw = dicom.pixel_array - - if type(image) == torch.Tensor: - image = image.float() - shape = image.shape[1:] - add_raw = False - else: - shape = image.shape[:2] - add_raw = True - - out = {"image": image, "mask": mask, "path": dicom_path} - - if add_raw: - if raw.shape[:2] != shape: - # no need to apply all transforms - raw = cv2.resize(raw, shape) - out["raw_image"] = raw - return out - - def get_mask(self, anns, image_info): - mask = np.zeros( - (image_info["height"], image_info["width"], self.num_classes), - dtype=np.uint8, - ) - for ann in anns: - segmentation = ann["segmentation"] - category_id = ( - ann["category_id"] - 1 - ) # Приведение category_id к индексу слоя - if isinstance(segmentation, list): # полигональная аннотация - for polygon in segmentation: - poly_mask = self._polygon_to_mask( - polygon, image_info["height"], image_info["width"] - ) - mask[:, :, category_id][poly_mask > 0] = 1 - return mask - - @staticmethod - def _polygon_to_mask(polygon, height, width): - mask = np.zeros((height, width), dtype=np.uint8) - polygon = np.array(polygon).reshape(-1, 2) - mask = cv2.fillPoly(mask, [polygon.astype(int)], color=1) - return mask - - def setup_infer(self): - pass - - def infer_dataloader(self): - return self - - def predict_dataloader(self): - return self