Skip to content

Commit

Permalink
Rebased on flux_lora and aligned flux_pipeline with changes in genera…
Browse files Browse the repository at this point in the history
…te_flux.py
  • Loading branch information
ksikiric committed Feb 13, 2025
1 parent 853c150 commit f56234e
Showing 1 changed file with 32 additions and 13 deletions.
45 changes: 32 additions & 13 deletions src/maxdiffusion/pipelines/flux/flux_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import partial
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Callable

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -152,8 +152,8 @@ def prepare_latents(

def prepare_latent_image_ids(self, height, width):
latent_image_ids = jnp.zeros((height, width, 3))
latent_image_ids = latent_image_ids.at[..., 1].set(latent_image_ids[..., 1] + jnp.arange(height)[:, None])
latent_image_ids = latent_image_ids.at[..., 2].set(latent_image_ids[..., 2] + jnp.arange(width)[None, :])
latent_image_ids = latent_image_ids.at[..., 1].set(jnp.arange(height)[:, None])
latent_image_ids = latent_image_ids.at[..., 2].set(jnp.arange(width)[None, :])

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

Expand All @@ -165,7 +165,6 @@ def get_clip_prompt_embeds(
self, prompt: Union[str, List[str]], num_images_per_prompt: int, tokenizer: CLIPTokenizer, text_encoder: FlaxCLIPTextModel
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
padding="max_length",
Expand All @@ -180,8 +179,7 @@ def get_clip_prompt_embeds(

prompt_embeds = text_encoder(text_input_ids, params=text_encoder.params, train=False)
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=-1)
prompt_embeds = np.reshape(prompt_embeds, (batch_size * num_images_per_prompt, -1))
prompt_embeds = jnp.tile(prompt_embeds, (num_images_per_prompt, 1))
return prompt_embeds


Expand Down Expand Up @@ -260,7 +258,8 @@ def _generate(
txt_ids,
vec,
guidance_vec,
timesteps,
c_ts,
p_ts
):

def loop_body(
Expand Down Expand Up @@ -292,9 +291,6 @@ def loop_body(
latents = jnp.array(latents, dtype=latents_dtype)
return latents, state, c_ts, p_ts

c_ts = timesteps[:-1]
p_ts = timesteps[1:]

loop_body_p = partial(
loop_body,
transformer=self.flux,
Expand All @@ -308,10 +304,28 @@ def loop_body(
vae_decode_p = partial(self.vae_decode, vae=self.vae, state=vae_params, config=self._config)

with self.mesh, nn_partitioning.axis_rules(self._config.logical_axis_rules):
latents, _, _, _ = jax.lax.fori_loop(0, len(timesteps) - 1, loop_body_p, (latents, flux_params, c_ts, p_ts))
latents, _, _, _ = jax.lax.fori_loop(0, len(c_ts), loop_body_p, (latents, flux_params, c_ts, p_ts))
image = vae_decode_p(latents)
return image

def do_time_shift(self, mu: float, sigma: float, t: Array):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)


def get_lin_function(self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b

def time_shift(self, latents, timesteps):
# estimate mu based on linear estimation between two points
lin_function = self.get_lin_function(x1=self._config.max_sequence_length,
y1=self._config.base_shift,
y2=self._config.max_shift)
mu = lin_function(latents.shape[1])
timesteps = self.do_time_shift(mu, 1.0, timesteps)
return timesteps

def __call__(
self,
timesteps: int,
Expand Down Expand Up @@ -364,7 +378,11 @@ def __call__(
rng=self.rng,
)

#timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16)
if self._config.time_shift:
timesteps = self.time_shift(latents, timesteps)
c_ts = timesteps[:-1]
p_ts = timesteps[1:]

guidance = jnp.asarray([self._config.guidance_scale] * global_batch_size, dtype=jnp.bfloat16)

images = self._generate(
Expand All @@ -376,7 +394,8 @@ def __call__(
text_ids,
pooled_prompt_embeds,
guidance,
timesteps,
c_ts,
p_ts
)

images = images
Expand Down

0 comments on commit f56234e

Please sign in to comment.