Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/test_coverage' into test_coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
BarzaH committed May 19, 2024
2 parents 56f85b0 + 5fe06d3 commit b06c208
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 7 deletions.
19 changes: 19 additions & 0 deletions tests/fixtures/config/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,25 @@
},
)

multiclass_jaccard_loss_w_target = DictConfig(
{
"name": "Segmentation",
"description": "something",
"task": ["multiclass-image-segmentation"],
"implementations": {
"torch": {
"JaccardLoss": {
"weight": 0.5,
"object": {
"_target_": "pytorch_toolbelt.losses.JaccardLoss",
"mode": "multiclass",
},
},
}
},
},
)

soft_ce_loss_w_target = DictConfig(
{
"name": "Classification",
Expand Down
23 changes: 16 additions & 7 deletions tests/fixtures/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@
}
)

deeplabv3_plus_w_target_multiclass = DictConfig(
{
"name": "deeplabv3plus",
"description": "something",
"_target_": "segmentation_models_pytorch.DeepLabV3Plus",
"classes": 4,
}
)

catboost_cfg_w_target = DictConfig(
{
"name": "catboost",
Expand Down Expand Up @@ -110,18 +119,18 @@
"_target_": "innofw.core.models.torch.architectures.autoencoders.vae.VAE",
"encoder": {
"_target_": "innofw.core.models.torch.architectures.autoencoders.vae.Encoder",
"in_dim": 609, # len(alphabet) * max(len_mols)
"in_dim": 609, # len(alphabet) * max(len_mols)
"hidden_dim": 128,
"enc_out_dim": 128,
},
},
"decoder": {
"_target_": "innofw.core.models.torch.architectures.autoencoders.vae.GRUDecoder",
"latent_dimension": 128,
"gru_stack_size": 3,
"gru_neurons_num": 128,
"out_dimension": 29, # len(alphabet)
}

}
)

Expand All @@ -130,13 +139,13 @@
"name": "biobert-ner",
"description": "bert for token classification biobert-base-cased-v1.2",
"_target_": "innofw.core.models.torch.architectures.token_classification.biobert_ner.BiobertNer",
"model":{
"model": {
"_target_": "transformers.BertForTokenClassification.from_pretrained",
"pretrained_model_name_or_path": "dmis-lab/biobert-base-cased-v1.2"
},
"tokenizer":{
},
"tokenizer": {
"_target_": "transformers.BertTokenizerFast.from_pretrained",
"pretrained_model_name_or_path": "dmis-lab/biobert-base-cased-v1.2"
}
}
}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule
from pytorch_lightning import LightningModule
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from innofw.constants import Frameworks
from innofw.constants import SegDataKeys, SegOutKeys
from innofw.core.models.torch.lightning_modules.segmentation import (
MulticlassSemanticSegmentationLightningModule
)
from innofw.utils.framework import get_losses
from innofw.utils.framework import get_model
from tests.fixtures.config import losses as fixt_losses
from tests.fixtures.config import models as fixt_models
from tests.fixtures.config import optimizers as fixt_optimizers
from tests.fixtures.config import trainers as fixt_trainers


class MultiSegDummyDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples

def __getitem__(self, index):
x = torch.rand(3, 224, 224)
y = torch.randint(0, 4, (1, 224, 224))
return {SegDataKeys.image: x, SegDataKeys.label: y}

def __len__(self):
return self.num_samples


class MultiSegDummyDataModule(LightningDataModule):
def __init__(self, num_samples: int, batch_size: int = 4):
super().__init__()
self.num_samples = num_samples
self.batch_size = batch_size

def setup(self, stage=None):
self.dataset = MultiSegDummyDataset(self.num_samples)

def train_dataloader(self):
return DataLoader(self.dataset, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.dataset, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.dataset, batch_size=self.batch_size)


def test_multiclasssegmentation_module() -> LightningModule:
cfg = DictConfig(
{
"models": fixt_models.deeplabv3_plus_w_target_multiclass,
"trainer": fixt_trainers.trainer_cfg_w_cpu_devices,
"losses": fixt_losses.multiclass_jaccard_loss_w_target,
}
)
model = get_model(cfg.models, cfg.trainer)
losses = get_losses(cfg, "multiclass-image-segmentation", Frameworks.torch)
optimizer_cfg = DictConfig(fixt_optimizers.adam_optim_w_target)

module = MulticlassSemanticSegmentationLightningModule(
model=model, losses=losses, optimizer_cfg=optimizer_cfg
)

assert module is not None

datamodule = MultiSegDummyDataModule(num_samples=8)
datamodule.setup()

for stage in ["train", "val"]:
output = module.stage_step(stage, next(iter(datamodule.train_dataloader())),
do_logging=True)
assert output[SegOutKeys.predictions].shape == (4, 4, 224, 224)

0 comments on commit b06c208

Please sign in to comment.