Skip to content

Commit

Permalink
Merge pull request #53 from sipherxyz/develop
Browse files Browse the repository at this point in the history
Update model download with sha256 validate
  • Loading branch information
tungnguyensipher authored Oct 30, 2024
2 parents 4be544a + e7e9e58 commit 4405434
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 118 deletions.
14 changes: 5 additions & 9 deletions modules/inpaint/lama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import folder_paths
import comfy.model_management as model_management

from ...model_utils import download_model
from ...model_utils import download_file


lama = None
Expand All @@ -33,14 +33,10 @@ def pad_tensor_to_modulo(img, mod):
def load_model():
global lama
if lama is None:
files = download_model(
model_path=model_dir,
model_url=model_url,
ext_filter=[".pt"],
download_name="big-lama.pt",
)

lama = torch.jit.load(files[0], map_location="cpu")
model_path = os.path.join(model_dir, "big-lama.pt")
download_file(model_url, model_path, model_sha)

lama = torch.jit.load(model_path, map_location="cpu")
lama.eval()

return lama
Expand Down
4 changes: 3 additions & 1 deletion modules/interrogate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from .blip_node import BlipLoader, BlipCaption
from .blip_node import BlipLoader, BlipCaption, DownloadAndLoadBlip
from .danbooru import DeepDanbooruCaption

NODE_CLASS_MAPPINGS = {
"BLIPLoader": BlipLoader,
"BLIPCaption": BlipCaption,
"DownloadAndLoadBlip": DownloadAndLoadBlip,
"DeepDanbooruCaption": DeepDanbooruCaption,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"BLIPLoader": "BLIP Loader",
"BLIPCaption": "BLIP Caption",
"DownloadAndLoadBlip": "Download and Load BLIP Model",
"DeepDanbooruCaption": "Deep Danbooru Caption",
}

Expand Down
62 changes: 41 additions & 21 deletions modules/interrogate/blip_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,25 @@
import folder_paths
from comfy.model_management import text_encoder_device, text_encoder_offload_device, soft_empty_cache

from ..model_utils import download_model
from ..model_utils import download_file
from ..utils import tensor2pil

blips = {}
blip_size = 384
gpu = text_encoder_device()
cpu = text_encoder_offload_device()
model_url = (
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth"
)
model_dir = os.path.join(folder_paths.models_dir, "blip")
models = {
"model_base_caption_capfilt_large.pth": {
"url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth",
"sha": "96ac8749bd0a568c274ebe302b3a3748ab9be614c737f3d8c529697139174086",
},
"model_base_capfilt_large.pth": {
"url": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth",
"sha": "8f5187458d4d47bb87876faf3038d5947eff17475edf52cf47b62e84da0b235f",
},
}


folder_paths.folder_names_and_paths["blip"] = (
[model_dir],
Expand Down Expand Up @@ -90,12 +98,7 @@ def join_caption(caption, prefix, suffix):

def blip_caption(model, image, min_length, max_length):
image = tensor2pil(image)

if "transformers==4.26.1" in packages(True):
print("Using Legacy `transformImaage()`")
tensor = transformImage_legacy(image)
else:
tensor = transformImage(image)
tensor = transformImage(image)

with torch.no_grad():
caption = model.generate(
Expand All @@ -122,8 +125,32 @@ def INPUT_TYPES(s):
CATEGORY = "Art Venture/Captioning"

def load_blip(self, model_name):
model = load_blip(model_name)
return (model,)
return (load_blip(model_name),)


class DownloadAndLoadBlip:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (list(models.keys()),),
},
}

RETURN_TYPES = ("BLIP_MODEL",)
FUNCTION = "download_and_load_blip"
CATEGORY = "Art Venture/Captioning"

def download_and_load_blip(self, model_name):
if model_name not in folder_paths.get_filename_list("blip"):
model_info = models[model_name]
download_file(
model_info["url"],
os.path.join(model_dir, model_name),
model_info["sha"],
)

return (load_blip(model_name),)


