Skip to content

Commit

Permalink
Merge pull request #2701 from AI-Casanova/refactor-cache
Browse files Browse the repository at this point in the history
Refactor cache
  • Loading branch information
vladmandic authored Jan 10, 2024
2 parents 3352be2 + 5aafe1e commit 8c65942
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,15 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
return None
if checkpoint_info in checkpoints_loaded:
shared.log.info("Model weights loading: from cache")
checkpoints_loaded.move_to_end(checkpoint_info, last=True) # FIFO -> LRU cache
return checkpoints_loaded[checkpoint_info]
res = read_state_dict(checkpoint_info.filename)
if shared.opts.sd_checkpoint_cache > 0 and shared.backend == shared.Backend.ORIGINAL:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = res
# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False)
timer.record("load")
return res

Expand All @@ -440,9 +447,6 @@ def load_model_weights(model: torch.nn.Module, checkpoint_info: CheckpointInfo,
return False
del state_dict
timer.record("apply")
if shared.opts.sd_checkpoint_cache > 0 and shared.backend == shared.Backend.ORIGINAL:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
if shared.opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
timer.record("channels")
Expand All @@ -464,9 +468,6 @@ def load_model_weights(model: torch.nn.Module, checkpoint_info: CheckpointInfo,
else:
model.model.diffusion_model.to(devices.dtype_unet)
model.first_stage_model.to(devices.dtype_vae)
# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False)
model.sd_model_hash = checkpoint_info.calculate_shorthash()
model.sd_model_checkpoint = checkpoint_info.filename
model.sd_checkpoint_info = checkpoint_info
Expand Down Expand Up @@ -1222,7 +1223,7 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model')
unload_model_weights(op=op)
sd_model = None
timer = Timer()
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
state_dict = get_checkpoint_state_dict(checkpoint_info, timer) if shared.backend == shared.Backend.ORIGINAL else None # TODO Revist after Diffusers enables state_dict loading
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
timer.record("config")
if sd_model is None or checkpoint_config != getattr(sd_model, 'used_config', None):
Expand Down

0 comments on commit 8c65942

Please sign in to comment.