-
Notifications
You must be signed in to change notification settings - Fork 76
/
train_localtoon.py
36 lines (27 loc) · 1.05 KB
/
train_localtoon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
import cv2
from modelscope.trainers.cv import CartoonTranslationTrainer
def main(args):
data_photo = os.path.join(args.data_dir, 'face_photo')
data_cartoon = os.path.join(args.data_dir, 'face_cartoon')
style = args.style
if style == "anime":
style = ""
else:
style = '-' + style
model_id = 'damo/cv_unet_person-image-cartoon' + style + '_compound-models'
max_steps = 300000
trainer = CartoonTranslationTrainer(
model=model_id,
work_dir=args.work_dir,
photo=data_photo,
cartoon=data_cartoon,
max_steps=max_steps)
trainer.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="process remove bg result")
parser.add_argument("--data_dir", type=str, default='', help="Path to training images.")
parser.add_argument("--work_dir", type=str, default='', help="Path to save results.")
parser.add_argument("--style", type=str, default='anime', help="resume training from similar style.")
args = parser.parse_args()
main(args)