From 49124e85753e87110276f27edd22b6f19b0e3f85 Mon Sep 17 00:00:00 2001 From: BetaDoggo Date: Mon, 13 Jan 2025 15:28:56 -0500 Subject: [PATCH 1/2] Improve compatibility with custom tag lists --- ai_diffusion/ui/autocomplete.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/ai_diffusion/ui/autocomplete.py b/ai_diffusion/ui/autocomplete.py index 4c96b19b8..6b88cb1f5 100644 --- a/ai_diffusion/ui/autocomplete.py +++ b/ai_diffusion/ui/autocomplete.py @@ -175,19 +175,21 @@ def _reload_tag_model(self): with tag_path.open("r", encoding="utf-8") as f: csv_reader = csv.reader(f) - # skip header line - next(csv_reader) for tag, type_str, count, _aliases in csv_reader: - tag = tag.replace("_", " ") - tag_type = TagType(int(type_str)) - count = int(count) - count_str = str(count) - if count > 1_000_000: - count_str = f"{count/1_000_000:.0f}m" - elif count > 1_000: - count_str = f"{count/1_000:.0f}k" - meta = f"{tag_name} {count_str}" - all_tags.append(TagItem(tag, tag_type, count, meta)) + if type_str.isdigit(): # skip header rows if they exist + tag = tag.replace("_", " ") + try: + tag_type = TagType(int(type_str)) + except: # default to general category if category is not recognised + tag_type = TagType(0) + count = int(count) + count_str = str(count) + if count > 1_000_000: + count_str = f"{count/1_000_000:.0f}m" + elif count > 1_000: + count_str = f"{count/1_000:.0f}k" + meta = f"{tag_name} {count_str}" + all_tags.append(TagItem(tag, tag_type, count, meta)) sorted_tags = sorted(all_tags, key=lambda x: x.count, reverse=True) seen = set() From 4c4b99970b042002f4cdbe06027fe6e9c3276bd1 Mon Sep 17 00:00:00 2001 From: BetaDoggo Date: Wed, 15 Jan 2025 11:50:31 -0500 Subject: [PATCH 2/2] fix formatting --- ai_diffusion/ui/autocomplete.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ai_diffusion/ui/autocomplete.py b/ai_diffusion/ui/autocomplete.py index 6b88cb1f5..362011a87 100644 --- a/ai_diffusion/ui/autocomplete.py +++ b/ai_diffusion/ui/autocomplete.py @@ -176,11 +176,11 @@ def _reload_tag_model(self): with tag_path.open("r", encoding="utf-8") as f: csv_reader = csv.reader(f) for tag, type_str, count, _aliases in csv_reader: - if type_str.isdigit(): # skip header rows if they exist + if type_str.isdigit(): # skip header rows if they exist tag = tag.replace("_", " ") try: tag_type = TagType(int(type_str)) - except: # default to general category if category is not recognised + except: # default to general category if category is not recognised tag_type = TagType(0) count = int(count) count_str = str(count)