From d0af06e44ab3dbbbb6d3625cb191f2a69d34593e Mon Sep 17 00:00:00 2001 From: "i.beskrovnyy" Date: Tue, 30 Jan 2024 12:06:43 +0300 Subject: [PATCH] black ew --- src/core/datasets.py | 2 +- src/core/train.py | 24 +++++++++++++++++------- src/utils/train_utils.py | 8 ++++++-- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/core/datasets.py b/src/core/datasets.py index f7e1e5c..77b1b1e 100644 --- a/src/core/datasets.py +++ b/src/core/datasets.py @@ -73,7 +73,7 @@ def _to_memory(self): n_bufs = int(len(idx) / buffer_size) idx = ( idx[: buffer_size * n_bufs].reshape(-1, buffer_size).tolist() - + idx[buffer_size * n_bufs:].reshape(1, -1).tolist() + + idx[buffer_size * n_bufs :].reshape(1, -1).tolist() ) with multiprocessing.Pool(self.to_memory_workers) as pool: mem_list = [] diff --git a/src/core/train.py b/src/core/train.py index 55dcd44..ec94142 100644 --- a/src/core/train.py +++ b/src/core/train.py @@ -41,7 +41,9 @@ def train_dim(args): # Optimizer ------------------------------------------------------------- opt = optim.Adam(model.parameters(), lr=args["tr_lr"]) - scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, "min", verbose=True, threshold=0.003, patience=args["tr_lr_patience"]) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + opt, "min", verbose=True, threshold=0.003, patience=args["tr_lr_patience"] + ) early_stop = EarlyStopperDim(args["tr_early_stop"]) biasLosses = [] @@ -94,8 +96,12 @@ def train_dim(args): ) directions = 2 if args["td_lstm_bidirectional"] else 1 - h0 = torch.zeros(args["td_lstm_num_layers"] * directions * torch.cuda.device_count(), args["tr_bs"], args["td_lstm_h"]) - c0 = torch.zeros(args["td_lstm_num_layers"] * directions * torch.cuda.device_count(), args["tr_bs"], args["td_lstm_h"]) + h0 = torch.zeros( + args["td_lstm_num_layers"] * directions * torch.cuda.device_count(), args["tr_bs"], args["td_lstm_h"] + ) + c0 = torch.zeros( + args["td_lstm_num_layers"] * directions * torch.cuda.device_count(), args["tr_bs"], args["td_lstm_h"] + ) for xb_spec, yb_mos, (idx, n_wins) in dl_train: # Estimate batch --------------------------------------------------- xb_spec = xb_spec.to(dev) @@ -294,7 +300,7 @@ def train_dim(args): # Scheduler update --------------------------------------------- scheduler.step(loss) - earl_stp = early_stop.step(r) + early_stop = early_stop.step(r) # Print -------------------------------------------------------- ep_runtime = time.time() - tic_epoch @@ -335,12 +341,16 @@ def train_dim(args): ) # Early stopping ----------------------------------------------- - if earl_stp: + if earl_stop: logger.info( - "--> Early stopping. best_r_p {:0.2f} best_rmse {:0.2f}".format(early_stop.best_r_p, early_stop.best_rmse) + "--> Early stopping. best_r_p {:0.2f} best_rmse {:0.2f}".format( + early_stop.best_r_p, early_stop.best_rmse + ) ) return # Training done -------------------------------------------------------- - logger.info("--> Training done. best_r_p {:0.2f} best_rmse {:0.2f}".format(early_stop.best_r_p, early_stop.best_rmse)) + logger.info("--> Training done. best_r_p {:0.2f} best_rmse {:0.2f}".format( + early_stop.best_r_p, early_stop.best_rmse) + ) return diff --git a/src/utils/train_utils.py b/src/utils/train_utils.py index a366e00..cd35352 100644 --- a/src/utils/train_utils.py +++ b/src/utils/train_utils.py @@ -108,7 +108,9 @@ def get_loss(self, yb, yb_hat, idx): b = torch.tensor(self.b, dtype=torch.float).to(yb_hat.device) b = b[idx, :] - yb_hat_map = (b[:, 0] + b[:, 1] * yb_hat[:, 0] + b[:, 2] * yb_hat[:, 0] ** 2 + b[:, 3] * yb_hat[:, 0] ** 3).view(-1, 1) + yb_hat_map = ( + b[:, 0] + b[:, 1] * yb_hat[:, 0] + b[:, 2] * yb_hat[:, 0] ** 2 + b[:, 3] * yb_hat[:, 0] ** 3 + ).view(-1, 1) loss_bias = self._nan_mse(yb_hat_map, yb) loss_normal = self._nan_mse(yb_hat, yb) @@ -282,7 +284,9 @@ def eval_results( ) ) else: - logger.info("%-30s r_p_file: %0.2f, rmse_map_file: %0.2f" % (db_name + ":", r["r_p_file"], r["rmse_map_file"])) + logger.info( + "%-30s r_p_file: %0.2f, rmse_map_file: %0.2f" % (db_name + ":", r["r_p_file"], r["rmse_map_file"]) + ) # Save individual database results in DataFrame db_results_df = pd.DataFrame(db_results_df)