diff --git a/modules/processing_vae.py b/modules/processing_vae.py index faaacb21e..4d988e73b 100644 --- a/modules/processing_vae.py +++ b/modules/processing_vae.py @@ -58,16 +58,11 @@ def full_vqgan_decode(latents, model): if scaling_factor: latents = latents * scaling_factor - vae_name = os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0] if sd_vae.loaded_vae_file is not None else "default" - vae_stats = f'name="{vae_name}" dtype={model.vqgan.dtype} device={model.vqgan.device}' - latents_stats = f'shape={latents.shape} dtype={latents.dtype} device={latents.device}' - stats = f'vae {vae_stats} latents {latents_stats}' - log_debug(f'VAE config: {model.vqgan.config}') try: decoded = model.vqgan.decode(latents).sample.clamp(0, 1) except Exception as e: - shared.log.error(f'VAE decode: {stats} {e}') + shared.log.error(f'VAE decode: {e}') errors.display(e, 'VAE decode') decoded = [] @@ -85,7 +80,8 @@ def full_vqgan_decode(latents, model): t1 = time.time() if debug: log_debug(f'VAE memory: {shared.mem_mon.read()}') - shared.log.debug(f'VAE decode: {stats} time={round(t1-t0, 3)}') + vae_name = os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0] if sd_vae.loaded_vae_file is not None else "default" + shared.log.debug(f'VAE decode: vae="{vae_name}" type="vqgan" dtype={model.vqgan.dtype} device={model.vqgan.device} time={round(t1-t0, 3)}') return decoded @@ -103,7 +99,6 @@ def full_vae_decode(latents, model): base_device = None if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False): base_device = sd_models.move_base(model, devices.cpu) - elif shared.opts.diffusers_offload_mode != "sequential": sd_models.move_model(model.vae, devices.device) @@ -134,17 +129,12 @@ def full_vae_decode(latents, model): if shift_factor: latents = latents + shift_factor - vae_name = os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0] if sd_vae.loaded_vae_file is not None else "default" - vae_stats = f'vae="{vae_name}" dtype={model.vae.dtype} device={model.vae.device} upcast={upcast} slicing={getattr(model.vae, "use_slicing", None)} tiling={getattr(model.vae, "use_tiling", None)}' - latents_stats = f'latents={latents.shape}:{latents.device}:{latents.dtype}' - stats = f'{vae_stats} {latents_stats}' - log_debug(f'VAE config: {model.vae.config}') try: with devices.inference_context(): decoded = model.vae.decode(latents, return_dict=False)[0] except Exception as e: - shared.log.error(f'VAE decode: {stats} {e}') + shared.log.error(f'VAE decode: {e}') if 'out of memory' not in str(e): errors.display(e, 'VAE decode') decoded = [] @@ -162,21 +152,24 @@ def full_vae_decode(latents, model): elif shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and base_device is not None: sd_models.move_base(model, base_device) + t1 = time.time() if debug: log_debug(f'VAE memory: {shared.mem_mon.read()}') - shared.log.debug(f'Decode: {stats} time={round(t1-t0, 3)}') + vae_name = os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0] if sd_vae.loaded_vae_file is not None else "default" + shared.log.debug(f'Decode: vae="{vae_name}" upcast={upcast} slicing={getattr(model.vae, "use_slicing", None)} tiling={getattr(model.vae, "use_tiling", None)} latents={latents.shape}:{latents.device}:{latents.dtype} time={t1-t0:.3f}') return decoded def full_vae_encode(image, model): - log_debug(f'VAE encode: name={sd_vae.loaded_vae_file if sd_vae.loaded_vae_file is not None else "baked"} dtype={model.vae.dtype} upcast={model.vae.config.get("force_upcast", None)}') if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and hasattr(model, 'unet'): log_debug('Moving to CPU: model=UNet') unet_device = model.unet.device sd_models.move_model(model.unet, devices.cpu) if not shared.opts.diffusers_offload_mode == "sequential" and hasattr(model, 'vae'): sd_models.move_model(model.vae, devices.device) + vae_name = sd_vae.loaded_vae_file if sd_vae.loaded_vae_file is not None else "default" + log_debug(f'Encode vae="{vae_name}" dtype={model.vae.dtype} upcast={model.vae.config.get("force_upcast", None)}') encoded = model.vae.encode(image.to(model.vae.device, model.vae.dtype)).latent_dist.sample() if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and hasattr(model, 'unet'): sd_models.move_model(model.unet, unet_device) @@ -184,7 +177,7 @@ def full_vae_encode(image, model): def taesd_vae_decode(latents): - log_debug(f'VAE decode: name=TAESD images={len(latents)} latents={latents.shape} slicing={shared.opts.diffusers_vae_slicing}') + t0 = time.time() if len(latents) == 0: return [] if shared.opts.diffusers_vae_slicing and len(latents) > 1: @@ -193,11 +186,13 @@ def taesd_vae_decode(latents): decoded[i] = sd_vae_taesd.decode(latents[i]) else: decoded = sd_vae_taesd.decode(latents) + t1 = time.time() + shared.log.debug(f'Decode: vae="taesd" latents={latents.shape}:{latents.dtype}:{latents.device} time={t1-t0:.3f}') return decoded def taesd_vae_encode(image): - log_debug(f'VAE encode: name=TAESD image={image.shape}') + shared.log.debug(f'Encode: vae="taesd" image={image.shape}') encoded = sd_vae_taesd.encode(image) return encoded @@ -243,6 +238,8 @@ def vae_decode(latents, model, output_type='np', full_quality=True, width=None, decoded = 2.0 * decoded - 1.0 # typical normalized range if torch.is_tensor(decoded): + if len(decoded.shape) == 3 and decoded.shape[0] == 3: + decoded = decoded.unsqueeze(0) if hasattr(model, 'video_processor'): imgs = model.video_processor.postprocess_video(decoded, output_type='pil') elif hasattr(model, 'image_processor'):