Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DJC-GO-SOLO authored Oct 29, 2022
1 parent aa49822 commit 495a131
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions code/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def parse_args():

for epoch in range(1, args.epochs + 1):
print(f'******************** Training Epoch: {epoch} ********************')
# train_fn(train_dataloader, model, criterion1, optimizer,args,device,epoch,scheduler)
train_fn(train_dataloader, model, criterion1, optimizer,args,device,epoch,scheduler)
r4_1,r4_2,mrr = validate_fn(val_dataloader, model, criterion2,args,device,epoch)
logger.info(f"Epoch: {epoch:02}. Valid. R4_1:{r4_1} R4_2:{r4_2} MRR:{mrr}")
if r4_1 >= best:
best = r4_1
logger.info(f'{r4_1} model saved')
torch.save(model.state_dict(), os.path.join(args.output_dir, f"best_model.bin"))
torch.save(model.state_dict(), os.path.join(args.output_dir, f"best_model.bin"))

0 comments on commit 495a131

Please sign in to comment.