-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Append ECG neural network tune example (#199)
* I supplemented the documentation with a paragraph about the work of the framework with the optimal selection of two real and one discrete parameters. Corrected the problem code for finding real and discrete parameters. * correct target score * Append new examples. Correct documentation * Corrected documentation of examples
- Loading branch information
Showing
9 changed files
with
533 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
145 changes: 145 additions & 0 deletions
145
examples/Machine_learning/NeuralNetwork/Segmentation/Problem/Cardio2D.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import random | ||
|
||
from examples.Machine_learning.NeuralNetwork.Segmentation.scripts.dataset import SegmentationDataset | ||
from examples.Machine_learning.NeuralNetwork.Segmentation.scripts.metric import AllMetricTracker | ||
from iOpt.trial import Point | ||
from iOpt.trial import FunctionValue | ||
from iOpt.problem import Problem | ||
from typing import Dict | ||
from datetime import datetime | ||
import os | ||
from sklearn.model_selection import train_test_split | ||
from torch.utils.data import DataLoader | ||
from lightning.pytorch import Trainer | ||
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping | ||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
from lightning.pytorch import LightningModule | ||
from examples.Machine_learning.NeuralNetwork.Segmentation.scripts.metric import SegmentationMetric | ||
from examples.Machine_learning.NeuralNetwork.Segmentation.scripts.model import Encoder, Decoder, UNet | ||
|
||
|
||
class UnetModule(LightningModule): | ||
def __init__(self, kernel_size=23, q=1.2, label_smoothing=0, p=0.75): | ||
super().__init__() | ||
self.save_hyperparameters() | ||
encoder = Encoder(12, kernel_size=kernel_size, q=q, p=p) | ||
decoder = Decoder(encoder, 4) | ||
|
||
self.model = UNet(encoder, decoder) | ||
self.loss = nn.CrossEntropyLoss(ignore_index=4, label_smoothing=label_smoothing) | ||
|
||
self.p_metric = SegmentationMetric('p', 'all', return_type='f1', samples=150) | ||
self.t_metric = SegmentationMetric('t', 'all', return_type='f1', samples=150) | ||
self.qrs_metric = SegmentationMetric('qrs', 'all', return_type='f1', samples=150) | ||
|
||
def predict(self, x): | ||
if isinstance(x, np.ndarray): | ||
x = torch.Tensor(x) | ||
x = x.unsqueeze(0) if len(x.shape) == 2 else x | ||
x = x.to(self.device) | ||
logits = self.model(x) | ||
y_pred = logits.argmax(axis=1) | ||
return y_pred.cpu().detach().numpy() | ||
|
||
def training_step(self, batch): | ||
_, x, y = batch | ||
logits = self.model(x) | ||
loss = self.loss(logits, y) | ||
dict_ = {'train_loss': loss} | ||
self.log_dict(dict_, on_epoch=True, on_step=False) | ||
return loss | ||
|
||
def validation_step(self, batch): | ||
_, x, y = batch | ||
logits = self.model(x) | ||
loss = self.loss(logits, y) | ||
dict_ = {'val_loss': loss} | ||
|
||
metrics = self.get_metric(x, y, 'val') | ||
dict_.update(metrics) | ||
|
||
self.log_dict(dict_, on_epoch=True, on_step=False) | ||
|
||
return loss | ||
|
||
def get_metric(self, x, y_true, prefix): | ||
y_true = y_true.cpu().detach().numpy() | ||
y_pred = self.predict(x) | ||
p_f1_score = self.p_metric(y_pred, y_true) | ||
qrs_f1_score = self.qrs_metric(y_pred, y_true) | ||
t_f1_score = self.t_metric(y_pred, y_true) | ||
dict = {f'{prefix}_p_wave': p_f1_score, f'{prefix}_qrs_wave': qrs_f1_score, f'{prefix}_t_wave': t_f1_score} | ||
return dict | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.AdamW(self.model.parameters()) | ||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.3, patience=50) | ||
return [optimizer], [{"scheduler": scheduler, | ||
"interval": "epoch", | ||
"monitor": "train_loss"}] | ||
|
||
def get_dataset(paths): | ||
return [np.load(f'data/signals/{x}') for x in paths], \ | ||
[np.load(f'data/masks/{x}') for x in paths] | ||
|
||
class Cardio2D(Problem): | ||
def __init__(self, p_bound: Dict[str, float], q_bound: Dict[str, float]): | ||
super(Cardio2D, self).__init__() | ||
self.dimension = 2 | ||
self.number_of_float_variables = 2 | ||
self.number_of_discrete_variables = 0 | ||
self.number_of_objectives = 1 | ||
self.number_of_constraints = 0 | ||
|
||
ecg_list = sorted(os.listdir('data/signals/')) | ||
ecg_list = [x for x in ecg_list if x.split('_')[-1] != 'unsupervised.npy'] | ||
|
||
train_list, test_list = train_test_split(ecg_list, test_size=0.2, shuffle=True, random_state=42) | ||
|
||
for x in sorted(os.listdir('data/signals/')): | ||
if x.split('_')[-1] == 'unsupervised.npy': | ||
train_list.append(x) | ||
|
||
x_train, y_train = get_dataset(train_list) | ||
x_test, y_test = get_dataset(test_list) | ||
|
||
train_dataset = SegmentationDataset('cpu', train_list, x_train, y_train, common_mask=True, for_train=True) | ||
val_dataset = SegmentationDataset('cpu', test_list, x_test, y_test, common_mask=True) | ||
|
||
self.train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) | ||
self.val_loader = DataLoader(val_dataset, batch_size=32) | ||
|
||
self.float_variable_names = np.array(["P parameter", "Q parameter"], dtype=str) | ||
self.lower_bound_of_float_variables = np.array([p_bound['low'], q_bound['low']], | ||
dtype=np.double) | ||
self.upper_bound_of_float_variables = np.array([p_bound['up'], q_bound['up']], | ||
dtype=np.double) | ||
|
||
def calculate(self, point: Point, function_value: FunctionValue) -> FunctionValue: | ||
p, q = point.float_variables[0], point.float_variables[1] | ||
|
||
now = datetime.now().strftime('%d.%m.%Y_%H:%M:%S') | ||
|
||
checkpoint = ModelCheckpoint(dirpath=f'models/', | ||
filename=f"{random.uniform(1, 100):.9f}" + " " + f"{p:.9f}" + '_' + f"{q:.9f}" + '_' + '{epoch}_{val_p_wave:.6f}_{val_qrs_wave:.6f}_{val_t_wave:.6f}', | ||
monitor='val_p_wave', | ||
save_top_k=3, | ||
mode='max') | ||
early_stopping = EarlyStopping(monitor='val_loss', | ||
patience=300) | ||
|
||
cb = AllMetricTracker() | ||
model = UnetModule(p=p, q=q) | ||
trainer = Trainer(max_epochs=1_000_000, callbacks=[checkpoint, early_stopping, cb]) | ||
try: | ||
trainer.fit(model, self.train_loader, self.val_loader) | ||
except Exception as err: | ||
print(f"Unexpected {err=}, {type(err)=}") | ||
|
||
print('p ' + f"{p:.9f}") | ||
print('q ' + f"{q:.9f}") | ||
function_value.value = -cb.best_p_valscore | ||
print(-cb.best_p_valscore) | ||
return function_value |
61 changes: 61 additions & 0 deletions
61
examples/Machine_learning/NeuralNetwork/Segmentation/UnetExample.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import shutil | ||
|
||
import numpy as np | ||
from examples.Machine_learning.NeuralNetwork.Segmentation.Problem.Cardio2D import Cardio2D | ||
from iOpt.output_system.listeners.console_outputers import ConsoleOutputListener | ||
from iOpt.solver import Solver | ||
from iOpt.solver_parametrs import SolverParameters | ||
import hashlib | ||
import os | ||
from pathlib import Path | ||
|
||
import requests | ||
from tqdm import tqdm | ||
|
||
|
||
def _get_hash(path: Path) -> str: | ||
file_hash = hashlib.sha256() | ||
with open(path, "rb") as f: | ||
while chunk := f.read(8192): | ||
file_hash.update(chunk) | ||
return file_hash.hexdigest() | ||
|
||
|
||
def download(path: Path, public_key: str) -> None: | ||
url = "https://cloud-api.yandex.net/v1/disk/public/resources" | ||
params = {"public_key": f"https://disk.yandex.ru/d/{public_key}"} | ||
|
||
response = requests.get(url, params=params).json() | ||
download_url = response["file"] | ||
file_size = response["size"] | ||
sha256 = response["sha256"] | ||
|
||
response = requests.get(download_url, stream=True) | ||
|
||
if path.is_file() and os.path.getsize(path) == file_size: | ||
print(f"File already downloaded: {path}") | ||
if _get_hash(path) == sha256: | ||
return | ||
|
||
with tqdm(total=file_size, unit="B", unit_scale=True) as progress_bar: | ||
with open(path, "wb") as f: | ||
for data in response.iter_content(1024): | ||
progress_bar.update(len(data)) | ||
f.write(data) | ||
|
||
|
||
if __name__ == "__main__": | ||
if not os.path.exists('data'): | ||
path = Path('data.zip') | ||
download(path, 'Oqxcid6uX58kYQ') | ||
shutil.unpack_archive('data.zip', 'data', format="zip") | ||
os.remove('data.zip') | ||
|
||
p_value_bound = {'low': 0.0, 'up': 1.0} | ||
q_value_bound = {'low': 1.0, 'up': 1.6} | ||
problem = Cardio2D(p_value_bound, q_value_bound) | ||
method_params = SolverParameters(r=np.double(3.0), iters_limit=10) | ||
solver = Solver(problem, parameters=method_params) | ||
cfol = ConsoleOutputListener(mode='full') | ||
solver.add_listener(cfol) | ||
solver_info = solver.solve() |
Empty file.
Empty file.
73 changes: 73 additions & 0 deletions
73
examples/Machine_learning/NeuralNetwork/Segmentation/scripts/dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import torch | ||
import numpy as np | ||
|
||
class SegmentationDataset(torch.utils.data.Dataset): | ||
def __init__(self, device, paths, signals, masks=None, common_mask=False, for_train=False): | ||
|
||
self._device = device | ||
self._paths = paths | ||
self._signals = [torch.Tensor(x).to(device) for x in signals] | ||
self._masks = [torch.LongTensor(x).to(device) for x in masks] | ||
|
||
self.begin_noise, self.end_noise = 1e-3, 3e-3 | ||
self.begin_ampl, self.end_ampl = 0, 0.3 | ||
|
||
self.begin_freq, self.end_freq = 0, 0.009 | ||
|
||
self.prob_isoline = 0.7 | ||
self.prob_reverse = 0.5 | ||
self.sub_len = 4000 | ||
|
||
self.common_mask = common_mask | ||
self.for_train = for_train | ||
|
||
def reverse_ecg(self, signal): | ||
result = torch.zeros_like(signal, device=self._device) | ||
for i, x in enumerate(signal): | ||
sign = 2 * (np.random.rand() < self.prob_reverse) - 1 | ||
result[i] = sign * x | ||
return result | ||
|
||
def __len__(self): | ||
return len(self._signals) | ||
|
||
def __getitem__(self, i): | ||
if not self.for_train: | ||
return self._paths[i], self._signals[i], self.skip_borders(self._masks[i][0]) | ||
|
||
shift = np.random.randint(0, 5000 - self.sub_len - 1) | ||
noise = self.begin_noise + (self.end_noise - self.begin_noise) * np.random.rand() | ||
signal = self._signals[i][:, shift:shift + self.sub_len] + torch.normal(0, noise, | ||
size=(self.sub_len,), | ||
device=self._device) | ||
|
||
signal = self.reverse_ecg(signal) | ||
|
||
if self._masks is None: | ||
return self._paths[i], signal | ||
|
||
mask = self._masks[i][:, shift: shift + self.sub_len] | ||
indexes = torch.randperm(12, device=self._device) | ||
|
||
if self.common_mask: | ||
mask = mask[0] | ||
else: | ||
mask = mask[indexes] | ||
|
||
return self._paths[i], signal[indexes], self.skip_borders(mask) | ||
|
||
def skip_borders(self, mask): | ||
wave_start = torch.logical_and(torch.roll(mask, 1) == 0, mask != 0).type(torch.uint8) | ||
wave_finish = torch.logical_and(torch.roll(mask, -1) == 0, mask != 0).type(torch.uint8) | ||
|
||
indexes_starts, = torch.where(wave_start == 1) | ||
indexes_finish, = torch.where(wave_finish == 1) | ||
|
||
left_skip = indexes_starts[indexes_starts > 500][0] | ||
right_skip = indexes_finish[indexes_finish < len(mask) - 500][-1] | ||
|
||
mask_copy = torch.clone(mask) | ||
mask_copy[:left_skip] = 4 | ||
mask_copy[right_skip:] = 4 | ||
|
||
return mask_copy |
Oops, something went wrong.