forked from bmaltais/kohya_ss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkohya_gui.py
147 lines (131 loc) · 4.38 KB
/
kohya_gui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gradio as gr
import os
import argparse
from dreambooth_gui import dreambooth_tab
from finetune_gui import finetune_tab
from textual_inversion_gui import ti_tab
from library.utilities import utilities_tab
from lora_gui import lora_tab
from library.class_lora_tab import LoRATools
import os
from library.custom_logging import setup_logging
from localization_ext import add_javascript
# Set up logging
log = setup_logging()
def UI(**kwargs):
add_javascript(kwargs.get('language'))
css = ''
headless = kwargs.get('headless', False)
log.info(f'headless: {headless}')
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
log.info('Load CSS...')
css += file.read() + '\n'
if os.path.exists('./.release'):
with open(os.path.join('./.release'), 'r', encoding='utf8') as file:
release = file.read()
if os.path.exists('./README.md'):
with open(os.path.join('./README.md'), 'r', encoding='utf8') as file:
README = file.read()
interface = gr.Blocks(
css=css, title=f'Kohya_ss GUI {release}', theme=gr.themes.Default()
)
with interface:
with gr.Tab('Dreambooth'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = dreambooth_tab(headless=headless)
with gr.Tab('LoRA'):
lora_tab(headless=headless)
with gr.Tab('Textual Inversion'):
ti_tab(headless=headless)
with gr.Tab('Finetuning'):
finetune_tab(headless=headless)
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
headless=headless,
)
with gr.Tab('LoRA'):
_ = LoRATools(headless=headless)
with gr.Tab('About'):
gr.Markdown(f'kohya_ss GUI release {release}')
with gr.Tab('README'):
gr.Markdown(README)
htmlStr = f"""
<html>
<body>
<div class="ver-class">{release}</div>
</body>
</html>
"""
gr.HTML(htmlStr)
# Show the interface
launch_kwargs = {}
username = kwargs.get('username')
password = kwargs.get('password')
server_port = kwargs.get('server_port', 0)
inbrowser = kwargs.get('inbrowser', False)
share = kwargs.get('share', False)
server_name = kwargs.get('listen')
launch_kwargs['server_name'] = server_name
if username and password:
launch_kwargs['auth'] = (username, password)
if server_port > 0:
launch_kwargs['server_port'] = server_port
if inbrowser:
launch_kwargs['inbrowser'] = inbrowser
if share:
launch_kwargs['share'] = share
interface.launch(**launch_kwargs)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
parser.add_argument(
'--listen',
type=str,
default='127.0.0.1',
help='IP to listen on for connections to Gradio',
)
parser.add_argument(
'--username', type=str, default='', help='Username for authentication'
)
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port',
type=int,
default=0,
help='Port to run the server listener on',
)
parser.add_argument(
'--inbrowser', action='store_true', help='Open in browser'
)
parser.add_argument(
'--share', action='store_true', help='Share the gradio UI'
)
parser.add_argument(
'--headless', action='store_true', help='Is the server headless'
)
parser.add_argument(
'--language', type=str, default=None, help='Set custom language'
)
args = parser.parse_args()
UI(
username=args.username,
password=args.password,
inbrowser=args.inbrowser,
server_port=args.server_port,
share=args.share,
listen=args.listen,
headless=args.headless,
language=args.language,
)