Skip to content

Commit

Permalink
add metadata cache
Browse files Browse the repository at this point in the history
  • Loading branch information
vladmandic committed Jun 13, 2023
1 parent 1afb7c6 commit eb47acf
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 7 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# defaults
__pycache__
/cache.json
/metadata.json
/config.json
/ui-config.json
/params.txt
/styles.csv
/ui-config.json
/user.css
/webui-user.bat
/webui-user.sh
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Change Log for SD.Next

## Update for 06/13/2023

- new cache for models/lora/lyco metadata: `metadata.json`
drastically reduces disk access on app startup

## Update for 06/12/2023

- updated ui labels and hints to improve clarity and provide some extra info
Expand All @@ -11,6 +16,7 @@
as some extensions are loading packages directly from their preload sequence
which was preventing some optimizations to take effect
- updated **settings** tab functionality, thanks @gegell
with real-time monitor for all new and/or updated settings
- **launcher** will now warn if application owned files are modified
you are free to add any user files, but do not modify app files unless you're sure in what you're doing
- add more profiling for scripts/extensions so you can see what takes time
Expand Down
26 changes: 21 additions & 5 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
checkpoint_aliases = {}
checkpoints_loaded = collections.OrderedDict()
skip_next_load = False
sd_metadata_file = os.path.join(paths.data_path, "metadata.json")
sd_metadata = None


class CheckpointInfo:
Expand Down Expand Up @@ -237,22 +239,36 @@ def get_state_dict_from_checkpoint(pl_sd):

def read_metadata_from_safetensors(filename):
import json
global sd_metadata # pylint: disable=global-statement
if sd_metadata is None:
if not os.path.isfile(sd_metadata_file):
sd_metadata = {}
else:
with open(sd_metadata_file, "r", encoding="utf8") as file:
sd_metadata = json.load(file)
res = sd_metadata.get(filename, None)
if res is not None:
return res

res = {}
with open(filename, mode="rb") as file:
metadata_len = file.read(8)
metadata_len = int.from_bytes(metadata_len, "little")
json_start = file.read(2)
assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
json_data = json_start + file.read(metadata_len-2)
json_obj = json.loads(json_data)
res = {}
for k, v in json_obj.get("__metadata__", {}).items():
res[k] = v
if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
except Exception:
pass
return res
sd_metadata[filename] = res
with open(sd_metadata_file, "w", encoding="utf8") as file:
json.dump(sd_metadata, file, indent=4)
return res


def read_state_dict(checkpoint_file, map_location=None): # pylint: disable=unused-argument
Expand Down Expand Up @@ -558,7 +574,7 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False):
load_dict = shared.opts.sd_model_dict != model_data.sd_dict
global skip_next_load # pylint: disable=global-statement
if skip_next_load:
shared.log.debug('Reload model weights skip')
shared.log.debug('Load model weights skip')
skip_next_load = False
return
from modules import lowvram, sd_hijack
Expand All @@ -568,7 +584,7 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False):
shared.log.debug(f'Model dict: existing={sd_model is not None} target={checkpoint_info.filename} info={info}')
else:
model_data.sd_dict = 'None'
shared.log.debug(f'Reload model weights: existing={sd_model is not None} target={checkpoint_info.filename} info={info}')
shared.log.debug(f'Load model weights: existing={sd_model is not None} target={checkpoint_info.filename} info={info}')
if not sd_model:
sd_model = model_data.sd_model
if sd_model is None: # previous model load failed
Expand Down Expand Up @@ -606,7 +622,7 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False):
try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
except Exception:
shared.log.error("Failed to load checkpoint, restoring previous")
shared.log.error("Load model failed: restoring previous")
load_model_weights(sd_model, current_checkpoint_info, None, timer)
finally:
sd_hijack.model_hijack.hijack(sd_model)
Expand Down
2 changes: 1 addition & 1 deletion wiki
Submodule wiki updated from 80fca5 to 383fa4

0 comments on commit eb47acf

Please sign in to comment.