Skip to content

Commit

Permalink
fix: [Community pipeline] Fix flattened elements on image (#10774)
Browse files Browse the repository at this point in the history
* feat: new community mixture_tiling_sdxl pipeline for SDXL mixture-of-diffusers support

* fix use of variable latents to tile_latents

* removed references to modules that are not being used in this pipeline

* make style, make quality

* fixfeat: added _get_crops_coords_list function to pipeline to automatically define ctop,cleft coord to focus on image generation, helps to better harmonize the image and corrects the problem of flattened elements.
  • Loading branch information
elismasilva authored Feb 12, 2025
1 parent 5105b5a commit 051ebc3
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 15 deletions.
16 changes: 8 additions & 8 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
| Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LsqilswLR40XLLcp6XFOl5nKb_wOe26W?usp=sharing) | [Andrew Zhu](https://xhinker.medium.com/) |
| Stable Diffusion Mixture Tiling Pipeline SD 1.5 | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending | [Stable Diffusion Mixture Tiling Pipeline SD 1.5](#stable-diffusion-mixture-tiling-sd-15) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/albarji/mixture-of-diffusers) | [Álvaro B Jiménez](https://github.com/albarji/) |
| Stable Diffusion Mixture Tiling Pipeline SDXL | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending | [Stable Diffusion Mixture Tiling Pipeline SDXL](#stable-diffusion-mixture-tiling-sdxl) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/elismasilva/mixture-of-diffusers-sdxl-tiling) | [Eliseu Silva](https://github.com/DEVAIEXP/) |
| Stable Diffusion Mixture Tiling Pipeline SD 1.5 | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending | [Stable Diffusion Mixture Tiling Pipeline SD 1.5](#stable-diffusion-mixture-tiling-pipeline-sd-15) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/albarji/mixture-of-diffusers) | [Álvaro B Jiménez](https://github.com/albarji/) |
| Stable Diffusion Mixture Canvas Pipeline SD 1.5 | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending. Works by defining a list of Text2Image region objects that detail the region of influence of each diffuser. | [Stable Diffusion Mixture Canvas Pipeline SD 1.5](#stable-diffusion-mixture-canvas-pipeline-sd-15) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/albarji/mixture-of-diffusers) | [Álvaro B Jiménez](https://github.com/albarji/) |
| Stable Diffusion Mixture Tiling Pipeline SDXL | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending | [Stable Diffusion Mixture Tiling Pipeline SDXL](#stable-diffusion-mixture-tiling-pipeline-sdxl) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/elismasilva/mixture-of-diffusers-sdxl-tiling) | [Eliseu Silva](https://github.com/DEVAIEXP/) |
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
Expand Down Expand Up @@ -2404,7 +2405,7 @@ pipe_images = mixing_pipeline(

![image_mixing_result](https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/boromir_gigachad.png)

### Stable Diffusion Mixture Tiling SD 1.5
### Stable Diffusion Mixture Tiling Pipeline SD 1.5

This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details.

Expand Down Expand Up @@ -2435,7 +2436,7 @@ image = pipeline(

![mixture_tiling_results](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/mixture_tiling.png)

### Stable Diffusion Mixture Canvas
### Stable Diffusion Mixture Canvas Pipeline SD 1.5

This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details.

Expand Down Expand Up @@ -2470,7 +2471,7 @@ output = pipeline(
![Input_Image](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/input_image.png)
![mixture_canvas_results](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/canvas.png)

### Stable Diffusion Mixture Tiling SDXL
### Stable Diffusion Mixture Tiling Pipeline SDXL

This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details.

Expand Down Expand Up @@ -2516,14 +2517,13 @@ image = pipe(
tile_col_overlap=256,
guidance_scale_tiles=[[7, 7, 7]], # or guidance_scale=7 if is the same for all prompts
height=1024,
width=3840,
target_size=(1024, 3840),
width=3840,
generator=generator,
num_inference_steps=30,
)["images"][0]
```

![mixture_tiling_results](https://huggingface.co/datasets/elismasilva/results/resolve/main/mixture_sdxl.png)
![mixture_tiling_results](https://huggingface.co/datasets/elismasilva/results/resolve/main/mixture_of_diffusers_sdxl_1.png)

### TensorRT Inpainting Stable Diffusion Pipeline

Expand Down
66 changes: 59 additions & 7 deletions examples/community/mixture_tiling_sdxl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -151,6 +151,51 @@ def _tile2latent_exclusive_indices(
return row_segment[0], row_segment[1], col_segment[0], col_segment[1]


def _get_crops_coords_list(num_rows, num_cols, output_width):
"""
Generates a list of lists of `crops_coords_top_left` tuples for focusing on
different horizontal parts of an image, and repeats this list for the specified
number of rows in the output structure.
This function calculates `crops_coords_top_left` tuples to create horizontal
focus variations (like left, center, right focus) based on `output_width`
and `num_cols` (which represents the number of horizontal focus points/columns).
It then repeats the *list* of these horizontal focus tuples `num_rows` times to
create the final list of lists output structure.
Args:
num_rows (int): The desired number of rows in the output list of lists.
This determines how many times the list of horizontal
focus variations will be repeated.
num_cols (int): The number of horizontal focus points (columns) to generate.
This determines how many horizontal focus variations are
created based on dividing the `output_width`.
output_width (int): The desired width of the output image.
Returns:
list[list[tuple[int, int]]]: A list of lists of tuples. Each inner list
contains `num_cols` tuples of `(ctop, cleft)`,
representing horizontal focus points. The outer list
contains `num_rows` such inner lists.
"""
crops_coords_list = []
if num_cols <= 0:
crops_coords_list = []
elif num_cols == 1:
crops_coords_list = [(0, 0)]
else:
section_width = output_width / num_cols
for i in range(num_cols):
cleft = int(round(i * section_width))
crops_coords_list.append((0, cleft))

result_list = []
for _ in range(num_rows):
result_list.append(list(crops_coords_list))

return result_list


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Expand Down Expand Up @@ -757,10 +802,10 @@ def __call__(
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
crops_coords_top_left: Optional[List[List[Tuple[int, int]]]] = None,
target_size: Optional[Tuple[int, int]] = None,
negative_original_size: Optional[Tuple[int, int]] = None,
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_crops_coords_top_left: Optional[List[List[Tuple[int, int]]]] = None,
negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
tile_height: Optional[int] = 1024,
Expand Down Expand Up @@ -826,7 +871,7 @@ def __call__(
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
crops_coords_top_left (`List[List[Tuple[int, int]]]`, *optional*, defaults to (0, 0)):
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
Expand All @@ -840,7 +885,7 @@ def __call__(
micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
negative_crops_coords_top_left (`List[List[Tuple[int, int]]]`, *optional*, defaults to (0, 0)):
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
Expand Down Expand Up @@ -883,6 +928,8 @@ def __call__(

original_size = original_size or (height, width)
target_size = target_size or (height, width)
negative_original_size = negative_original_size or (height, width)
negative_target_size = negative_target_size or (height, width)

self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
Expand Down Expand Up @@ -914,6 +961,11 @@ def __call__(

device = self._execution_device

# update crops coords list
crops_coords_top_left = _get_crops_coords_list(grid_rows, grid_cols, tile_width)
if negative_original_size is not None and negative_target_size is not None:
negative_crops_coords_top_left = _get_crops_coords_list(grid_rows, grid_cols, tile_width)

# update height and width tile size and tile overlap size
height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)
width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)
Expand Down Expand Up @@ -1020,15 +1072,15 @@ def __call__(
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
crops_coords_top_left[row][col],
target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
if negative_original_size is not None and negative_target_size is not None:
negative_add_time_ids = self._get_add_time_ids(
negative_original_size,
negative_crops_coords_top_left,
negative_crops_coords_top_left[row][col],
negative_target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
Expand Down

0 comments on commit 051ebc3

Please sign in to comment.