Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CogView4 Control Block #10809

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

zRzRzRzRzRzRzR
Copy link
Contributor

What does this pull request do?

The purpose of this PR is to add a Control module to CogView4, which refers to the implementation of Flux.

Who can review?

@arrow

@zRzRzRzRzRzRzR zRzRzRzRzRzRzR changed the title CogView4 Contorl Block CogView4 Control Block Feb 17, 2025
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR!
do we already have a checkpoint for cogview3 control lora or is this mainly to support training?

@@ -35,7 +36,7 @@ class CogView4PatchEmbed(nn.Module):
def __init__(
self,
in_channels: int = 16,
hidden_size: int = 2560,
hidden_size: int = 4096,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this not breaking change?

"""


def calculate_shift(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add a #Copied from here

>>> import torch
>>> from diffusers import CogView4Pipeline

>>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to update the pipeline
do we have a checkpoint?

self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

def _get_glm_embeds(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add #Copied from here

prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds

def encode_prompt(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

if timesteps is None
else np.array(timesteps)
)
timesteps = timesteps.astype(np.int64)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
timesteps = timesteps.astype(np.int64)
timesteps = timesteps.astype(np.int64).astype(np.float32)

self.scheduler.config.get("base_shift", 0.25),
self.scheduler.config.get("max_shift", 0.75),
)
_, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas, mu=mu)

we updated our scheduler to work with cogview4 - is there any reason we still cannot use the 1scheduler.set_timesteps1 to set timesteps?

self.scheduler.config.get("max_shift", 0.75),
)
_, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
timesteps = torch.from_numpy(timesteps).to(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
timesteps = torch.from_numpy(timesteps).to(device)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants