Skip to content

Commit

Permalink
add better memory management, allowing swap between LTX Video and SD …
Browse files Browse the repository at this point in the history
…3.5 Medium on 4090
  • Loading branch information
bghira committed Jan 21, 2025
1 parent 2648bb4 commit 72f7ea6
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 15 deletions.
20 changes: 19 additions & 1 deletion discord_tron_client/classes/image_manipulation/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
# )
from PIL import Image
import torch, gc, logging, diffusers, transformers, os
from torch import OutOfMemoryError

logger = logging.getLogger("DiffusionPipelineManager")
logger.setLevel("DEBUG")
Expand Down Expand Up @@ -106,6 +107,8 @@ def clear_pipeline(self, model_id: str) -> None:
if model_id in self.pipelines:
try:
del self.pipelines[model_id]
if self.pipeline_runner.get("model") == model_id:
self.pipeline_runner["model"] = None
self.clear_cuda_cache()
except Exception as e:
logger.error(f"Error when deleting pipe: {e}")
Expand Down Expand Up @@ -407,7 +410,9 @@ def get_pipe(
logger.info(
f"Moving pipe to CUDA early, because no offloading is being used."
)
self.delete_pipes(keep_model=model_id)
self.pipelines[model_id].to(self.device)

if config.enable_compile() and hasattr(
self.pipelines[model_id], "unet"
):
Expand Down Expand Up @@ -435,7 +440,11 @@ def get_pipe(
)
else:
logger.info(f"Keeping existing pipeline. Not creating any new ones.")
logger.info(f"Moving pipeline back to {self.device}")
self.delete_pipes(keep_model=model_id)
self.pipelines[model_id].to(self.device)
logger.info(f"Moved pipeline back to {self.device}")

self.last_pipe_type[model_id] = pipe_type
self.last_pipe_scheduler[model_id] = self.pipelines[model_id].config[
"scheduler"
Expand All @@ -456,7 +465,7 @@ def get_pipe(
def delete_pipes(self, keep_model: str = None):
total_allowed_concurrent = hardware.get_concurrent_pipe_count()
# Loop by a range of 0 through len(self.pipelines):
for model_id in list(self.pipelines.keys()):
for model_id in set(self.pipelines.keys()):
if len(self.pipelines) > total_allowed_concurrent and (
keep_model is None or keep_model != model_id
):
Expand All @@ -465,13 +474,22 @@ def delete_pipes(self, keep_model: str = None):
)
self.pipelines[model_id].to("meta")
del self.pipelines[model_id]
if self.pipeline_runner.get("model") == model_id:
self.pipeline_runner["model"] = None
if model_id in self.last_pipe_type:
del self.last_pipe_type[model_id]
if model_id in self.last_pipe_scheduler:
del self.last_pipe_scheduler[model_id]
if model_id in self.pipeline_versions:
del self.pipeline_versions[model_id]
if model_id in self.last_pipe_scheduler:
del self.last_pipe_scheduler[model_id]
if model_id in self.last_pipe_type:
del self.last_pipe_type[model_id]
self.clear_cuda_cache()

def clear_cuda_cache(self):
return None
gc.collect()
if config.get_cuda_cache_clear_toggle():
logger.info("Clearing the CUDA cache...")
Expand Down
25 changes: 12 additions & 13 deletions discord_tron_client/classes/image_manipulation/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,7 @@ def __init__(
self.websocket = websocket
self.model_config = model_config
self.prompt_manager = None
self.pipeline_runner = {
"model": None,
"runner": None,
}


async def reset_bar(self, discord_msg, websocket):
# An object to manage a progress bar for Discord.
main_loop = asyncio.get_event_loop()
Expand Down Expand Up @@ -301,11 +297,12 @@ def _run_pipeline(
)
# text2img workflow
if (
self.pipeline_runner["model"] is not None
and self.pipeline_runner["runner"] is not None
and self.pipeline_runner["model"] == user_model
getattr(self.pipeline_manager, 'pipeline_runner', {}).get("model") is not None
and getattr(self.pipeline_manager, 'pipeline_runner', {}).get("runner") is not None
and getattr(self.pipeline_manager, 'pipeline_runner', {}).get("model") == user_model
):
pipeline_runner = self.pipeline_runner["runner"]
logging.info("Using preserved pipeline_runner.")
pipeline_runner = getattr(self.pipeline_manager, 'pipeline_runner', {}).get("runner")
elif (
type(pipe) is diffusers.StableDiffusionXLPipeline
or "ptx0/s1" in user_model
Expand Down Expand Up @@ -385,11 +382,13 @@ def _run_pipeline(
logging.debug(f"Received type of pipeline: {type(pipe)}")
pipeline_runner = runner_map["text2img"](pipeline=pipe)
if (
self.pipeline_runner["model"] is None
or self.pipeline_runner["model"] != user_model
getattr(self.pipeline_manager, 'pipeline_runner', {}).get("model") is None
or getattr(self.pipeline_manager, 'pipeline_runner', {}).get("model") != user_model
):
self.pipeline_runner["model"] = user_model
self.pipeline_runner["runner"] = pipeline_runner
if not hasattr(self.pipeline_manager, 'pipeline_runner'):
setattr(self.pipeline_manager, 'pipeline_runner', {})
self.pipeline_manager.pipeline_runner["model"] = user_model
self.pipeline_manager.pipeline_runner["runner"] = pipeline_runner
if image is None:
preprocessed_images = pipeline_runner(
prompt=positive_prompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, **kwargs):
self.diffusion_manager = None
if "pipeline" in kwargs:
self.pipeline = kwargs["pipeline"]
print(f"initializing pipeline runner using {self.pipeline}")
if "pipeline_manager" in kwargs:
self.pipeline_manager = kwargs["pipeline_manager"]
else:
Expand Down Expand Up @@ -150,6 +151,7 @@ def load_adapter(
else:
# lycoris_wrapper.to(self.pipeline.transformer.device)
lycoris_wrapper.apply_to()
logging.info("Moving Lycoris to GPU")
lycoris_wrapper.to(device=self.pipeline_manager.device, dtype=self.pipeline_manager.torch_dtype)
self.loaded_adapters[clean_adapter_name] = {
"adapter_type": adapter_type,
Expand All @@ -169,8 +171,12 @@ def clear_adapters(self):
if config.get("adapter_type") == "lycoris":
lycoris_wrapper = config.get("lycoris_wrapper")
if not lycoris_wrapper:
logging.error(f"Failed to clear adapter {clean_adapter_name}")
continue
logging.debug(f"Restoring lycoris wrapper for {clean_adapter_name}")
lycoris_wrapper.restore()
lycoris_wrapper.to("meta")
logging.debug("Sent lycoris to the abyss, meta tensors.")
self.loaded_adapters[clean_adapter_name] = None
self.pipeline.unload_lora_weights()
self.loaded_adapters = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,9 @@ def __call__(self, **args):
)

# Call the pipeline with arguments and return the images
return self.pipeline(**args).images
self.pipeline.to(self.pipeline_manager.device)
print(f"device: {self.pipeline.transformer.device}, {self.pipeline.vae.device}, {self.pipeline.text_encoder.device}")
result = self.pipeline(**args).images
self.clear_adapters()

return result

0 comments on commit 72f7ea6

Please sign in to comment.