Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
vladmandic committed May 20, 2023
1 parent e8ddc6e commit 0891b30
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ venv
/*.sh
/*.txt
/*.mp3
/*.lnk
!webui.bat
!webui.sh

Expand Down
1 change: 1 addition & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Stuff to be added...

Stuff to be investigated...

- Gradio `app_kwargs`: <https://github.com/gradio-app/gradio/issues/4054>

## Merge PRs

Expand Down
6 changes: 5 additions & 1 deletion modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,19 @@ def read_state_dict(checkpoint_file, map_location=None): # pylint: disable=unuse
_, extension = os.path.splitext(checkpoint_file)
if shared.opts.stream_load:
if extension.lower() == ".safetensors":
shared.log.debug('Model weights loading: type=safetensors mode=buffered')
buffer = f.read()
pl_sd = safetensors.torch.load(buffer)
else:
shared.log.debug('Model weights loading: type=checkpoint mode=buffered')
buffer = io.BytesIO(f.read())
pl_sd = torch.load(buffer, map_location='cpu')
else:
if extension.lower() == ".safetensors":
shared.log.debug('Model weights loading: type=safetensors mode=mmap')
pl_sd = safetensors.torch.load_file(checkpoint_file, device='cpu')
else:
shared.log.debug('Model weights loading: type=checkpoint mode=direct')
pl_sd = torch.load(f, map_location='cpu')
sd = get_state_dict_from_checkpoint(pl_sd)
del pl_sd
Expand All @@ -244,7 +248,7 @@ def read_state_dict(checkpoint_file, map_location=None): # pylint: disable=unuse
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
if checkpoint_info in checkpoints_loaded:
# use checkpoint cache
shared.log.info("Loading weights from cache")
shared.log.info("Model weights loading: from cache")
return checkpoints_loaded[checkpoint_info]
res = read_state_dict(checkpoint_info.filename)
timer.record("load")
Expand Down

0 comments on commit 0891b30

Please sign in to comment.