Skip to content

Commit

Permalink
Revert "Create basic Flux calc for test and validation loss"
Browse files Browse the repository at this point in the history
This reverts commit 0b50630.
  • Loading branch information
stepfunction83 committed Jan 25, 2025
1 parent 0b50630 commit 99338a2
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 181 deletions.
258 changes: 104 additions & 154 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import time
from typing import List, Optional, Tuple, Union
import toml
import random

from tqdm import tqdm

Expand All @@ -45,15 +44,14 @@

import library.config_util as config_util

from contextlib import nullcontext

# import library.sdxl_train_util as sdxl_train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments


def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
Expand Down Expand Up @@ -579,180 +577,132 @@ def grad_hook(parameter: torch.Tensor):
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

### PLACEHOLDERS ###
test_step_freq = 10
val_step_freq = 25
test_set_count = 5
val_set_count = 5
test_val_repeat_count = 2

logger.warning('CREATING TEST AND VALIDATION SETS')
test_set, val_set = train_util.create_test_val_set(train_dataloader, test_set_count, val_set_count)

# TODO: Get arguments for step_freq values
# TODO: Get arguments for test_set_count, test_noise_iter
for m in training_models:
m.train()

def calculate_loss(step=step, batch=batch, state=None, accumulate_loss: bool=True, accelerator=accelerator):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step

if state is not None:
noise, noisy_model_input, timesteps, sigmas = state

with accelerator.accumulate(*training_models) if accumulate_loss else nullcontext(): # Only utilize the accumulate context if loss is marked to be accumulated, otherwise, just use a null context. This avoids the test and validation samples impacting the training.
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
else:
with torch.no_grad():
# encode images to latents. images are [-1, 1]
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)

# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)

text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list
else:
# not cached or training, so get from text encoders
tokens_and_masks = batch["input_ids_list"]
with torch.no_grad():
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
if args.blockwise_fused_optimizers:
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step

# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
else:
with torch.no_grad():
# encode images to latents. images are [-1, 1]
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)

# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)

text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list
else:
# not cached or training, so get from text encoders
tokens_and_masks = batch["input_ids_list"]
with torch.no_grad():
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]

bsz = latents.shape[0]
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps

# get noisy model input and timesteps
if state is None: # Only calculate if not using stored values for validation
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)

bsz = latents.shape[0]

# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
)

# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)

# get guidance: ensure args.guidance_scale is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)

# call model
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None

if args.bypass_flux_guidance:
flux_utils.bypass_flux_guidance(flux)

with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = flux(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)

# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
# get guidance: ensure args.guidance_scale is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)

if args.bypass_flux_guidance:
flux_utils.restore_flux_guidance(flux)
# call model
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None

if args.bypass_flux_guidance:
flux_utils.bypass_flux_guidance(flux)

with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = flux(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)

# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)

# flow matching loss: this is different from SD3
target = noise - latents
if args.bypass_flux_guidance:
flux_utils.restore_flux_guidance(flux)

# calculate loss
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)

loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean()
# flow matching loss: this is different from SD3
target = noise - latents

state = (noise, noisy_model_input, timesteps, sigmas)
# calculate loss
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])

return loss, state

loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean()

for m in training_models:
m.train()

for step, batch in enumerate(train_dataloader):
if step in val_set['steps']: # Skip validation steps, don't increment global step
logger.warning('SKIPPING BATCH IN VALIDATION SET')
continue

current_step.value = global_step
# backward
accelerator.backward(loss)

if args.blockwise_fused_optimizers:
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

# CALCULATE LOSS ON TEST SET AT TEST SET FREQUENCY
if global_step==0:
test_fixed_states = []
test_losses = []
if global_step % test_step_freq == 0 and test_step_freq > 0:
test_loss, test_fixed_states = train_util.calc_test_val_loss(dataset=test_set, loss_func=calculate_loss, repeat_count=test_val_repeat_count, fixed_states=test_fixed_states, test=True)
test_losses.append(test_loss)

# CALCULATE LOSS ON VALIDATION SET AT TEST SET FREQUENCY
if global_step==0:
val_fixed_states = []
val_losses = []
if global_step % val_step_freq == 0 and val_step_freq > 0:
val_loss, val_fixed_states = train_util.calc_test_val_loss(dataset=val_set, loss_func=calculate_loss, repeat_count=test_val_repeat_count, fixed_states=val_fixed_states, test=False)
val_losses.append(val_loss)

# STANDARD LOSS CALCULATION
loss, _ = calculate_loss(step, batch, accumulate_loss=True) # Loss should be accumulated when not running the test/validation samples though

# backward
accelerator.backward(loss)

if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.blockwise_fused_optimizers:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.blockwise_fused_optimizers:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
Expand Down
27 changes: 0 additions & 27 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6398,30 +6398,3 @@ def add(self, *, epoch: int, step: int, loss: float) -> None:
@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)

def calc_test_val_loss(dataset, loss_func, repeat_count, fixed_states=[], test=True):
test_val_ind = 'TEST' if test else 'VALIDATION'
# logger.warning(f'CALCULATING {test_val_ind} LOSS')
losses = []
for step, batch in enumerate(dataset['batches'] * repeat_count):
if len(fixed_states) < len(dataset['batches']) * repeat_count: # If accumulating fixed states, calculate state as normal and return
loss, state = loss_func(step, batch, None, accumulate_loss=False)
fixed_states.append(state)
else: # Otherwise, recall the stored values and use those instead so the test loss is consistently calculated for each sample
state = fixed_states[step]
loss, _ = loss_func(step, batch, state, accumulate_loss=False)
losses.append(loss.detach().item())
avg_loss = sum(losses) / len(losses)
logger.info(f'AVERAGE {test_val_ind} LOSS: {avg_loss:.6f}')
return avg_loss, fixed_states

def create_test_val_set(dataloader, test_set_count, val_set_count):
test_set = test_set = {'steps':list(range(test_set_count)), 'batches':[]}
val_set = {'steps':list(range(test_set_count,test_set_count+val_set_count)), 'batches':[]}
for step, batch in enumerate(dataloader):
if step in test_set['steps']:
test_set['batches'].append(batch)
if step in val_set['steps']:
val_set['batches'].append(batch)
if step >= test_set_count + val_set_count:
return test_set, val_set

0 comments on commit 99338a2

Please sign in to comment.