From 80ff9528c8cca80e9aaf169c3bcdd7c777736e9e Mon Sep 17 00:00:00 2001 From: Karim Galliamov <45235968+KGallyamov@users.noreply.github.com> Date: Sun, 19 May 2024 15:35:19 +0300 Subject: [PATCH] tests for anomaly detection module --- tests/fixtures/config/losses.py | 40 +++++++++++----- tests/fixtures/config/models.py | 10 ++++ .../test_anomaly_detection.py | 46 +++++++++++++++++++ 3 files changed, 85 insertions(+), 11 deletions(-) create mode 100644 tests/integration/models/torch/lighting_modules/test_anomaly_detection.py diff --git a/tests/fixtures/config/losses.py b/tests/fixtures/config/losses.py index 23cc088e..32a147ed 100644 --- a/tests/fixtures/config/losses.py +++ b/tests/fixtures/config/losses.py @@ -1,6 +1,5 @@ from omegaconf import DictConfig - jaccard_loss_w_target = DictConfig( { "name": "Segmentation", @@ -63,25 +62,25 @@ "description": "something", "task": ["text-vae", "text-vae-forward", "text-vae-reverse"], "implementations": { - "torch":{ - "mse":{ + "torch": { + "mse": { "weight": 1.0, - "object":{ + "object": { "_target_": "torch.nn.MSELoss"} - }, + }, "target_loss": { "weight": 1.0, "object": { "_target_": "torch.nn.MSELoss" - } + } }, "kld": { "weight": 0.1, "object": { "_target_": "innofw.core.losses.kld.KLD" - } - } } + } + } } } ) @@ -93,15 +92,15 @@ "task": ["text-ner"], "implementations": { "torch": { - "FocalLoss":{ + "FocalLoss": { "weight": 1, "object": { "_target_": "innofw.core.losses.focal_loss.FocalLoss", "gamma": 2 } } - } - } + } + } } ) @@ -123,3 +122,22 @@ }, }, ) + +l1_loss_w_target = DictConfig( + { + "name": "L1 loss", + "description": "L1 loss", + "task": ["anomaly-detection-timeseries"], + "implementations": { + "torch": { + "L1Loss": { + "weight": 1, + "reduction": "sum", + "object": { + "_target_": "torch.nn.L1Loss" + } + } + } + } + } +) diff --git a/tests/fixtures/config/models.py b/tests/fixtures/config/models.py index 8e2e0d12..932482f4 100644 --- a/tests/fixtures/config/models.py +++ b/tests/fixtures/config/models.py @@ -148,4 +148,14 @@ "pretrained_model_name_or_path": "dmis-lab/biobert-base-cased-v1.2" } } +) + +lstm_autoencoder_w_target = DictConfig( + { + "name": "lstm autoencoder", + "description": "lstm autoencoder", + "_target_": "innofw.core.models.torch.architectures.autoencoders.timeseries_lstm.RecurrentAutoencoder", + "seq_len": 140, + "n_features": 1 + } ) \ No newline at end of file diff --git a/tests/integration/models/torch/lighting_modules/test_anomaly_detection.py b/tests/integration/models/torch/lighting_modules/test_anomaly_detection.py new file mode 100644 index 00000000..e41226e8 --- /dev/null +++ b/tests/integration/models/torch/lighting_modules/test_anomaly_detection.py @@ -0,0 +1,46 @@ +from omegaconf import DictConfig + +from innofw.constants import Frameworks, Stages +from innofw.core.datamodules.lightning_datamodules.anomaly_detection_timeseries_dm import \ + TimeSeriesLightningDataModule +from innofw.core.models.torch.lightning_modules import ( + AnomalyDetectionTimeSeriesLightningModule +) +from innofw.utils.framework import get_datamodule +from tests.fixtures.config.datasets import anomaly_detection_timeseries_datamodule_cfg_w_target +from innofw.utils.framework import get_losses +from innofw.utils.framework import get_model +from tests.fixtures.config import losses as fixt_losses +from tests.fixtures.config import models as fixt_models +from tests.fixtures.config import optimizers as fixt_optimizers +from tests.fixtures.config import trainers as fixt_trainers + + +def test_anomaly_detection(): + cfg = DictConfig( + { + "models": fixt_models.lstm_autoencoder_w_target, + "trainer": fixt_trainers.trainer_cfg_w_cpu_devices, + "losses": fixt_losses.l1_loss_w_target, + } + ) + model = get_model(cfg.models, cfg.trainer) + losses = get_losses(cfg, "anomaly-detection-timeseries", Frameworks.torch) + optimizer_cfg = DictConfig(fixt_optimizers.adam_optim_w_target) + + module = AnomalyDetectionTimeSeriesLightningModule( + model=model, losses=losses, optimizer_cfg=optimizer_cfg + ) + + assert module is not None + + datamodule: TimeSeriesLightningDataModule = get_datamodule( + anomaly_detection_timeseries_datamodule_cfg_w_target, + Frameworks.torch, + task="anomaly-detection-timeseries" + ) + datamodule.setup(Stages.train) + + for stage in ["train", "val"]: + module.stage_step(stage, next(iter(datamodule.train_dataloader())), + do_logging=True)