Skip to content

Commit

Permalink
cleanup vae decode
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Jan 16, 2025
1 parent 56afd1d commit f07aca6
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions modules/processing_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,11 @@ def full_vqgan_decode(latents, model):
if scaling_factor:
latents = latents * scaling_factor

vae_name = os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0] if sd_vae.loaded_vae_file is not None else "default"
vae_stats = f'name="{vae_name}" dtype={model.vqgan.dtype} device={model.vqgan.device}'
latents_stats = f'shape={latents.shape} dtype={latents.dtype} device={latents.device}'
stats = f'vae {vae_stats} latents {latents_stats}'

log_debug(f'VAE config: {model.vqgan.config}')
try:
decoded = model.vqgan.decode(latents).sample.clamp(0, 1)
except Exception as e:
shared.log.error(f'VAE decode: {stats} {e}')
shared.log.error(f'VAE decode: {e}')
errors.display(e, 'VAE decode')
decoded = []

Expand All @@ -85,7 +80,8 @@ def full_vqgan_decode(latents, model):
t1 = time.time()
if debug:
log_debug(f'VAE memory: {shared.mem_mon.read()}')
shared.log.debug(f'VAE decode: {stats} time={round(t1-t0, 3)}')
vae_name = os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0] if sd_vae.loaded_vae_file is not None else "default"
shared.log.debug(f'VAE decode: vae="{vae_name}" type="vqgan" dtype={model.vqgan.dtype} device={model.vqgan.device} time={round(t1-t0, 3)}')
return decoded


Expand All @@ -103,7 +99,6 @@ def full_vae_decode(latents, model):
base_device = None
if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False):
base_device = sd_models.move_base(model, devices.cpu)

elif shared.opts.diffusers_offload_mode != "sequential":
sd_models.move_model(model.vae, devices.device)

Expand Down Expand Up @@ -134,17 +129,12 @@ def full_vae_decode(latents, model):
if shift_factor:
latents = latents + shift_factor

vae_name = os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0] if sd_vae.loaded_vae_file is not None else "default"
vae_stats = f'vae="{vae_name}" dtype={model.vae.dtype} device={model.vae.device} upcast={upcast} slicing={getattr(model.vae, "use_slicing", None)} tiling={getattr(model.vae, "use_tiling", None)}'
latents_stats = f'latents={latents.shape}:{latents.device}:{latents.dtype}'
stats = f'{vae_stats} {latents_stats}'

log_debug(f'VAE config: {model.vae.config}')
try:
with devices.inference_context():
decoded = model.vae.decode(latents, return_dict=False)[0]
except Exception as e:
shared.log.error(f'VAE decode: {stats} {e}')
shared.log.error(f'VAE decode: {e}')
if 'out of memory' not in str(e):
errors.display(e, 'VAE decode')
decoded = []
Expand All @@ -162,29 +152,32 @@ def full_vae_decode(latents, model):

elif shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and base_device is not None:
sd_models.move_base(model, base_device)

t1 = time.time()
if debug:
log_debug(f'VAE memory: {shared.mem_mon.read()}')
shared.log.debug(f'Decode: {stats} time={round(t1-t0, 3)}')
vae_name = os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0] if sd_vae.loaded_vae_file is not None else "default"
shared.log.debug(f'Decode: vae="{vae_name}" upcast={upcast} slicing={getattr(model.vae, "use_slicing", None)} tiling={getattr(model.vae, "use_tiling", None)} latents={latents.shape}:{latents.device}:{latents.dtype} time={t1-t0:.3f}')
return decoded


def full_vae_encode(image, model):
log_debug(f'VAE encode: name={sd_vae.loaded_vae_file if sd_vae.loaded_vae_file is not None else "baked"} dtype={model.vae.dtype} upcast={model.vae.config.get("force_upcast", None)}')
if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and hasattr(model, 'unet'):
log_debug('Moving to CPU: model=UNet')
unet_device = model.unet.device
sd_models.move_model(model.unet, devices.cpu)
if not shared.opts.diffusers_offload_mode == "sequential" and hasattr(model, 'vae'):
sd_models.move_model(model.vae, devices.device)
vae_name = sd_vae.loaded_vae_file if sd_vae.loaded_vae_file is not None else "default"
log_debug(f'Encode vae="{vae_name}" dtype={model.vae.dtype} upcast={model.vae.config.get("force_upcast", None)}')
encoded = model.vae.encode(image.to(model.vae.device, model.vae.dtype)).latent_dist.sample()
if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and hasattr(model, 'unet'):
sd_models.move_model(model.unet, unet_device)
return encoded


def taesd_vae_decode(latents):
log_debug(f'VAE decode: name=TAESD images={len(latents)} latents={latents.shape} slicing={shared.opts.diffusers_vae_slicing}')
t0 = time.time()
if len(latents) == 0:
return []
if shared.opts.diffusers_vae_slicing and len(latents) > 1:
Expand All @@ -193,11 +186,13 @@ def taesd_vae_decode(latents):
decoded[i] = sd_vae_taesd.decode(latents[i])
else:
decoded = sd_vae_taesd.decode(latents)
t1 = time.time()
shared.log.debug(f'Decode: vae="taesd" latents={latents.shape}:{latents.dtype}:{latents.device} time={t1-t0:.3f}')
return decoded


def taesd_vae_encode(image):
log_debug(f'VAE encode: name=TAESD image={image.shape}')
shared.log.debug(f'Encode: vae="taesd" image={image.shape}')
encoded = sd_vae_taesd.encode(image)
return encoded

Expand Down Expand Up @@ -243,6 +238,8 @@ def vae_decode(latents, model, output_type='np', full_quality=True, width=None,
decoded = 2.0 * decoded - 1.0 # typical normalized range

if torch.is_tensor(decoded):
if len(decoded.shape) == 3 and decoded.shape[0] == 3:
decoded = decoded.unsqueeze(0)
if hasattr(model, 'video_processor'):
imgs = model.video_processor.postprocess_video(decoded, output_type='pil')
elif hasattr(model, 'image_processor'):
Expand Down

0 comments on commit f07aca6

Please sign in to comment.