Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training added on top of flux_impl #147

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT"
FLUX_CHECKPOINT = "FLUX_CHECKPOINT"


def create_orbax_checkpoint_manager(
Expand Down Expand Up @@ -66,7 +67,7 @@ def create_orbax_checkpoint_manager(
"text_encoder_state",
"tokenizer_config",
)
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT:
item_names += (
"text_encoder_2_state",
"text_encoder_2_config",
Expand Down Expand Up @@ -117,7 +118,7 @@ def load_stable_diffusion_configs(
"tokenizer_config": orbax.checkpoint.args.JsonRestore(),
}

if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT:
restore_args["text_encoder_2_config"] = orbax.checkpoint.args.JsonRestore()

return (checkpoint_manager.restore(step, args=orbax.checkpoint.args.Composite(**restore_args)), None)
Expand Down
210 changes: 210 additions & 0 deletions src/maxdiffusion/checkpointing/flux_checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from abc import ABC
from contextlib import nullcontext
import os
import json
import functools
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
import orbax.checkpoint as ocp
import grain.python as grain
from maxdiffusion import (
max_utils,
FlaxAutoencoderKL,
max_logging,
)
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from ..pipelines.flux.flux_pipeline import FluxPipeline

from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer)

from maxdiffusion.checkpointing.checkpointing_utils import (
create_orbax_checkpoint_manager,
load_stable_diffusion_configs,
)
from maxdiffusion.models.flux.util import load_flow_model

FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
_CHECKPOINT_FORMAT_DIFFUSERS = "CHECKPOINT_FORMAT_DIFFUSERS"
_CHECKPOINT_FORMAT_ORBAX = "CHECKPOINT_FORMAT_ORBAX"


class FluxCheckpointer(ABC):

def __init__(self, config, checkpoint_type):
self.config = config
self.checkpoint_type = checkpoint_type
self.checkpoint_format = None

self.rng = jax.random.PRNGKey(self.config.seed)
self.devices_array = max_utils.create_device_mesh(config)
self.mesh = Mesh(self.devices_array, self.config.mesh_axes)
self.total_train_batch_size = self.config.total_train_batch_size

self.checkpoint_manager = create_orbax_checkpoint_manager(
self.config.checkpoint_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=checkpoint_type,
dataset_type=config.dataset_type,
)

def _create_optimizer(self, config, learning_rate):

learning_rate_scheduler = max_utils.create_learning_rate_schedule(
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
)
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
return tx, learning_rate_scheduler

def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training):
transformer = pipeline.flux

tx, learning_rate_scheduler = None, None
if is_training:
learning_rate = self.config.learning_rate

tx, learning_rate_scheduler = self._create_optimizer(self.config, learning_rate)

transformer_eval_params = transformer.init_weights(
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
)

transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")

weights_init_fn = functools.partial(pipeline.flux.init_weights, rngs=self.rng, max_sequence_length=self.config.max_sequence_length)
flux_state, state_mesh_shardings = max_utils.setup_initial_state(
model=pipeline.flux,
tx=tx,
config=self.config,
mesh=self.mesh,
weights_init_fn=weights_init_fn,
model_params=None,
checkpoint_manager=self.checkpoint_manager,
checkpoint_item=checkpoint_item_name,
training=is_training,
)
if not self.config.train_new_flux:
flux_state = flux_state.replace(params=transformer_params)
flux_state = jax.device_put(flux_state, state_mesh_shardings)
return flux_state, state_mesh_shardings, learning_rate_scheduler

def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):

# Currently VAE training is not supported.
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=self.rng)
return max_utils.setup_initial_state(
model=pipeline.vae,
tx=None,
config=self.config,
mesh=self.mesh,
weights_init_fn=weights_init_fn,
model_params=params,
checkpoint_manager=self.checkpoint_manager,
checkpoint_item=checkpoint_item_name,
training=is_training,
)

