-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added training code, loss and results are stable
- Loading branch information
Showing
11 changed files
with
1,192 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
""" | ||
Copyright 2024 Google LLC | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
https://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
from abc import ABC | ||
from contextlib import nullcontext | ||
import os | ||
import json | ||
import functools | ||
import jax | ||
import jax.numpy as jnp | ||
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P | ||
import orbax.checkpoint as ocp | ||
import grain.python as grain | ||
from maxdiffusion import ( | ||
max_utils, | ||
FlaxAutoencoderKL, | ||
max_logging, | ||
) | ||
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel | ||
from ..pipelines.flux.flux_pipeline import FluxPipeline | ||
|
||
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer) | ||
|
||
from maxdiffusion.checkpointing.checkpointing_utils import ( | ||
create_orbax_checkpoint_manager, | ||
load_stable_diffusion_configs, | ||
) | ||
from maxdiffusion.models.flux.util import load_flow_model | ||
|
||
FLUX_CHECKPOINT = "FLUX_CHECKPOINT" | ||
_CHECKPOINT_FORMAT_DIFFUSERS = "CHECKPOINT_FORMAT_DIFFUSERS" | ||
_CHECKPOINT_FORMAT_ORBAX = "CHECKPOINT_FORMAT_ORBAX" | ||
|
||
|
||
class FluxCheckpointer(ABC): | ||
|
||
def __init__(self, config, checkpoint_type): | ||
self.config = config | ||
self.checkpoint_type = checkpoint_type | ||
self.checkpoint_format = None | ||
|
||
self.rng = jax.random.PRNGKey(self.config.seed) | ||
self.devices_array = max_utils.create_device_mesh(config) | ||
self.mesh = Mesh(self.devices_array, self.config.mesh_axes) | ||
self.total_train_batch_size = self.config.total_train_batch_size | ||
|
||
self.checkpoint_manager = create_orbax_checkpoint_manager( | ||
self.config.checkpoint_dir, | ||
enable_checkpointing=True, | ||
save_interval_steps=1, | ||
checkpoint_type=checkpoint_type, | ||
dataset_type=config.dataset_type, | ||
) | ||
|
||
def _create_optimizer(self, config, learning_rate): | ||
|
||
learning_rate_scheduler = max_utils.create_learning_rate_schedule( | ||
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps | ||
) | ||
tx = max_utils.create_optimizer(config, learning_rate_scheduler) | ||
return tx, learning_rate_scheduler | ||
|
||
def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training): | ||
transformer = pipeline.flux | ||
|
||
tx, learning_rate_scheduler = None, None | ||
if is_training: | ||
learning_rate = self.config.learning_rate | ||
|
||
tx, learning_rate_scheduler = self._create_optimizer(self.config, learning_rate) | ||
|
||
transformer_eval_params = transformer.init_weights( | ||
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True | ||
) | ||
|
||
transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu") | ||
|
||
weights_init_fn = functools.partial(pipeline.flux.init_weights, rngs=self.rng, max_sequence_length=self.config.max_sequence_length) | ||
flux_state, state_mesh_shardings = max_utils.setup_initial_state( | ||
model=pipeline.flux, | ||
tx=tx, | ||
config=self.config, | ||
mesh=self.mesh, | ||
weights_init_fn=weights_init_fn, | ||
model_params=None, | ||
checkpoint_manager=self.checkpoint_manager, | ||
checkpoint_item=checkpoint_item_name, | ||
training=is_training, | ||
) | ||
if not self.config.train_new_flux: | ||
flux_state = flux_state.replace(params=transformer_params) | ||
flux_state = jax.device_put(flux_state, state_mesh_shardings) | ||
return flux_state, state_mesh_shardings, learning_rate_scheduler | ||
|
||
def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False): | ||
|
||
# Currently VAE training is not supported. | ||
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=self.rng) | ||
return max_utils.setup_initial_state( | ||
model=pipeline.vae, | ||
tx=None, | ||
config=self.config, | ||
mesh=self.mesh, | ||
weights_init_fn=weights_init_fn, | ||
model_params=params, | ||
checkpoint_manager=self.checkpoint_manager, | ||
checkpoint_item=checkpoint_item_name, | ||
training=is_training, | ||
) | ||
|
||
def restore_data_iterator_state(self, data_iterator): | ||
if ( | ||
self.config.dataset_type == "grain" | ||
and data_iterator is not None | ||
and (self.checkpoint_manager.directory / str(self.checkpoint_manager.latest_step()) / "iter").exists() | ||
): | ||
max_logging.log("Restoring data iterator from checkpoint") | ||
restored = self.checkpoint_manager.restore( | ||
self.checkpoint_manager.latest_step(), | ||
args=ocp.args.Composite(iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)), | ||
) | ||
data_iterator.local_iterator = restored["iter"] | ||
else: | ||
max_logging.log("data iterator checkpoint not found") | ||
return data_iterator | ||
|
||
def _get_pipeline_class(self): | ||
return FluxPipeline | ||
|
||
def _set_checkpoint_format(self, checkpoint_format): | ||
self.checkpoint_format = checkpoint_format | ||
|
||
def save_checkpoint(self, train_step, pipeline, train_states): | ||
items = { | ||
"config": ocp.args.JsonSave({"model_name": self.config.model_name}), | ||
} | ||
|
||
items["flux_state"] = ocp.args.PyTreeSave(train_states["flux_state"]) | ||
|
||
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) | ||
|
||
def load_params(self, step=None): | ||
|
||
self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX | ||
|
||
def load_checkpoint(self, step=None, scheduler_class=None): | ||
clip_encoder = FlaxCLIPTextModel.from_pretrained( | ||
self.config.clip_model_name_or_path, dtype=self.config.weights_dtype | ||
) | ||
clip_tokenizer = CLIPTokenizer.from_pretrained( | ||
self.config.clip_model_name_or_path, max_length=77, use_fast=True | ||
) | ||
|
||
t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype) | ||
t5_tokenizer = AutoTokenizer.from_pretrained( | ||
self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True | ||
) | ||
encoders_sharding = PositionalSharding(self.devices_array).replicate() | ||
partial_device_put_replicated = functools.partial(max_utils.device_put_replicated, sharding=encoders_sharding) | ||
clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_encoder.params) | ||
clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_encoder.params) | ||
t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), t5_encoder.params) | ||
t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, t5_encoder.params) | ||
|
||
|
||
|
||
vae, vae_params = FlaxAutoencoderKL.from_pretrained( | ||
self.config.pretrained_model_name_or_path, subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16" | ||
) | ||
|
||
flash_block_sizes = max_utils.get_flash_block_sizes(self.config) | ||
# loading from pretrained here causes a crash when trying to compile the model | ||
# Failed to load HSACO: HIP_ERROR_NoBinaryForGpu | ||
transformer = FluxTransformer2DModel.from_config( | ||
self.config.pretrained_model_name_or_path, | ||
subfolder="transformer", | ||
mesh=self.mesh, | ||
split_head_dim=self.config.split_head_dim, | ||
attention_kernel=self.config.attention, | ||
flash_block_sizes=flash_block_sizes, | ||
dtype=self.config.activations_dtype, | ||
weights_dtype=self.config.weights_dtype, | ||
precision=max_utils.get_precision(self.config), | ||
) | ||
|
||
return FluxPipeline(t5_encoder, | ||
clip_encoder, | ||
vae, | ||
t5_tokenizer, | ||
clip_tokenizer, | ||
transformer, | ||
None, | ||
dtype=self.config.activations_dtype, | ||
mesh=self.mesh, | ||
config=self.config, | ||
rng=self.rng), vae_params | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.