class BlipCaption:
Expand Down Expand Up @@ -173,15 +200,8 @@ def blip_caption(
return ([join_caption("", prefix, suffix)],)

if blip_model is None:
ckpts = folder_paths.get_filename_list("blip")
if len(ckpts) == 0:
ckpts = download_model(
model_path=model_dir,
model_url=model_url,
ext_filter=[".pth"],
download_name="model_base_caption_capfilt_large.pth",
)
blip_model = load_blip(ckpts[0])
downloader = DownloadAndLoadBlip()
blip_model = downloader.download_and_load_blip("model_base_caption_capfilt_large.pth")[0]

device = gpu if device_mode != "CPU" else cpu
blip_model = blip_model.to(device)
Expand Down
19 changes: 8 additions & 11 deletions modules/interrogate/danbooru.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,33 @@
from comfy.model_management import text_encoder_device, text_encoder_offload_device, soft_empty_cache

from ..image_utils import resize_image
from ..model_utils import download_model
from ..model_utils import download_file
from ..utils import is_junction, tensor2pil
from .blip_node import join_caption

danbooru = None
blip_size = 384
gpu = text_encoder_device()
cpu = text_encoder_offload_device()
model_dir = os.path.join(folder_paths.models_dir, "blip")
model_url = "https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt"
model_sha = "3841542cda4dd037da12a565e854b3347bb2eec8fbcd95ea3941b2c68990a355"
re_special = re.compile(r"([\\()])")


def load_danbooru(device_mode):
global danbooru
if danbooru is None:
blip_dir = os.path.join(folder_paths.models_dir, "blip")
if not os.path.exists(blip_dir) and not is_junction(blip_dir):
os.makedirs(blip_dir, exist_ok=True)
if not os.path.exists(model_dir) and not is_junction(model_dir):
os.makedirs(model_dir, exist_ok=True)

files = download_model(
model_path=blip_dir,
model_url=model_url,
ext_filter=[".pt"],
download_name="model-resnet_custom_v3.pt",
)
model_path = os.path.join(model_dir, "model-resnet_custom_v3.pt")
download_file(model_url, model_path, model_sha)

from .models.deepbooru_model import DeepDanbooruModel

danbooru = DeepDanbooruModel()
danbooru.load_state_dict(torch.load(files[0], map_location="cpu"))
danbooru.load_state_dict(torch.load(model_path, map_location="cpu"))
danbooru.eval()

if device_mode != "CPU":
Expand Down
4 changes: 3 additions & 1 deletion modules/isnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from .segmenter import ISNetLoader, ISNetSegment
from .segmenter import ISNetLoader, ISNetSegment, DownloadISNetModel

NODE_CLASS_MAPPINGS = {
"ISNetLoader": ISNetLoader,
"ISNetSegment": ISNetSegment,
"DownloadISNetModel": DownloadISNetModel,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ISNetLoader": "ISNet Loader",
"ISNetSegment": "ISNet Segment",
"DownloadISNetModel": "Download and Load ISNet Model",
}

__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
69 changes: 46 additions & 23 deletions modules/isnet/segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,30 @@
import comfy.model_management as model_management
import comfy.utils

from ..model_utils import download_model
from ..utils import pil2tensor, tensor2pil, numpy2pil
from ..model_utils import download_file
from ..utils import pil2tensor, tensor2pil
from ..logger import logger


isnets = {}
cache_size = [1024, 1024]
gpu = model_management.get_torch_device()
cpu = torch.device("cpu")
model_dir = os.path.join(folder_paths.models_dir, "isnet")
model_url = "https://huggingface.co/NimaBoscarino/IS-Net_DIS-general-use/resolve/main/isnet-general-use.pth"
cache_size = [1024, 1024]
models = {
"isnet-general-use.pth": {
"url": "https://huggingface.co/NimaBoscarino/IS-Net_DIS-general-use/resolve/main/isnet-general-use.pth",
"sha": "9e1aafea58f0b55d0c35077e0ceade6ba1ba2bce372fd4f8f77215391f3fac13",
},
"isnetis.pth": {
"url": "https://github.com/Sanster/models/releases/download/isnetis/isnetis.pth",
"sha": "90a970badbd99ca7839b4e0beb09a36565d24edba7e4a876de23c761981e79e0",
},
"RMBG-1.4.bin": {
"url": "https://huggingface.co/briaai/RMBG-1.4/resolve/main/pytorch_model.bin",
"sha": "59569acdb281ac9fc9f78f9d33b6f9f17f68e25086b74f9025c35bb5f2848967",
},
}

folder_paths.folder_names_and_paths["isnet"] = (
[model_dir],
Expand Down Expand Up @@ -134,23 +147,40 @@ def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("isnet"),),
"model_override": ("STRING", {"default": "None"}),
},
}

RETURN_TYPES = ("ISNET_MODEL",)
FUNCTION = "load_isnet"
CATEGORY = "Art Venture/Segmentation"

def load_isnet(self, model_name, model_override="None"):
if model_override != "None":
if model_override not in folder_paths.get_filename_list("isnet"):
logger.warning(f"Model override {model_override} not found. Use {model_name} instead.")
else:
model_name = model_override
def load_isnet(self, model_name):
return (load_isnet_model(model_name),)


class DownloadISNetModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (list(models.keys()),),
},
}

RETURN_TYPES = ("ISNET_MODEL",)
FUNCTION = "download_isnet"
CATEGORY = "Art Venture/Segmentation"

def download_isnet(self, model_name):
if model_name not in folder_paths.get_filename_list("isnet"):
model_info = models[model_name]
download_file(
model_info["url"],
os.path.join(model_dir, model_name),
model_info["sha"],
)

model = load_isnet_model(model_name)
return (model,)
return (load_isnet_model(model_name),)


class ISNetSegment:
Expand Down Expand Up @@ -179,15 +209,8 @@ def segment_isnet(self, images: torch.Tensor, threshold, device_mode="AUTO", ena
return (images, masks)

if isnet_model is None:
ckpts = folder_paths.get_filename_list("isnet")
if len(ckpts) == 0:
ckpts = download_model(
model_path=model_dir,
model_url=model_url,
ext_filter=[".pth"],
download_name="isnet-general-use.pth",
)
isnet_model = load_isnet_model(ckpts[0])
downloader = DownloadISNetModel()
isnet_model = downloader.download_isnet("isnet-general-use.pth")[0]

device = gpu if device_mode != "CPU" else cpu
isnet_model = isnet_model.to(device)
Expand All @@ -198,7 +221,7 @@ def segment_isnet(self, images: torch.Tensor, threshold, device_mode="AUTO", ena
for image in images:
mask = predict(isnet_model, image, device)
mask_im = tensor2pil(mask.permute(1, 2, 0))
cropped = Image.new("RGBA", mask_im.size, (0,0,0,0))
cropped = Image.new("RGBA", mask_im.size, (0, 0, 0, 0))
cropped.paste(tensor2pil(image), mask=mask_im)

masks.append(mask)
Expand Down
Loading

0 comments on commit 4405434

Please sign in to comment.