Skip to content

Commit

Permalink
controlnet load lock
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Jan 10, 2025
1 parent b48e464 commit 85a6aca
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 66 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
- flux support on-the-fly quantization for bnb of unet only
- control restore pipeline before running hires
- restore args after batch run
- control add load lock

## Update for 2024-12-31

Expand Down
135 changes: 69 additions & 66 deletions modules/control/units/controlnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import time
import threading
from typing import Union
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, FluxPipeline, StableDiffusion3Pipeline, ControlNetModel
from modules.control.units import detect
Expand Down Expand Up @@ -116,6 +117,7 @@
all_models.update(predefined_f1)
all_models.update(predefined_sd3)
cache_dir = 'models/control/controlnet'
load_lock = threading.Lock()


def find_models():
Expand Down Expand Up @@ -236,73 +238,74 @@ def load_safetensors(self, model_id, model_path):
self.model = cls.from_single_file(model_path, config=config, **self.load_config)

def load(self, model_id: str = None, force: bool = True) -> str:
try:
t0 = time.time()
model_id = model_id or self.model_id
if model_id is None or model_id == 'None':
self.reset()
return
if model_id not in all_models:
log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}')
return
model_path = all_models[model_id]
if model_path == '':
return
if model_path is None:
log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id')
return
if 'lora' in model_id.lower():
self.model = model_path
return
if model_id == self.model_id and not force:
# log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded')
return
log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"')
cls, _config = self.get_class(model_id)
if model_path.endswith('.safetensors'):
self.load_safetensors(model_id, model_path)
else:
kwargs = {}
if '/bin' in model_path:
model_path = model_path.replace('/bin', '')
self.load_config['use_safetensors'] = False
if cls is None:
log.error(f'Control {what} model load failed: id="{model_id}" unknown base model')
with load_lock:
try:
t0 = time.time()
model_id = model_id or self.model_id
if model_id is None or model_id == 'None':
self.reset()
return
if variants.get(model_id, None) is not None:
kwargs['variant'] = variants[model_id]
self.model = cls.from_pretrained(model_path, **self.load_config, **kwargs)
if self.model is None:
return
if self.dtype is not None:
self.model.to(self.dtype)
if "ControlNet" in opts.nncf_compress_weights:
try:
log.debug(f'Control {what} model NNCF Compress: id="{model_id}"')
from installer import install
install('nncf==2.7.0', quiet=True)
from modules.sd_models_compile import nncf_compress_model
self.model = nncf_compress_model(self.model)
except Exception as e:
log.error(f'Control {what} model NNCF Compression failed: id="{model_id}" error={e}')
elif "ControlNet" in opts.optimum_quanto_weights:
try:
log.debug(f'Control {what} model Optimum Quanto: id="{model_id}"')
model_quant.load_quanto('Load model: type=ControlNet')
from modules.sd_models_compile import optimum_quanto_model
self.model = optimum_quanto_model(self.model)
except Exception as e:
log.error(f'Control {what} model Optimum Quanto failed: id="{model_id}" error={e}')
if self.device is not None:
self.model.to(self.device)
t1 = time.time()
self.model_id = model_id
log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" cls={cls.__name__} time={t1-t0:.2f}')
return f'{what} loaded model: {model_id}'
except Exception as e:
log.error(f'Control {what} model load failed: id="{model_id}" error={e}')
errors.display(e, f'Control {what} load')
return f'{what} failed to load model: {model_id}'
if model_id not in all_models:
log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}')
return
model_path = all_models[model_id]
if model_path == '':
return
if model_path is None:
log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id')
return
if 'lora' in model_id.lower():
self.model = model_path
return
if model_id == self.model_id and not force:
# log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded')
return
log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"')
cls, _config = self.get_class(model_id)
if model_path.endswith('.safetensors'):
self.load_safetensors(model_id, model_path)
else:
kwargs = {}
if '/bin' in model_path:
model_path = model_path.replace('/bin', '')
self.load_config['use_safetensors'] = False
if cls is None:
log.error(f'Control {what} model load failed: id="{model_id}" unknown base model')
return
if variants.get(model_id, None) is not None:
kwargs['variant'] = variants[model_id]
self.model = cls.from_pretrained(model_path, **self.load_config, **kwargs)
if self.model is None:
return
if self.dtype is not None:
self.model.to(self.dtype)
if "ControlNet" in opts.nncf_compress_weights:
try:
log.debug(f'Control {what} model NNCF Compress: id="{model_id}"')
from installer import install
install('nncf==2.7.0', quiet=True)
from modules.sd_models_compile import nncf_compress_model
self.model = nncf_compress_model(self.model)
except Exception as e:
log.error(f'Control {what} model NNCF Compression failed: id="{model_id}" error={e}')
elif "ControlNet" in opts.optimum_quanto_weights:
try:
log.debug(f'Control {what} model Optimum Quanto: id="{model_id}"')
model_quant.load_quanto('Load model: type=ControlNet')
from modules.sd_models_compile import optimum_quanto_model
self.model = optimum_quanto_model(self.model)
except Exception as e:
log.error(f'Control {what} model Optimum Quanto failed: id="{model_id}" error={e}')
if self.device is not None:
self.model.to(self.device)
t1 = time.time()
self.model_id = model_id
log.info(f'Control {what} model loaded: id="{model_id}" path="{model_path}" cls={cls.__name__} time={t1-t0:.2f}')
return f'{what} loaded model: {model_id}'
except Exception as e:
log.error(f'Control {what} model load failed: id="{model_id}" error={e}')
errors.display(e, f'Control {what} load')
return f'{what} failed to load model: {model_id}'


class ControlNetPipeline():
Expand Down

0 comments on commit 85a6aca

Please sign in to comment.