def restore_data_iterator_state(self, data_iterator):
if (
self.config.dataset_type == "grain"
and data_iterator is not None
and (self.checkpoint_manager.directory / str(self.checkpoint_manager.latest_step()) / "iter").exists()
):
max_logging.log("Restoring data iterator from checkpoint")
restored = self.checkpoint_manager.restore(
self.checkpoint_manager.latest_step(),
args=ocp.args.Composite(iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)),
)
data_iterator.local_iterator = restored["iter"]
else:
max_logging.log("data iterator checkpoint not found")
return data_iterator

def _get_pipeline_class(self):
return FluxPipeline

def _set_checkpoint_format(self, checkpoint_format):
self.checkpoint_format = checkpoint_format

def save_checkpoint(self, train_step, pipeline, train_states):
items = {
"config": ocp.args.JsonSave({"model_name": self.config.model_name}),
}

items["flux_state"] = ocp.args.PyTreeSave(train_states["flux_state"])

self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))

def load_params(self, step=None):

self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX

def load_checkpoint(self, step=None, scheduler_class=None):
clip_encoder = FlaxCLIPTextModel.from_pretrained(
self.config.clip_model_name_or_path, dtype=self.config.weights_dtype
)
clip_tokenizer = CLIPTokenizer.from_pretrained(
self.config.clip_model_name_or_path, max_length=77, use_fast=True
)

t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype)
t5_tokenizer = AutoTokenizer.from_pretrained(
self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True
)
encoders_sharding = PositionalSharding(self.devices_array).replicate()
partial_device_put_replicated = functools.partial(max_utils.device_put_replicated, sharding=encoders_sharding)
clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_encoder.params)
clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_encoder.params)
t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params)
t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params)



vae, vae_params = FlaxAutoencoderKL.from_pretrained(
self.config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16"
)

flash_block_sizes = max_utils.get_flash_block_sizes(self.config)
# loading from pretrained here causes a crash when trying to compile the model
# Failed to load HSACO: HIP_ERROR_NoBinaryForGpu
transformer = FluxTransformer2DModel.from_config(
self.config.pretrained_model_name_or_path,
subfolder="transformer",
mesh=self.mesh,
split_head_dim=self.config.split_head_dim,
attention_kernel=self.config.attention,
flash_block_sizes=flash_block_sizes,
dtype=self.config.activations_dtype,
weights_dtype=self.config.weights_dtype,
precision=max_utils.get_precision(self.config),
)

return FluxPipeline(t5_encoder,
clip_encoder,
vae,
t5_tokenizer,
clip_tokenizer,
transformer,
None,
dtype=self.config.activations_dtype,
mesh=self.mesh,
config=self.config,
rng=self.rng), vae_params

19 changes: 10 additions & 9 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ norm_num_groups: 32

# If train_new_unet, unet weights will be randomly initialized to train the unet from scratch
# else they will be loaded from pretrained_model_name_or_path
train_new_unet: False
train_new_flux: False

