diff --git a/discord_tron_client/classes/image_manipulation/diffusion.py b/discord_tron_client/classes/image_manipulation/diffusion.py index be89a07..fd68c05 100644 --- a/discord_tron_client/classes/image_manipulation/diffusion.py +++ b/discord_tron_client/classes/image_manipulation/diffusion.py @@ -31,6 +31,7 @@ # ) from PIL import Image import torch, gc, logging, diffusers, transformers, os +from torch import OutOfMemoryError logger = logging.getLogger("DiffusionPipelineManager") logger.setLevel("DEBUG") @@ -106,6 +107,8 @@ def clear_pipeline(self, model_id: str) -> None: if model_id in self.pipelines: try: del self.pipelines[model_id] + if self.pipeline_runner.get("model") == model_id: + self.pipeline_runner["model"] = None self.clear_cuda_cache() except Exception as e: logger.error(f"Error when deleting pipe: {e}") @@ -407,7 +410,9 @@ def get_pipe( logger.info( f"Moving pipe to CUDA early, because no offloading is being used." ) + self.delete_pipes(keep_model=model_id) self.pipelines[model_id].to(self.device) + if config.enable_compile() and hasattr( self.pipelines[model_id], "unet" ): @@ -435,7 +440,11 @@ def get_pipe( ) else: logger.info(f"Keeping existing pipeline. Not creating any new ones.") + logger.info(f"Moving pipeline back to {self.device}") + self.delete_pipes(keep_model=model_id) self.pipelines[model_id].to(self.device) + logger.info(f"Moved pipeline back to {self.device}") + self.last_pipe_type[model_id] = pipe_type self.last_pipe_scheduler[model_id] = self.pipelines[model_id].config[ "scheduler" @@ -456,7 +465,7 @@ def get_pipe( def delete_pipes(self, keep_model: str = None): total_allowed_concurrent = hardware.get_concurrent_pipe_count() # Loop by a range of 0 through len(self.pipelines): - for model_id in list(self.pipelines.keys()): + for model_id in set(self.pipelines.keys()): if len(self.pipelines) > total_allowed_concurrent and ( keep_model is None or keep_model != model_id ): @@ -465,6 +474,14 @@ def delete_pipes(self, keep_model: str = None): ) self.pipelines[model_id].to("meta") del self.pipelines[model_id] + if self.pipeline_runner.get("model") == model_id: + self.pipeline_runner["model"] = None + if model_id in self.last_pipe_type: + del self.last_pipe_type[model_id] + if model_id in self.last_pipe_scheduler: + del self.last_pipe_scheduler[model_id] + if model_id in self.pipeline_versions: + del self.pipeline_versions[model_id] if model_id in self.last_pipe_scheduler: del self.last_pipe_scheduler[model_id] if model_id in self.last_pipe_type: @@ -472,6 +489,7 @@ def delete_pipes(self, keep_model: str = None): self.clear_cuda_cache() def clear_cuda_cache(self): + return None gc.collect() if config.get_cuda_cache_clear_toggle(): logger.info("Clearing the CUDA cache...") diff --git a/discord_tron_client/classes/image_manipulation/pipeline.py b/discord_tron_client/classes/image_manipulation/pipeline.py index e4281bc..ed2fd1b 100644 --- a/discord_tron_client/classes/image_manipulation/pipeline.py +++ b/discord_tron_client/classes/image_manipulation/pipeline.py @@ -70,11 +70,7 @@ def __init__( self.websocket = websocket self.model_config = model_config self.prompt_manager = None - self.pipeline_runner = { - "model": None, - "runner": None, - } - + async def reset_bar(self, discord_msg, websocket): # An object to manage a progress bar for Discord. main_loop = asyncio.get_event_loop() @@ -301,11 +297,12 @@ def _run_pipeline( ) # text2img workflow if ( - self.pipeline_runner["model"] is not None - and self.pipeline_runner["runner"] is not None - and self.pipeline_runner["model"] == user_model + getattr(self.pipeline_manager, 'pipeline_runner', {}).get("model") is not None + and getattr(self.pipeline_manager, 'pipeline_runner', {}).get("runner") is not None + and getattr(self.pipeline_manager, 'pipeline_runner', {}).get("model") == user_model ): - pipeline_runner = self.pipeline_runner["runner"] + logging.info("Using preserved pipeline_runner.") + pipeline_runner = getattr(self.pipeline_manager, 'pipeline_runner', {}).get("runner") elif ( type(pipe) is diffusers.StableDiffusionXLPipeline or "ptx0/s1" in user_model @@ -385,11 +382,13 @@ def _run_pipeline( logging.debug(f"Received type of pipeline: {type(pipe)}") pipeline_runner = runner_map["text2img"](pipeline=pipe) if ( - self.pipeline_runner["model"] is None - or self.pipeline_runner["model"] != user_model + getattr(self.pipeline_manager, 'pipeline_runner', {}).get("model") is None + or getattr(self.pipeline_manager, 'pipeline_runner', {}).get("model") != user_model ): - self.pipeline_runner["model"] = user_model - self.pipeline_runner["runner"] = pipeline_runner + if not hasattr(self.pipeline_manager, 'pipeline_runner'): + setattr(self.pipeline_manager, 'pipeline_runner', {}) + self.pipeline_manager.pipeline_runner["model"] = user_model + self.pipeline_manager.pipeline_runner["runner"] = pipeline_runner if image is None: preprocessed_images = pipeline_runner( prompt=positive_prompt, diff --git a/discord_tron_client/classes/image_manipulation/pipeline_runners/base_runner.py b/discord_tron_client/classes/image_manipulation/pipeline_runners/base_runner.py index 6ea6354..331fd73 100644 --- a/discord_tron_client/classes/image_manipulation/pipeline_runners/base_runner.py +++ b/discord_tron_client/classes/image_manipulation/pipeline_runners/base_runner.py @@ -17,6 +17,7 @@ def __init__(self, **kwargs): self.diffusion_manager = None if "pipeline" in kwargs: self.pipeline = kwargs["pipeline"] + print(f"initializing pipeline runner using {self.pipeline}") if "pipeline_manager" in kwargs: self.pipeline_manager = kwargs["pipeline_manager"] else: @@ -150,6 +151,7 @@ def load_adapter( else: # lycoris_wrapper.to(self.pipeline.transformer.device) lycoris_wrapper.apply_to() + logging.info("Moving Lycoris to GPU") lycoris_wrapper.to(device=self.pipeline_manager.device, dtype=self.pipeline_manager.torch_dtype) self.loaded_adapters[clean_adapter_name] = { "adapter_type": adapter_type, @@ -169,8 +171,12 @@ def clear_adapters(self): if config.get("adapter_type") == "lycoris": lycoris_wrapper = config.get("lycoris_wrapper") if not lycoris_wrapper: + logging.error(f"Failed to clear adapter {clean_adapter_name}") continue + logging.debug(f"Restoring lycoris wrapper for {clean_adapter_name}") lycoris_wrapper.restore() + lycoris_wrapper.to("meta") + logging.debug("Sent lycoris to the abyss, meta tensors.") self.loaded_adapters[clean_adapter_name] = None self.pipeline.unload_lora_weights() self.loaded_adapters = {} diff --git a/discord_tron_client/classes/image_manipulation/pipeline_runners/sd3_runner.py b/discord_tron_client/classes/image_manipulation/pipeline_runners/sd3_runner.py index 8dac23d..0ab0212 100644 --- a/discord_tron_client/classes/image_manipulation/pipeline_runners/sd3_runner.py +++ b/discord_tron_client/classes/image_manipulation/pipeline_runners/sd3_runner.py @@ -76,4 +76,9 @@ def __call__(self, **args): ) # Call the pipeline with arguments and return the images - return self.pipeline(**args).images + self.pipeline.to(self.pipeline_manager.device) + print(f"device: {self.pipeline.transformer.device}, {self.pipeline.vae.device}, {self.pipeline.text_encoder.device}") + result = self.pipeline(**args).images + self.clear_adapters() + + return result \ No newline at end of file