Skip to content

Commit

Permalink
Added training code, loss and results are stable
Browse files Browse the repository at this point in the history
  • Loading branch information
ksikiric committed Feb 12, 2025
1 parent a774fb1 commit ba2d028
Show file tree
Hide file tree
Showing 11 changed files with 1,192 additions and 22 deletions.
5 changes: 3 additions & 2 deletions src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT"
FLUX_CHECKPOINT = "FLUX_CHECKPOINT"


def create_orbax_checkpoint_manager(
Expand Down Expand Up @@ -66,7 +67,7 @@ def create_orbax_checkpoint_manager(
"text_encoder_state",
"tokenizer_config",
)
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT:
item_names += (
"text_encoder_2_state",
"text_encoder_2_config",
Expand Down Expand Up @@ -117,7 +118,7 @@ def load_stable_diffusion_configs(
"tokenizer_config": orbax.checkpoint.args.JsonRestore(),
}

if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT:
restore_args["text_encoder_2_config"] = orbax.checkpoint.args.JsonRestore()

return (checkpoint_manager.restore(step, args=orbax.checkpoint.args.Composite(**restore_args)), None)
Expand Down
210 changes: 210 additions & 0 deletions src/maxdiffusion/checkpointing/flux_checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from abc import ABC
from contextlib import nullcontext
import os
import json
import functools
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
import orbax.checkpoint as ocp
import grain.python as grain
from maxdiffusion import (
max_utils,
FlaxAutoencoderKL,
max_logging,
)
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from ..pipelines.flux.flux_pipeline import FluxPipeline

from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer)

from maxdiffusion.checkpointing.checkpointing_utils import (
create_orbax_checkpoint_manager,
load_stable_diffusion_configs,
)
from maxdiffusion.models.flux.util import load_flow_model

FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
_CHECKPOINT_FORMAT_DIFFUSERS = "CHECKPOINT_FORMAT_DIFFUSERS"
_CHECKPOINT_FORMAT_ORBAX = "CHECKPOINT_FORMAT_ORBAX"


class FluxCheckpointer(ABC):

def __init__(self, config, checkpoint_type):
self.config = config
self.checkpoint_type = checkpoint_type
self.checkpoint_format = None

self.rng = jax.random.PRNGKey(self.config.seed)
self.devices_array = max_utils.create_device_mesh(config)
self.mesh = Mesh(self.devices_array, self.config.mesh_axes)
self.total_train_batch_size = self.config.total_train_batch_size

self.checkpoint_manager = create_orbax_checkpoint_manager(
self.config.checkpoint_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=checkpoint_type,
dataset_type=config.dataset_type,
)

def _create_optimizer(self, config, learning_rate):

learning_rate_scheduler = max_utils.create_learning_rate_schedule(
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
)
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
return tx, learning_rate_scheduler

def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training):
transformer = pipeline.flux

tx, learning_rate_scheduler = None, None
if is_training:
learning_rate = self.config.learning_rate

tx, learning_rate_scheduler = self._create_optimizer(self.config, learning_rate)

transformer_eval_params = transformer.init_weights(
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
)

transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")

weights_init_fn = functools.partial(pipeline.flux.init_weights, rngs=self.rng, max_sequence_length=self.config.max_sequence_length)
flux_state, state_mesh_shardings = max_utils.setup_initial_state(
model=pipeline.flux,
tx=tx,
config=self.config,
mesh=self.mesh,
weights_init_fn=weights_init_fn,
model_params=None,
checkpoint_manager=self.checkpoint_manager,
checkpoint_item=checkpoint_item_name,
training=is_training,
)
if not self.config.train_new_flux:
flux_state = flux_state.replace(params=transformer_params)
flux_state = jax.device_put(flux_state, state_mesh_shardings)
return flux_state, state_mesh_shardings, learning_rate_scheduler

def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):

# Currently VAE training is not supported.
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=self.rng)
return max_utils.setup_initial_state(
model=pipeline.vae,
tx=None,
config=self.config,
mesh=self.mesh,
weights_init_fn=weights_init_fn,
model_params=params,
checkpoint_manager=self.checkpoint_manager,
checkpoint_item=checkpoint_item_name,
training=is_training,
)

def restore_data_iterator_state(self, data_iterator):
if (
self.config.dataset_type == "grain"
and data_iterator is not None
and (self.checkpoint_manager.directory / str(self.checkpoint_manager.latest_step()) / "iter").exists()
):
max_logging.log("Restoring data iterator from checkpoint")
restored = self.checkpoint_manager.restore(
self.checkpoint_manager.latest_step(),
args=ocp.args.Composite(iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)),
)
data_iterator.local_iterator = restored["iter"]
else:
max_logging.log("data iterator checkpoint not found")
return data_iterator

def _get_pipeline_class(self):
return FluxPipeline

def _set_checkpoint_format(self, checkpoint_format):
self.checkpoint_format = checkpoint_format

