From bf35ebcdf3d017813e5af49a3bbfee11b752a672 Mon Sep 17 00:00:00 2001 From: pengzhendong <275331498@qq.com> Date: Tue, 14 Nov 2023 16:46:23 +0800 Subject: [PATCH] fix bug of ddp --- wetts/vits/utils/task.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/wetts/vits/utils/task.py b/wetts/vits/utils/task.py index afae44c..2459d1f 100644 --- a/wetts/vits/utils/task.py +++ b/wetts/vits/utils/task.py @@ -3,6 +3,7 @@ import json import logging import os +from pathlib import Path import torch @@ -189,9 +190,7 @@ def get_hparams(init=True): args = parser.parse_args() model_dir = args.model - - if not os.path.exists(model_dir): - os.makedirs(model_dir) + Path(model_dir).mkdir(parents=True, exist_ok=True) config_path = args.config config_save_path = os.path.join(model_dir, "config.json") @@ -247,8 +246,7 @@ def get_logger(model_dir, filename="train.log"): formatter = logging.Formatter( "%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") - if not os.path.exists(model_dir): - os.makedirs(model_dir) + Path(model_dir).mkdir(parents=True, exist_ok=True) h = logging.FileHandler(os.path.join(model_dir, filename)) h.setLevel(logging.INFO) h.setFormatter(formatter)