From 3d5c644ef89ed5baefc727c13e9be42705a948c5 Mon Sep 17 00:00:00 2001 From: stepfunction83 <32859451+stepfunction83@users.noreply.github.com> Date: Sun, 26 Jan 2025 20:27:56 -0500 Subject: [PATCH] Implement validation state snapshotting and replay for stable validation loss calculation --- flux_train_network.py | 23 ++++--- train_network.py | 146 +++++++++++++++++++++++++++--------------- 2 files changed, 111 insertions(+), 58 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index c35f6a697..da6f37a33 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -341,16 +341,23 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, + state=None ): - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) + bsz = latents.shape[0] - # get noisy model input and timesteps - noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, accelerator.device, weight_dtype - ) + if state is None: + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + state = (noise, noisy_model_input, timesteps, sigmas) + else: + noise, noisy_model_input, timesteps, sigmas = state # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 @@ -482,7 +489,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t ) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, weighting + return model_pred, target, timesteps, weighting, state def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/train_network.py b/train_network.py index 18803144c..8d503dd4e 100644 --- a/train_network.py +++ b/train_network.py @@ -219,11 +219,16 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, - is_train=True + is_train=True, + state=None ): - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + if state is None: + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + state = (noise, noisy_latents, timesteps) + else: + noise, noisy_latents, timesteps = state # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -275,7 +280,7 @@ def get_noise_pred_and_target( network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - return noise_pred, target, timesteps, None + return noise_pred, target, timesteps, None, state def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: if args.min_snr_gamma: @@ -330,7 +335,8 @@ def process_batch( tokenize_strategy: strategy_base.TokenizeStrategy, is_train=True, train_text_encoder=True, - train_unet=True + train_unet=True, + state=None ) -> torch.Tensor: """ Process a batch for the network @@ -386,7 +392,7 @@ def process_batch( text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target - noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( + noise_pred, target, timesteps, weighting, state = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, @@ -397,7 +403,8 @@ def process_batch( network, weight_dtype, train_unet, - is_train=is_train + is_train=is_train, + state=state ) huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) @@ -413,7 +420,7 @@ def process_batch( loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - return loss.mean() + return loss.mean(), state def train(self, args): session_id = random.randint(0, 2**32) @@ -1164,13 +1171,17 @@ def load_model_hook(models, input_dir): ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}" if args.validate_every_n_steps is not None: + # TODO: REMOVE HARDCODED TIMESTEP ITERATION COUNT + TIMESTEP_ITERATION_COUNT = 4 + validation_steps = ( min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) ) + logger.warning(f"VALIDATION STEPS: {validation_steps}") val_progress_bar = tqdm( - range(validation_steps), smoothing=0, + range(TIMESTEP_ITERATION_COUNT * validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" ) @@ -1275,6 +1286,8 @@ def remove_model(old_ckpt_name): param_3rd = params_itr.__next__() logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}") + tracking_states = [] + clean_memory_on_device(accelerator.device) for epoch in range(epoch_to_start, num_train_epochs): @@ -1303,7 +1316,7 @@ def remove_model(old_ckpt_name): # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - loss = self.process_batch( + loss, _ = self.process_batch( batch, text_encoders, unet, @@ -1318,7 +1331,8 @@ def remove_model(old_ckpt_name): tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, - train_unet=train_unet + train_unet=train_unet, + state=None ) accelerator.backward(loss) @@ -1397,46 +1411,76 @@ def remove_model(old_ckpt_name): and args.validation_at_start and (global_step - 1) % args.validate_every_n_steps == 0 # Note: Should use global step - 1 since the global step is incremented prior to this being run ) - + # Break out validation processing so that it does not need to be repeated def process_validation(): val_progress_bar.reset() + for val_step, batch in enumerate(val_dataloader): - if val_step >= validation_steps: - break - - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - - loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=False, - train_text_encoder=False, - train_unet=False - ) - - current_loss = loss.detach().item() - val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) - val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) - - if is_tracking: - logs = { - "loss/validation/step_current": current_loss, - "val_step": (epoch * validation_steps) + val_step, - } - accelerator.log(logs, step=global_step) + if len(tracking_states) < len(val_dataloader): + batch['states'] = [] + for t_step in range(TIMESTEP_ITERATION_COUNT): + if val_step >= validation_steps: + break + + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + + if len(tracking_states) < len(val_dataloader): + loss, state = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=False, + train_unet=False, + state=None + ) + batch['states'].append(state) + # logger.warning(f'BATCH STATE COUNT: {len(batch["states"])}') + else: + loss, state = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=False, + train_unet=False, + state=tracking_states[val_step][t_step] + ) + # logger.warning(f'TRACKING STATE COUNT: {len(tracking_states[val_step])}') + current_loss = loss.detach().item() + val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) + + if is_tracking: + logs = { + "loss/validation/step_current": current_loss, + "val_step": (epoch * validation_steps) + val_step, + } + accelerator.log(logs, step=global_step) + + if len(tracking_states) < len(val_dataloader): + tracking_states.append(batch['states']) if is_tracking: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average @@ -1445,11 +1489,13 @@ def process_validation(): "loss/validation/step_divergence": loss_validation_divergence, } accelerator.log(logs, step=global_step) + + return # END VALIDATION PROCESSING if accelerator.sync_gradients and should_validate_step: - process_validation() - + process_validation() + if global_step >= args.max_train_steps: break