def save_checkpoint(self, train_step, pipeline, train_states):
items = {
"config": ocp.args.JsonSave({"model_name": self.config.model_name}),
}

items["flux_state"] = ocp.args.PyTreeSave(train_states["flux_state"])

self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

def load_params(self, step=None):

self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX

def load_checkpoint(self, step=None, scheduler_class=None):
clip_encoder = FlaxCLIPTextModel.from_pretrained(
self.config.clip_model_name_or_path, dtype=self.config.weights_dtype
)
clip_tokenizer = CLIPTokenizer.from_pretrained(
self.config.clip_model_name_or_path, max_length=77, use_fast=True
)

t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype)
t5_tokenizer = AutoTokenizer.from_pretrained(
self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True
)
encoders_sharding = PositionalSharding(self.devices_array).replicate()
partial_device_put_replicated = functools.partial(max_utils.device_put_replicated, sharding=encoders_sharding)
clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_encoder.params)
clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_encoder.params)
t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params)
t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params)



vae, vae_params = FlaxAutoencoderKL.from_pretrained(
self.config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16"
)

flash_block_sizes = max_utils.get_flash_block_sizes(self.config)
# loading from pretrained here causes a crash when trying to compile the model
# Failed to load HSACO: HIP_ERROR_NoBinaryForGpu
transformer = FluxTransformer2DModel.from_config(
self.config.pretrained_model_name_or_path,
subfolder="transformer",
mesh=self.mesh,
split_head_dim=self.config.split_head_dim,
attention_kernel=self.config.attention,
flash_block_sizes=flash_block_sizes,
dtype=self.config.activations_dtype,
weights_dtype=self.config.weights_dtype,
precision=max_utils.get_precision(self.config),
)

return FluxPipeline(t5_encoder,
clip_encoder,
vae,
t5_tokenizer,
clip_tokenizer,
transformer,
None,
dtype=self.config.activations_dtype,
mesh=self.mesh,
config=self.config,
rng=self.rng), vae_params

21 changes: 11 additions & 10 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ precision: "DEFAULT"
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash
attention: 'dot_product' # Supported attention: dot_product, flash

flash_block_sizes: {}
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
Expand All @@ -73,7 +73,7 @@ norm_num_groups: 32

# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch
# else they will be loaded from pretrained_model_name_or_path
train_new_unet: False
train_new_flux: False

# train text_encoder - Currently not supported for SDXL
train_text_encoder: False
Expand Down Expand Up @@ -111,7 +111,7 @@ diffusion_scheduler_config: {
base_output_directory: ""

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu'

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
Expand Down Expand Up @@ -173,7 +173,7 @@ hf_train_files: ''
hf_access_token: ''
image_column: 'image'
caption_column: 'text'
resolution: 1024
resolution: 512
center_crop: False
random_flip: False
# If cache_latents_text_encoder_outputs is True
Expand All @@ -189,17 +189,17 @@ checkpoint_every: -1
enable_single_replica_ckpt_restoring: False

# Training loop
learning_rate: 4.e-7
learning_rate: 1.e-5
scale_lr: False
max_train_samples: -1
# max_train_steps takes priority over num_train_epochs.
max_train_steps: 200
max_train_steps: 1500
num_train_epochs: 1
seed: 0
output_dir: 'sdxl-model-finetuned'
per_device_batch_size: 1

warmup_steps_fraction: 0.0
warmup_steps_fraction: 0.1
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.

# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
Expand All @@ -209,7 +209,7 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
adam_weight_decay: 1.e-2 # AdamW Weight decay
adam_weight_decay: 0 # AdamW Weight decay
max_grad_norm: 1.0

enable_profiler: False
Expand All @@ -219,14 +219,15 @@ skip_first_n_steps_for_profiler: 5
profiler_steps: 10

# Generation parameters
prompt: "A magical castle in the middle of a forest, artistic drawing"
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
prompt: "A Cubone, the lonely Pokémon, sits clutching its signature bone, its face hidden by a skull helmet."
prompt_2: "A Cubone, the lonely Pokémon, sits clutching its signature bone, its face hidden by a skull helmet."
negative_prompt: "purple, red"
do_classifier_free_guidance: True
guidance_scale: 3.5
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 50
save_final_checkpoint: False

# SDXL Lightning parameters
lightning_from_pt: True
Expand Down
6 changes: 3 additions & 3 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ precision: "DEFAULT"
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash
attention: 'dot_product' # Supported attention: dot_product, flash
flash_block_sizes: {
"block_q" : 256,
"block_kv_compute" : 256,
Expand Down Expand Up @@ -119,7 +119,7 @@ diffusion_scheduler_config: {
base_output_directory: ""

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu'

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
Expand Down Expand Up @@ -234,7 +234,7 @@ do_classifier_free_guidance: True
guidance_scale: 0.0
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 4
num_inference_steps: 50

# SDXL Lightning parameters
lightning_from_pt: True
Expand Down
Loading

0 comments on commit ba2d028

Please sign in to comment.