From 5aafe1ece2de08974b02acd544219cf4bb1a4f31 Mon Sep 17 00:00:00 2001 From: AI-Casanova <54461896+AI-Casanova@users.noreply.github.com> Date: Tue, 9 Jan 2024 22:52:48 -0600 Subject: [PATCH] Refactor cache Disable double-load --- modules/sd_models.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 0a0df6c82..a9f2f5dc7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -418,8 +418,15 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): return None if checkpoint_info in checkpoints_loaded: shared.log.info("Model weights loading: from cache") + checkpoints_loaded.move_to_end(checkpoint_info, last=True) # FIFO -> LRU cache return checkpoints_loaded[checkpoint_info] res = read_state_dict(checkpoint_info.filename) + if shared.opts.sd_checkpoint_cache > 0 and shared.backend == shared.Backend.ORIGINAL: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = res + # clean up cache if limit is reached + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) timer.record("load") return res @@ -440,9 +447,6 @@ def load_model_weights(model: torch.nn.Module, checkpoint_info: CheckpointInfo, return False del state_dict timer.record("apply") - if shared.opts.sd_checkpoint_cache > 0 and shared.backend == shared.Backend.ORIGINAL: - # cache newly loaded model - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() if shared.opts.opt_channelslast: model.to(memory_format=torch.channels_last) timer.record("channels") @@ -464,9 +468,6 @@ def load_model_weights(model: torch.nn.Module, checkpoint_info: CheckpointInfo, else: model.model.diffusion_model.to(devices.dtype_unet) model.first_stage_model.to(devices.dtype_vae) - # clean up cache if limit is reached - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: - checkpoints_loaded.popitem(last=False) model.sd_model_hash = checkpoint_info.calculate_shorthash() model.sd_model_checkpoint = checkpoint_info.filename model.sd_checkpoint_info = checkpoint_info @@ -1222,7 +1223,7 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model') unload_model_weights(op=op) sd_model = None timer = Timer() - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) if shared.backend == shared.Backend.ORIGINAL else None # TODO Revist after Diffusers enables state_dict loading checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) timer.record("config") if sd_model is None or checkpoint_config != getattr(sd_model, 'used_config', None):