forked from StartHua/Comfyui_CXH_FluxLoraMerge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmergeNode.py
324 lines (255 loc) · 12.5 KB
/
mergeNode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
# merge_lora.py
import os
import time
import sys
import torch
from tqdm import tqdm
from safetensors.torch import load_file, save_file
import comfy.utils
import comfy.sd
import folder_paths
def merge_loras_mix(main_lora_model, merge_lora_model, weight_percentages, merge_type):
"""Merges two LoRA models using multiple weight percentages."""
merged_models = []
for weight in weight_percentages:
merged_model = merge_loras_weighted(main_lora_model, merge_lora_model, weight / 100, merge_type)
merged_models.append((weight / 100, merged_model))
return merged_models
def merge_loras_weighted(main_lora_model, merge_lora_model, main_weight, merge_type='adaptive'):
"""Merges two LoRA models using adaptive or manual merge with a specified main weight."""
merged_model = {}
all_keys = set(main_lora_model.keys()).union(set(merge_lora_model.keys()))
with tqdm(total=len(all_keys), desc="Merging LoRA models", unit="layer") as pbar:
for key in all_keys:
if key in main_lora_model and key in merge_lora_model:
if merge_type == 'adaptive':
merged_model[key] = adaptive_merge(main_lora_model[key], merge_lora_model[key], main_weight)
else:
merged_model[key] = manual_merge(main_lora_model[key], merge_lora_model[key], main_weight)
elif key in main_lora_model:
merged_model[key] = main_lora_model[key]
else:
merged_model[key] = merge_lora_model[key]
pbar.update(1)
return merged_model
def additive_merge(main_lora_model, merge_lora_model, add_weight):
"""Always use 100% of the first model and add the second model at a specified percentage."""
merged_model = {}
all_keys = set(main_lora_model.keys()).union(set(merge_lora_model.keys()))
with tqdm(total=len(all_keys), desc="Additive Merging LoRA models", unit="layer") as pbar:
for key in all_keys:
if key in main_lora_model and key in merge_lora_model:
tensor1 = main_lora_model[key]
tensor2 = merge_lora_model[key]
if tensor1.size() != tensor2.size():
tensor1, tensor2 = pad_tensors(tensor1, tensor2)
merged_model[key] = tensor1 + (add_weight * tensor2)
elif key in main_lora_model:
merged_model[key] = main_lora_model[key]
else:
merged_model[key] = add_weight * merge_lora_model[key]
pbar.update(1)
return merged_model
def adaptive_merge(tensor1, tensor2, main_weight):
"""Merges two tensors using adaptive weights based on their L2 norms."""
if tensor1.size() != tensor2.size():
tensor1, tensor2 = pad_tensors(tensor1, tensor2)
norm1 = torch.norm(tensor1)
norm2 = torch.norm(tensor2)
adaptive_weight1 = norm1 / (norm1 + norm2)
adaptive_weight2 = norm2 / (norm1 + norm2)
final_weight1 = adaptive_weight1 * main_weight + (1 - adaptive_weight2) * (1 - main_weight)
final_weight2 = 1 - final_weight1
return final_weight1 * tensor1 + final_weight2 * tensor2
def manual_merge(tensor1, tensor2, main_weight):
"""Merges two tensors using fixed weights based on user input."""
if tensor1.size() != tensor2.size():
tensor1, tensor2 = pad_tensors(tensor1, tensor2)
return main_weight * tensor1 + (1 - main_weight) * tensor2
def save_merged_lora(merged_model, lora_folder, main_lora_file, merge_lora_file, weight, merge_type):
"""Saves the merged LoRA model with an appropriate name."""
main_name = os.path.splitext(main_lora_file)[0]
merge_name = os.path.splitext(merge_lora_file)[0]
if merge_type == 'adaptive':
strategy_code = f"A{int(weight * 100)}"
elif merge_type == 'additive':
strategy_code = f"ADDI{int(weight * 100)}"
else: # manual
strategy_code = f"M{int(weight * 100)}"
merged_lora_name = f"mrg_{main_name}_{strategy_code}_{merge_name}.safetensors"
merged_lora_path = os.path.join(lora_folder, merged_lora_name)
save_file(merged_model, merged_lora_path)
print(f"Merged LoRA saved as: {merged_lora_name}")
from tqdm import tqdm
import torch
def pad_tensors(tensor1, tensor2):
"""Pads tensors to the same size if they differ."""
max_size = [max(s1, s2) for s1, s2 in zip(tensor1.size(), tensor2.size())]
padded1 = torch.zeros(max_size, device=tensor1.device, dtype=tensor1.dtype)
padded2 = torch.zeros(max_size, device=tensor2.device, dtype=tensor2.dtype)
padded1[tuple(slice(0, s) for s in tensor1.size())] = tensor1
padded2[tuple(slice(0, s) for s in tensor2.size())] = tensor2
return padded1, padded2
def pad_all_tensors(tensors):
"""Pads all tensors in the list to match the maximum size across all tensors."""
if not tensors:
return []
# Determine the max size across all tensors
max_size = [max(t.size(dim) for t in tensors) for dim in range(len(tensors[0].size()))]
# Pad each tensor to the max size
padded_tensors = []
for tensor in tensors:
padded_tensor = torch.zeros(max_size, device=tensor.device, dtype=tensor.dtype)
slices = tuple(slice(0, s) for s in tensor.size())
padded_tensor[slices] = tensor
padded_tensors.append(padded_tensor)
return padded_tensors
def god_mode(lora_folder, merge_strategy='adaptive'):
"""
Merges multiple LoRA models simultaneously using the specified strategy, constrained by available memory.
Args:
- lora_folder: The folder containing LoRA models to merge.
- merge_strategy: The merging strategy to use ('adaptive', 'additive').
Returns:
- Path to the final merged model saved to disk.
"""
# Load all LoRA models from the folder with progress bar
lora_files = [f for f in os.listdir(lora_folder) if f.endswith('.safetensors')]
if not lora_files:
print("No LoRA models found to merge.")
return None
print(f"Loading {len(lora_files)} LoRA models...")
lora_models = []
largest_file_size = 0
largest_file_name = ''
with tqdm(total=len(lora_files), desc="Loading LoRA Models", unit="file") as pbar:
for file in lora_files:
file_path = os.path.join(lora_folder, file)
file_size = os.path.getsize(file_path)
if file_size > largest_file_size:
largest_file_size = file_size
largest_file_name = file
try:
lora_model = load_file(file_path)
lora_models.append(lora_model)
except Exception as e:
print(f"Error loading model {file}: {e}")
pbar.update(1)
if not lora_models:
print("No LoRA models successfully loaded.")
return None
print(f"Largest input file: {largest_file_name} ({largest_file_size} bytes)")
print(f"Starting merge with {len(lora_models)} LoRA models using {merge_strategy} strategy.")
# Initialize the merged model with keys from all models
all_keys = set().union(*(model.keys() for model in lora_models))
merged_model = {key: torch.zeros_like(next(model[key] for model in lora_models if key in model))
for key in all_keys}
total_input_tensors = 0
total_merged_tensors = 0
for key in tqdm(all_keys, desc="Merging tensors", unit="tensor"):
tensors = [model[key] for model in lora_models if key in model]
total_input_tensors += len(tensors)
if not tensors:
print(f"Warning: No tensors found for key: {key}")
continue
# print(f"Merging {len(tensors)} tensors for key: {key}")
# print(f"Input tensor sizes: {[t.size() for t in tensors]}")
try:
padded_tensors = pad_all_tensors(tensors)
if merge_strategy == 'adaptive':
merged_model[key] = adaptive_merge_multiple(padded_tensors)
elif merge_strategy == 'additive':
merged_model[key] = additive_merge_multiple(padded_tensors)
else:
raise ValueError(f"Unknown merge strategy: {merge_strategy}")
# print(f"Merged tensor size: {merged_model[key].size()}")
total_merged_tensors += 1
except Exception as e:
print(f"Error merging tensors for key {key}: {e}")
# Instead of skipping, use the tensor from the largest file if available
largest_model_tensor = next((model[key] for model in lora_models if key in model and model is lora_models[0]), None)
if largest_model_tensor is not None:
merged_model[key] = largest_model_tensor
print(f"Using tensor from largest file for key {key}")
else:
print(f"Warning: Skipping key {key} due to errors")
print(f"Total input tensors: {total_input_tensors}")
print(f"Total merged tensors: {total_merged_tensors}")
# Determine the strategy code for the filename
strategy_code = 'A' if merge_strategy == 'adaptive' else 'M'
# Create the filename using the correct naming convention
merged_filename = f"mrg_final_merged_{strategy_code}100_god_mode.safetensors"
merged_file_path = os.path.join(lora_folder, merged_filename)
# Save the final merged model
try:
save_file(merged_model, merged_file_path)
merged_file_size = os.path.getsize(merged_file_path)
print(f"Merged file saved as: {merged_filename}")
print(f"Merged file size: {merged_file_size} bytes")
if merged_file_size < largest_file_size:
print("Warning: Merged file is smaller than the largest input file. Some data may have been lost in the process.")
else:
print("Merged file is larger than or equal to the largest input file, as expected.")
except Exception as e:
print(f"Error saving merged model: {e}")
return None
return merged_file_path
def adaptive_merge_multiple(tensors):
"""Merges multiple tensors using adaptive weights based on their L2 norms."""
try:
norms = [torch.norm(tensor) for tensor in tensors]
total_norm = sum(norms)
weights = [norm / total_norm for norm in norms]
# Calculate the final merged tensor
merged_tensor = sum(w * t for w, t in zip(weights, tensors))
return merged_tensor
except Exception as e:
print(f"Error in adaptive_merge_multiple: {e}")
return torch.zeros_like(tensors[0])
def additive_merge_multiple(tensors):
"""Merges multiple tensors using additive merging with equal weighting."""
try:
weight = 1.0 / len(tensors)
merged_tensor = sum(weight * tensor for tensor in tensors)
return merged_tensor
except Exception as e:
print(f"Error in additive_merge_multiple: {e}")
return torch.zeros_like(tensors[0])
class CXH_Lora_Merge:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"savename":("STRING"),
"main_lora": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the merged LoRA."}),
"merge_lora": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the merged LoRA."}),
"merge_type": (["adaptive", "manual","additive"],),
"weight":("INT", {"default": 50, "min": 0, "max": 100, "step": 1}),
}
}
RETURN_TYPES = ()
RETURN_NAMES = ()
FUNCTION = "gen"
OUTPUT_NODE = True
CATEGORY = "CXH/model"
def gen(self,savename ,main_lora,merge_lora,merge_type,weight,):
lora_path_1 = os.path.join(folder_paths.models_dir,"loras",main_lora)
lora_path_2 = os.path.join(folder_paths.models_dir,"loras",merge_lora)
save_lora = os.path.join(folder_paths.models_dir,"loras",savename+".safetensors")
print(lora_path_1)
main_lora_model = load_file(lora_path_1)
merge_lora_model = load_file(lora_path_2)
if merge_type == 'adaptive':
merged_models = merge_loras_mix(main_lora_model, merge_lora_model, [weight],merge_type)
elif merge_type == 'additive':
merged_model = additive_merge(main_lora_model, merge_lora_model, weight / 100)
merged_models = [(weight / 100, merged_model)]
else: # Weighted
merged_model = merge_loras_weighted(main_lora_model, merge_lora_model, weight / 100, merge_type)
merged_models = [(weight / 100, merged_model)]
for weight, merged_model in merged_models:
print(f"Merged LoRA saved as: {save_lora}")
save_file(merged_model, save_lora)
return ()