Skip to content

Commit

Permalink
Implement validation state snapshotting and replay for stable validat…
Browse files Browse the repository at this point in the history
…ion loss calculation
  • Loading branch information
stepfunction83 committed Jan 27, 2025
1 parent 23dd55a commit 3d5c644
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 58 deletions.
23 changes: 15 additions & 8 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,16 +341,23 @@ def get_noise_pred_and_target(
network,
weight_dtype,
train_unet,
is_train=True
is_train=True,
state=None
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)

bsz = latents.shape[0]

# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
if state is None:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)

# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
state = (noise, noisy_model_input, timesteps, sigmas)
else:
noise, noisy_model_input, timesteps, sigmas = state

# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
Expand Down Expand Up @@ -482,7 +489,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)

return model_pred, target, timesteps, weighting
return model_pred, target, timesteps, weighting, state

def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
Expand Down
146 changes: 96 additions & 50 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,16 @@ def get_noise_pred_and_target(
network,
weight_dtype,
train_unet,
is_train=True
is_train=True,
state=None
):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
if state is None:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
state = (noise, noisy_latents, timesteps)
else:
noise, noisy_latents, timesteps = state

# ensure the hidden state will require grad
if args.gradient_checkpointing:
Expand Down Expand Up @@ -275,7 +280,7 @@ def get_noise_pred_and_target(
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)

return noise_pred, target, timesteps, None
return noise_pred, target, timesteps, None, state

def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
if args.min_snr_gamma:
Expand Down Expand Up @@ -330,7 +335,8 @@ def process_batch(
tokenize_strategy: strategy_base.TokenizeStrategy,
is_train=True,
train_text_encoder=True,
train_unet=True
train_unet=True,
state=None
) -> torch.Tensor:
"""
Process a batch for the network
Expand Down Expand Up @@ -386,7 +392,7 @@ def process_batch(
text_encoder_conds[i] = encoded_text_encoder_conds[i]

# sample noise, call unet, get target
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
noise_pred, target, timesteps, weighting, state = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
Expand All @@ -397,7 +403,8 @@ def process_batch(
network,
weight_dtype,
train_unet,
is_train=is_train
is_train=is_train,
state=state
)

huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
Expand All @@ -413,7 +420,7 @@ def process_batch(

loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)

return loss.mean()
return loss.mean(), state

def train(self, args):
session_id = random.randint(0, 2**32)
Expand Down Expand Up @@ -1164,13 +1171,17 @@ def load_model_hook(models, input_dir):
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"

if args.validate_every_n_steps is not None:
# TODO: REMOVE HARDCODED TIMESTEP ITERATION COUNT
TIMESTEP_ITERATION_COUNT = 4

validation_steps = (
min(args.max_validation_steps, len(val_dataloader))
if args.max_validation_steps is not None
else len(val_dataloader)
)
logger.warning(f"VALIDATION STEPS: {validation_steps}")
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
range(TIMESTEP_ITERATION_COUNT * validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="validation steps"
)
Expand Down Expand Up @@ -1275,6 +1286,8 @@ def remove_model(old_ckpt_name):
param_3rd = params_itr.__next__()
logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}")

tracking_states = []

clean_memory_on_device(accelerator.device)

for epoch in range(epoch_to_start, num_train_epochs):
Expand Down Expand Up @@ -1303,7 +1316,7 @@ def remove_model(old_ckpt_name):
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)

loss = self.process_batch(
loss, _ = self.process_batch(
batch,
text_encoders,
unet,
Expand All @@ -1318,7 +1331,8 @@ def remove_model(old_ckpt_name):
tokenize_strategy,
is_train=True,
train_text_encoder=train_text_encoder,
train_unet=train_unet
train_unet=train_unet,
state=None
)

accelerator.backward(loss)
Expand Down Expand Up @@ -1397,46 +1411,76 @@ def remove_model(old_ckpt_name):
and args.validation_at_start
and (global_step - 1) % args.validate_every_n_steps == 0 # Note: Should use global step - 1 since the global step is incremented prior to this being run
)

# Break out validation processing so that it does not need to be repeated
def process_validation():
val_progress_bar.reset()

for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break

# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)

loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False
)

current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })

if is_tracking:
logs = {
"loss/validation/step_current": current_loss,
"val_step": (epoch * validation_steps) + val_step,
}
accelerator.log(logs, step=global_step)
if len(tracking_states) < len(val_dataloader):
batch['states'] = []
for t_step in range(TIMESTEP_ITERATION_COUNT):
if val_step >= validation_steps:
break

# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)

if len(tracking_states) < len(val_dataloader):
loss, state = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False,
state=None
)
batch['states'].append(state)
# logger.warning(f'BATCH STATE COUNT: {len(batch["states"])}')
else:
loss, state = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False,
state=tracking_states[val_step][t_step]
)
# logger.warning(f'TRACKING STATE COUNT: {len(tracking_states[val_step])}')
current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })

if is_tracking:
logs = {
"loss/validation/step_current": current_loss,
"val_step": (epoch * validation_steps) + val_step,
}
accelerator.log(logs, step=global_step)

if len(tracking_states) < len(val_dataloader):
tracking_states.append(batch['states'])

if is_tracking:
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
Expand All @@ -1445,11 +1489,13 @@ def process_validation():
"loss/validation/step_divergence": loss_validation_divergence,
}
accelerator.log(logs, step=global_step)

return
# END VALIDATION PROCESSING

if accelerator.sync_gradients and should_validate_step:
process_validation()
process_validation()

if global_step >= args.max_train_steps:
break

Expand Down

0 comments on commit 3d5c644

Please sign in to comment.