Skip to content

Commit

Permalink
black ew
Browse files Browse the repository at this point in the history
  • Loading branch information
i.beskrovnyy committed Jan 30, 2024
1 parent 67c0a85 commit d0af06e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _to_memory(self):
n_bufs = int(len(idx) / buffer_size)
idx = (
idx[: buffer_size * n_bufs].reshape(-1, buffer_size).tolist()
+ idx[buffer_size * n_bufs:].reshape(1, -1).tolist()
+ idx[buffer_size * n_bufs :].reshape(1, -1).tolist()
)
with multiprocessing.Pool(self.to_memory_workers) as pool:
mem_list = []
Expand Down
24 changes: 17 additions & 7 deletions src/core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def train_dim(args):

# Optimizer -------------------------------------------------------------
opt = optim.Adam(model.parameters(), lr=args["tr_lr"])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, "min", verbose=True, threshold=0.003, patience=args["tr_lr_patience"])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
opt, "min", verbose=True, threshold=0.003, patience=args["tr_lr_patience"]
)
early_stop = EarlyStopperDim(args["tr_early_stop"])

biasLosses = []
Expand Down Expand Up @@ -94,8 +96,12 @@ def train_dim(args):
)

directions = 2 if args["td_lstm_bidirectional"] else 1
h0 = torch.zeros(args["td_lstm_num_layers"] * directions * torch.cuda.device_count(), args["tr_bs"], args["td_lstm_h"])
c0 = torch.zeros(args["td_lstm_num_layers"] * directions * torch.cuda.device_count(), args["tr_bs"], args["td_lstm_h"])
h0 = torch.zeros(
args["td_lstm_num_layers"] * directions * torch.cuda.device_count(), args["tr_bs"], args["td_lstm_h"]
)
c0 = torch.zeros(
args["td_lstm_num_layers"] * directions * torch.cuda.device_count(), args["tr_bs"], args["td_lstm_h"]
)
for xb_spec, yb_mos, (idx, n_wins) in dl_train:
# Estimate batch ---------------------------------------------------
xb_spec = xb_spec.to(dev)
Expand Down Expand Up @@ -294,7 +300,7 @@ def train_dim(args):

# Scheduler update ---------------------------------------------
scheduler.step(loss)
earl_stp = early_stop.step(r)
early_stop = early_stop.step(r)

# Print --------------------------------------------------------
ep_runtime = time.time() - tic_epoch
Expand Down Expand Up @@ -335,12 +341,16 @@ def train_dim(args):
)

# Early stopping -----------------------------------------------
if earl_stp:
if earl_stop:
logger.info(
"--> Early stopping. best_r_p {:0.2f} best_rmse {:0.2f}".format(early_stop.best_r_p, early_stop.best_rmse)
"--> Early stopping. best_r_p {:0.2f} best_rmse {:0.2f}".format(
early_stop.best_r_p, early_stop.best_rmse
)
)
return

# Training done --------------------------------------------------------
logger.info("--> Training done. best_r_p {:0.2f} best_rmse {:0.2f}".format(early_stop.best_r_p, early_stop.best_rmse))
logger.info("--> Training done. best_r_p {:0.2f} best_rmse {:0.2f}".format(
early_stop.best_r_p, early_stop.best_rmse)
)
return
8 changes: 6 additions & 2 deletions src/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def get_loss(self, yb, yb_hat, idx):
b = torch.tensor(self.b, dtype=torch.float).to(yb_hat.device)
b = b[idx, :]

yb_hat_map = (b[:, 0] + b[:, 1] * yb_hat[:, 0] + b[:, 2] * yb_hat[:, 0] ** 2 + b[:, 3] * yb_hat[:, 0] ** 3).view(-1, 1)
yb_hat_map = (
b[:, 0] + b[:, 1] * yb_hat[:, 0] + b[:, 2] * yb_hat[:, 0] ** 2 + b[:, 3] * yb_hat[:, 0] ** 3
).view(-1, 1)

loss_bias = self._nan_mse(yb_hat_map, yb)
loss_normal = self._nan_mse(yb_hat, yb)
Expand Down Expand Up @@ -282,7 +284,9 @@ def eval_results(
)
)
else:
logger.info("%-30s r_p_file: %0.2f, rmse_map_file: %0.2f" % (db_name + ":", r["r_p_file"], r["rmse_map_file"]))
logger.info(
"%-30s r_p_file: %0.2f, rmse_map_file: %0.2f" % (db_name + ":", r["r_p_file"], r["rmse_map_file"])
)

# Save individual database results in DataFrame
db_results_df = pd.DataFrame(db_results_df)
Expand Down

0 comments on commit d0af06e

Please sign in to comment.