# train text_encoder - Currently not supported for SDXL
train_text_encoder: False
Expand Down Expand Up @@ -111,7 +111,7 @@ diffusion_scheduler_config: {
base_output_directory: ""

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu'

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
Expand Down Expand Up @@ -173,7 +173,7 @@ hf_train_files: ''
hf_access_token: ''
image_column: 'image'
caption_column: 'text'
resolution: 1024
resolution: 512
center_crop: False
random_flip: False
# If cache_latents_text_encoder_outputs is True
Expand All @@ -189,17 +189,17 @@ checkpoint_every: -1
enable_single_replica_ckpt_restoring: False

# Training loop
learning_rate: 4.e-7
learning_rate: 1.e-5
scale_lr: False
max_train_samples: -1
# max_train_steps takes priority over num_train_epochs.
max_train_steps: 200
max_train_steps: 1500
num_train_epochs: 1
seed: 0
output_dir: 'sdxl-model-finetuned'
per_device_batch_size: 1

warmup_steps_fraction: 0.0
warmup_steps_fraction: 0.1
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.

# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
Expand All @@ -209,7 +209,7 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
adam_weight_decay: 1.e-2 # AdamW Weight decay
adam_weight_decay: 0 # AdamW Weight decay
max_grad_norm: 1.0

enable_profiler: False
Expand All @@ -219,14 +219,15 @@ skip_first_n_steps_for_profiler: 5
profiler_steps: 10

# Generation parameters
prompt: "A magical castle in the middle of a forest, artistic drawing"
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
prompt: "A Cubone, the lonely Pokémon, sits clutching its signature bone, its face hidden by a skull helmet."
prompt_2: "A Cubone, the lonely Pokémon, sits clutching its signature bone, its face hidden by a skull helmet."
negative_prompt: "purple, red"
do_classifier_free_guidance: True
guidance_scale: 3.5
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 50
save_final_checkpoint: False

# SDXL Lightning parameters
lightning_from_pt: True
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ diffusion_scheduler_config: {
base_output_directory: ""

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
hardware: 'gpu' # Supported hardware types are 'tpu', 'gpu'

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
Expand Down Expand Up @@ -234,7 +234,7 @@ do_classifier_free_guidance: True
guidance_scale: 0.0
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 4
num_inference_steps: 50

# SDXL Lightning parameters
lightning_from_pt: True
Expand Down
55 changes: 55 additions & 0 deletions src/maxdiffusion/maxdiffusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,61 @@ def calculate_unet_tflops(config, pipeline, batch_size, rngs, train):
/ jax.local_device_count()
)

def get_dummy_flux_inputs(config, pipeline, batch_size):
"""Returns randomly initialized flux inputs."""
latents, latents_ids = pipeline.prepare_latents(
batch_size=batch_size,
num_channels_latents=pipeline.flux.in_channels // 4,
height=config.resolution,
width=config.resolution,
vae_scale_factor=pipeline.vae_scale_factor,
dtype=config.activations_dtype,
rng=pipeline.rng
)
guidance_vec = jnp.asarray([config.guidance_scale] * batch_size, dtype=config.activations_dtype)

timesteps = jnp.ones((batch_size,), dtype=config.weights_dtype)
t5_hidden_states_shape = (
batch_size,
config.max_sequence_length,
4096,
)
t5_hidden_states = jnp.zeros(t5_hidden_states_shape, dtype=config.weights_dtype)
t5_ids = jnp.zeros((batch_size, t5_hidden_states.shape[1], 3), dtype=config.weights_dtype)

clip_hidden_states_shape = (
batch_size,
768,
)
clip_hidden_states = jnp.zeros(clip_hidden_states_shape, dtype=config.weights_dtype)

return (latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states)


def calculate_flux_tflops(config, pipeline, batch_size, rngs, train):
"""
Calculates jflux tflops.
batch_size should be per_device_batch_size * jax.local_device_count() or attention's shard_map won't
cache the compilation when flash is enabled.
"""

(latents, timesteps, latents_ids, guidance_vec, t5_hidden_states, t5_ids, clip_hidden_states) = get_dummy_flux_inputs(config, pipeline, batch_size)
return (
max_utils.calculate_model_tflops(
pipeline.flux,
rngs,
train,
hidden_states=latents,
img_ids=latents_ids,
encoder_hidden_states=t5_hidden_states,
txt_ids=t5_ids,
pooled_projections=clip_hidden_states,
timestep=timesteps,
guidance=guidance_vec,
)
/ jax.local_device_count()
)


def tokenize_captions(examples, caption_column, tokenizer, input_ids_key="input_ids", p_encode=None):
"""Tokenize captions for sd1.x,sd2.x models."""
Expand Down
5 changes: 5 additions & 0 deletions src/maxdiffusion/pipelines/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_import_structure = { "pipeline_jflux" : "JfluxPipeline" }

from .flux_pipeline import (
FluxPipeline,
)
Loading