Skip to content

Commit

Permalink
tests for anomaly detection module
Browse files Browse the repository at this point in the history
  • Loading branch information
KGallyamov committed May 19, 2024
1 parent b06c208 commit 80ff952
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 11 deletions.
40 changes: 29 additions & 11 deletions tests/fixtures/config/losses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from omegaconf import DictConfig


jaccard_loss_w_target = DictConfig(
{
"name": "Segmentation",
Expand Down Expand Up @@ -63,25 +62,25 @@
"description": "something",
"task": ["text-vae", "text-vae-forward", "text-vae-reverse"],
"implementations": {
"torch":{
"mse":{
"torch": {
"mse": {
"weight": 1.0,
"object":{
"object": {
"_target_": "torch.nn.MSELoss"}
},
},
"target_loss": {
"weight": 1.0,
"object": {
"_target_": "torch.nn.MSELoss"
}
}
},
"kld": {
"weight": 0.1,
"object": {
"_target_": "innofw.core.losses.kld.KLD"
}
}
}
}
}
}
}
)
Expand All @@ -93,15 +92,15 @@
"task": ["text-ner"],
"implementations": {
"torch": {
"FocalLoss":{
"FocalLoss": {
"weight": 1,
"object": {
"_target_": "innofw.core.losses.focal_loss.FocalLoss",
"gamma": 2
}
}
}
}
}
}
}
)

Expand All @@ -123,3 +122,22 @@
},
},
)

l1_loss_w_target = DictConfig(
{
"name": "L1 loss",
"description": "L1 loss",
"task": ["anomaly-detection-timeseries"],
"implementations": {
"torch": {
"L1Loss": {
"weight": 1,
"reduction": "sum",
"object": {
"_target_": "torch.nn.L1Loss"
}
}
}
}
}
)
10 changes: 10 additions & 0 deletions tests/fixtures/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,14 @@
"pretrained_model_name_or_path": "dmis-lab/biobert-base-cased-v1.2"
}
}
)

lstm_autoencoder_w_target = DictConfig(
{
"name": "lstm autoencoder",
"description": "lstm autoencoder",
"_target_": "innofw.core.models.torch.architectures.autoencoders.timeseries_lstm.RecurrentAutoencoder",
"seq_len": 140,
"n_features": 1
}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from omegaconf import DictConfig

from innofw.constants import Frameworks, Stages
from innofw.core.datamodules.lightning_datamodules.anomaly_detection_timeseries_dm import \
TimeSeriesLightningDataModule
from innofw.core.models.torch.lightning_modules import (
AnomalyDetectionTimeSeriesLightningModule
)
from innofw.utils.framework import get_datamodule
from tests.fixtures.config.datasets import anomaly_detection_timeseries_datamodule_cfg_w_target
from innofw.utils.framework import get_losses
from innofw.utils.framework import get_model
from tests.fixtures.config import losses as fixt_losses
from tests.fixtures.config import models as fixt_models
from tests.fixtures.config import optimizers as fixt_optimizers
from tests.fixtures.config import trainers as fixt_trainers


def test_anomaly_detection():
cfg = DictConfig(
{
"models": fixt_models.lstm_autoencoder_w_target,
"trainer": fixt_trainers.trainer_cfg_w_cpu_devices,
"losses": fixt_losses.l1_loss_w_target,
}
)
model = get_model(cfg.models, cfg.trainer)
losses = get_losses(cfg, "anomaly-detection-timeseries", Frameworks.torch)
optimizer_cfg = DictConfig(fixt_optimizers.adam_optim_w_target)

module = AnomalyDetectionTimeSeriesLightningModule(
model=model, losses=losses, optimizer_cfg=optimizer_cfg
)

assert module is not None

datamodule: TimeSeriesLightningDataModule = get_datamodule(
anomaly_detection_timeseries_datamodule_cfg_w_target,
Frameworks.torch,
task="anomaly-detection-timeseries"
)
datamodule.setup(Stages.train)

for stage in ["train", "val"]:
module.stage_step(stage, next(iter(datamodule.train_dataloader())),
do_logging=True)

0 comments on commit 80ff952

Please sign in to comment.