Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SD3 Lora page filter - detection not implemented #16299

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions extensions-builtin/Lora/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SdVersion(enum.Enum):
SD1 = 2
SD2 = 3
SDXL = 4
SD3 = 5


class NetworkOnDisk:
Expand Down Expand Up @@ -59,6 +60,7 @@ def read_metadata():
self.sd_version = self.detect_version()

def detect_version(self):
# TODO: SdVersion.SD3 detection
if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
return SdVersion.SDXL
elif str(self.metadata.get('ss_v2', "")) == "True":
Expand Down
3 changes: 2 additions & 1 deletion extensions-builtin/Lora/scripts/lora_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def before_ui():
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
"lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'),
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL", "SD3"]}),
"TEMP_setting_sd3_lora_filter": shared.OptionInfo(["SD1", "Unknown"], "For SD3 model also show Lora of other sd version", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL", "Unknown"]}).info('Temporary setting until SD3 Lora detection is implemented'),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
Expand Down
2 changes: 1 addition & 1 deletion extensions-builtin/Lora/ui_edit_user_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def generate_random_prompt_from_tags(self, tags):
def create_extra_default_items_in_left_column(self):

# this would be a lot better as gr.Radio but I can't make it work
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'SD3', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)

def create_editor(self):
self.create_default_editor_elems()
Expand Down
21 changes: 19 additions & 2 deletions extensions-builtin/Lora/ui_extra_networks_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import network
import networks

from modules import shared, ui_extra_networks
from modules import shared, ui_extra_networks, sd_models_types
from modules.ui_extra_networks import quote_js
from ui_edit_user_metadata import LoraUserMetadataEditor

Expand Down Expand Up @@ -62,8 +62,14 @@ def create_item(self, name, index=None, enable_filter=True):

if shared.opts.lora_show_all or not enable_filter or not shared.sd_model:
pass
elif shared.sd_model.is_sd3:
# TODO: add proper SD3 filtering when detection is implemented
# TODO: move after Unknown block when implemented
if sd_version is network.SdVersion.SD3 or sd_version.name in shared.opts.TEMP_setting_sd3_lora_filter:
return item
return None
elif sd_version == network.SdVersion.Unknown:
model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1
model_version = self.sd_to_lora_version(shared.sd_model)
if model_version.name in shared.opts.lora_hide_unknown_for_versions:
return None
elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:
Expand All @@ -88,3 +94,14 @@ def allowed_directories_for_previews(self):

def create_user_metadata_editor(self, ui, tabname):
return LoraUserMetadataEditor(ui, tabname, self)

@staticmethod
def sd_to_lora_version(sd_model: sd_models_types.WebuiSdModel):
if sd_model.is_sd1:
return network.SdVersion.SD1
elif sd_model.is_sd2:
return network.SdVersion.SD2
elif sd_model.is_sdxl:
return network.SdVersion.SDXL
elif sd_model.is_sd3:
return network.SdVersion.SD3