Skip to content

Commit

Permalink
Refactor validation processing to reduce duplicate code.
Browse files Browse the repository at this point in the history
  • Loading branch information
stepfunction83 committed Jan 26, 2025
1 parent f2d8806 commit 23dd55a
Showing 1 changed file with 9 additions and 49 deletions.
58 changes: 9 additions & 49 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,14 +1389,17 @@ def remove_model(old_ckpt_name):
maximum_norm
)
accelerator.log(logs, step=global_step)


# VALIDATION PER STEP
should_validate_step = (
args.validate_every_n_steps is not None
and args.validation_at_start
and (global_step - 1) % args.validate_every_n_steps == 0 # Note: Should use global step - 1 since the global step is incremented prior to this being run
)
if accelerator.sync_gradients and should_validate_step:

# Break out validation processing so that it does not need to be repeated
def process_validation():
val_progress_bar.reset()
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
Expand Down Expand Up @@ -1442,6 +1445,10 @@ def remove_model(old_ckpt_name):
"loss/validation/step_divergence": loss_validation_divergence,
}
accelerator.log(logs, step=global_step)
# END VALIDATION PROCESSING

if accelerator.sync_gradients and should_validate_step:
process_validation()

if global_step >= args.max_train_steps:
break
Expand All @@ -1454,54 +1461,7 @@ def remove_model(old_ckpt_name):
)

if should_validate_epoch and len(val_dataloader) > 0:
val_progress_bar.reset()
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break

# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)

loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False
)

current_loss = loss.detach().item()
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })

if is_tracking:
logs = {
"loss/validation/epoch_current": current_loss,
"epoch": epoch + 1,
"val_step": (epoch * validation_steps) + val_step
}
accelerator.log(logs, step=global_step)

if is_tracking:
avr_loss: float = val_epoch_loss_recorder.moving_average
loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss
logs = {
"loss/validation/epoch_average": avr_loss,
"loss/validation/epoch_divergence": loss_validation_divergence,
"epoch": epoch + 1
}
accelerator.log(logs, step=global_step)
process_validation()

# END OF EPOCH
if is_tracking:
Expand Down

0 comments on commit 23dd55a

Please sign in to comment.