forked from vladmandic/automatic
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request vladmandic#3671 from vladmandic/dev
merge dev to master
- Loading branch information
Showing
84 changed files
with
100,005 additions
and
610 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
from typing import Union | ||
import os | ||
import re | ||
import logging | ||
from tqdm.rich import tqdm | ||
import torch | ||
import PIL | ||
import faiss | ||
import numpy as np | ||
import pandas as pd | ||
import transformers | ||
|
||
|
||
class ImageDB: | ||
# TODO index: quantize and train faiss index | ||
# TODO index: clip batch processing | ||
def __init__(self, | ||
name:str='db', | ||
fmt:str='json', | ||
cache_dir:str=None, | ||
dtype:torch.dtype=torch.float16, | ||
device:torch.device=torch.device('cpu'), | ||
model:str='openai/clip-vit-large-patch14', # 'facebook/dinov2-small' | ||
debug:bool=False, | ||
pbar:bool=True, | ||
): | ||
self.format = fmt | ||
self.name = name | ||
self.cache_dir = cache_dir | ||
self.processor: transformers.AutoImageProcessor = None | ||
self.model: transformers.AutoModel = None | ||
self.tokenizer = transformers.AutoTokenizer = None | ||
self.device: torch.device = device | ||
self.dtype: torch.dtype = dtype | ||
self.dimension = 768 if 'clip' in model else 384 | ||
self.debug = debug | ||
self.pbar = pbar | ||
self.repo = model | ||
self.df = pd.DataFrame([], columns=['filename', 'timestamp', 'metadata']) # image/metadata database | ||
self.index = faiss.IndexFlatL2(self.dimension) # embed database | ||
self.log = logging.getLogger(__name__) | ||
self.err = logging.getLogger(__name__).error | ||
self.log = logging.getLogger(__name__).info if self.debug else logging.getLogger(__name__).debug | ||
# self.init() | ||
# self.load() | ||
|
||
def __str__(self): | ||
return f'db: name="{self.name}" format={self.format} device={self.device} dtype={self.dtype} dimension={self.dimension} model="{self.repo}" records={len(self.df)} index={self.index.ntotal}' | ||
|
||
def init(self): # initialize models | ||
if self.processor is None or self.model is None: | ||
if 'clip' in self.repo: | ||
self.processor = transformers.CLIPImageProcessor.from_pretrained(self.repo, cache_dir=self.cache_dir) | ||
self.tokenizer = transformers.CLIPTokenizer.from_pretrained(self.repo, cache_dir=self.cache_dir) | ||
self.model = transformers.CLIPModel.from_pretrained(self.repo, cache_dir=self.cache_dir).to(device=self.device, dtype=self.dtype) | ||
elif 'dino' in self.repo: | ||
self.processor = transformers.AutoImageProcessor.from_pretrained(self.repo, cache_dir=self.cache_dir) | ||
self.model = transformers.AutoModel.from_pretrained(self.repo, cache_dir=self.cache_dir).to(device=self.device, dtype=self.dtype) | ||
else: | ||
self.err(f'db: model="{self.repo}" unknown') | ||
self.log(f'db: load model="{self.repo}" cache="{self.cache_dir}" device={self.device} dtype={self.dtype}') | ||
|
||
def load(self): # load db to disk | ||
if self.format == 'json' and os.path.exists(f'{self.name}.json'): | ||
self.df = pd.read_json(f'{self.name}.json') | ||
elif self.format == 'csv' and os.path.exists(f'{self.name}.csv'): | ||
self.df = pd.read_csv(f'{self.name}.csv') | ||
elif self.format == 'pickle' and os.path.exists(f'{self.name}.pkl'): | ||
self.df = pd.read_pickle(f'{self.name}.parquet') | ||
if os.path.exists(f'{self.name}.index'): | ||
self.index = faiss.read_index(f'{self.name}.index') | ||
if self.index.ntotal != len(self.df): | ||
self.err(f'db: index={self.index.ntotal} data={len(self.df)} mismatch') | ||
self.index = faiss.IndexFlatL2(self.dimension) | ||
self.df = pd.DataFrame([], columns=['filename', 'timestamp', 'metadata']) | ||
self.log(f'db: load data={len(self.df)} name={self.name} format={self.format} name={self.name}') | ||
|
||
def save(self): # save db to disk | ||
if self.format == 'json': | ||
self.df.to_json(f'{self.name}.json') | ||
elif self.format == 'csv': | ||
self.df.to_csv(f'{self.name}.csv') | ||
elif self.format == 'pickle': | ||
self.df.to_pickle(f'{self.name}.pkl') | ||
faiss.write_index(self.index, f'{self.name}.index') | ||
self.log(f'db: save data={len(self.df)} name={self.name} format={self.format} name={self.name}') | ||
|
||
def normalize(self, embed) -> np.ndarray: # normalize embed before using it | ||
embed = embed.detach().float().cpu().numpy() | ||
faiss.normalize_L2(embed) | ||
return embed | ||
|
||
def embedding(self, query: Union[PIL.Image.Image | str]) -> np.ndarray: # calculate embed for prompt or image | ||
if self.processor is None or self.model is None: | ||
self.err('db: model not loaded') | ||
if isinstance(query, str) and os.path.exists(query): | ||
query = PIL.Image.open(query).convert('RGB') | ||
self.model = self.model.to(self.device) | ||
with torch.no_grad(): | ||
if 'clip' in self.repo: | ||
if isinstance(query, str): | ||
processed = self.tokenizer(text=query, padding=True, return_tensors="pt").to(device=self.device) | ||
results = self.model.get_text_features(**processed) | ||
else: | ||
processed = self.processor(images=query, return_tensors="pt").to(device=self.device, dtype=self.dtype) | ||
results = self.model.get_image_features(**processed) | ||
elif 'dino' in self.repo: | ||
processed = self.processor(images=query, return_tensors="pt").to(device=self.device, dtype=self.dtype) | ||
results = self.model(**processed) | ||
results = results.last_hidden_state.mean(dim=1) | ||
else: | ||
self.err(f'db: model="{self.repo}" unknown') | ||
return None | ||
return self.normalize(results) | ||
|
||
def add(self, embed, filename=None, metadata=None): # add embed to db | ||
rec = pd.DataFrame([{'filename': filename, 'timestamp': pd.Timestamp.now(), 'metadata': metadata}]) | ||
if len(self.df) > 0: | ||
self.df = pd.concat([self.df, rec], ignore_index=True) | ||
else: | ||
self.df = rec | ||
self.index.add(embed) | ||
|
||
def search(self, filename: str = None, metadata: str = None, embed: np.ndarray = None, k=10, d=1.0): # search by filename/metadata/prompt-embed/image-embed | ||
def dct(record: pd.DataFrame, mode: str, distance: float = None): | ||
if distance is not None: | ||
return {'type': mode, 'filename': record[1]['filename'], 'metadata': record[1]['metadata'], 'distance': round(distance, 2)} | ||
else: | ||
return {'type': mode, 'filename': record[1]['filename'], 'metadata': record[1]['metadata']} | ||
|
||
if self.index.ntotal == 0: | ||
return | ||
self.log(f'db: search k={k} d={d}') | ||
if embed is not None: | ||
distances, indexes = self.index.search(embed, k) | ||
records = self.df.iloc[indexes[0]] | ||
for record, distance in zip(records.iterrows(), distances[0]): | ||
if d <= 0 or distance <= d: | ||
yield dct(record, distance=distance, mode='embed') | ||
if filename is not None: | ||
records = self.df[self.df['filename'].str.contains(filename, na=False, case=False)] | ||
for record in records.iterrows(): | ||
yield dct(record, mode='filename') | ||
if metadata is not None: | ||
records = self.df[self.df['metadata'].str.contains(filename, na=False, case=False)] | ||
for record in records.iterrows(): | ||
yield dct(record, mode='metadata') | ||
|
||
def decode(self, s: bytes): # decode byte-encoded exif metadata | ||
remove_prefix = lambda text, prefix: text[len(prefix):] if text.startswith(prefix) else text # pylint: disable=unnecessary-lambda-assignment | ||
for encoding in ['utf-8', 'utf-16', 'ascii', 'latin_1', 'cp1252', 'cp437']: # try different encodings | ||
try: | ||
s = remove_prefix(s, b'UNICODE') | ||
s = remove_prefix(s, b'ASCII') | ||
s = remove_prefix(s, b'\x00') | ||
val = s.decode(encoding, errors="strict") | ||
val = re.sub(r'[\x00-\x09\n\s\s+]', '', val).strip() # remove remaining special characters, new line breaks, and double empty spaces | ||
if len(val) == 0: # remove empty strings | ||
val = None | ||
return val | ||
except Exception: | ||
pass | ||
return None | ||
|
||
def metadata(self, image: PIL.Image.Image): # get exif metadata from image | ||
exif = image._getexif() # pylint: disable=protected-access | ||
if exif is None: | ||
return '' | ||
for k, v in exif.items(): | ||
if k == 37510: # comment | ||
return self.decode(v) | ||
return '' | ||
|
||
def image(self, filename: str, image=None): # add file/image to db | ||
try: | ||
if image is None: | ||
image = PIL.Image.open(filename) | ||
image.load() | ||
embed = self.embedding(image.convert('RGB')) | ||
metadata = self.metadata(image) | ||
image.close() | ||
self.add(embed, filename=filename, metadata=metadata) | ||
except Exception as _e: | ||
# self.err(f'db: {str(_e)}') | ||
pass | ||
|
||
def folder(self, folder: str): # add all files from folder to db | ||
files = [] | ||
for root, _subdir, _files in os.walk(folder): | ||
for f in _files: | ||
files.append(os.path.join(root, f)) | ||
if self.pbar: | ||
for f in tqdm(files): | ||
self.image(filename=f) | ||
else: | ||
for f in files: | ||
self.image(filename=f) | ||
|
||
def offload(self): # offload model to cpu | ||
if self.model is not None: | ||
self.model = self.model.to('cpu') | ||
|
||
|
||
if __name__ == '__main__': | ||
import time | ||
import argparse | ||
logging.basicConfig(level=logging.INFO) | ||
parser = argparse.ArgumentParser(description = 'image-search') | ||
group = parser.add_mutually_exclusive_group(required=True) | ||
group.add_argument('--search', action='store_true', help='run search') | ||
group.add_argument('--index', action='store_true', help='run indexing') | ||
parser.add_argument('--db', default='db', help='database name') | ||
parser.add_argument('--model', default='openai/clip-vit-large-patch14', help='huggingface model') | ||
parser.add_argument('--cache', default='/mnt/models/huggingface', help='cache folder') | ||
parser.add_argument('input', nargs='*', default=os.getcwd()) | ||
args = parser.parse_args() | ||
|
||
db = ImageDB( | ||
name=args.db, | ||
model=args.model, # 'facebook/dinov2-small' | ||
cache_dir=args.cache, | ||
dtype=torch.bfloat16, | ||
device=torch.device('cuda'), | ||
debug=True, | ||
pbar=True, | ||
) | ||
db.init() | ||
db.load() | ||
print(db) | ||
|
||
if args.index: | ||
t0 = time.time() | ||
if len(args.input) > 0: | ||
for fn in args.input: | ||
if os.path.isfile(fn): | ||
db.image(filename=fn) | ||
elif os.path.isdir(fn): | ||
db.folder(folder=fn) | ||
t1 = time.time() | ||
print('index', t1-t0) | ||
db.save() | ||
db.offload() | ||
|
||
if args.search: | ||
for ref in args.input: | ||
emb = db.embedding(ref) | ||
res = db.search(filename=ref, metadata=ref, embed=emb, k=10, d=0) | ||
for r in res: | ||
print(ref, r) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
{ | ||
"_class_name": "StableDiffusionXLImg2ImgPipeline", | ||
"_diffusers_version": "0.19.0.dev0", | ||
"force_zeros_for_empty_prompt": false, | ||
"add_watermarker": null, | ||
"requires_aesthetics_score": true, | ||
"scheduler": [ | ||
"diffusers", | ||
"EulerDiscreteScheduler" | ||
], | ||
"text_encoder": [ | ||
null, | ||
null | ||
], | ||
"text_encoder_2": [ | ||
"transformers", | ||
"CLIPTextModelWithProjection" | ||
], | ||
"tokenizer": [ | ||
null, | ||
null | ||
], | ||
"tokenizer_2": [ | ||
"transformers", | ||
"CLIPTokenizer" | ||
], | ||
"unet": [ | ||
"diffusers", | ||
"UNet2DConditionModel" | ||
], | ||
"vae": [ | ||
"diffusers", | ||
"AutoencoderKL" | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
{ | ||
"_class_name": "EulerDiscreteScheduler", | ||
"_diffusers_version": "0.19.0.dev0", | ||
"beta_end": 0.012, | ||
"beta_schedule": "scaled_linear", | ||
"beta_start": 0.00085, | ||
"clip_sample": false, | ||
"interpolation_type": "linear", | ||
"num_train_timesteps": 1000, | ||
"prediction_type": "epsilon", | ||
"sample_max_value": 1.0, | ||
"set_alpha_to_one": false, | ||
"skip_prk_steps": true, | ||
"steps_offset": 1, | ||
"timestep_spacing": "leading", | ||
"trained_betas": null, | ||
"use_karras_sigmas": false | ||
} |
Oops, something went wrong.