-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CogView4 Control Block #10809
base: main
Are you sure you want to change the base?
CogView4 Control Block #10809
Conversation
…diffusers into cogview4_control
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the PR!
do we already have a checkpoint for cogview3 control lora or is this mainly to support training?
@@ -35,7 +36,7 @@ class CogView4PatchEmbed(nn.Module): | |||
def __init__( | |||
self, | |||
in_channels: int = 16, | |||
hidden_size: int = 2560, | |||
hidden_size: int = 4096, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this not breaking change?
""" | ||
|
||
|
||
def calculate_shift( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can add a #Copied from
here
>>> import torch | ||
>>> from diffusers import CogView4Pipeline | ||
|
||
>>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to update the pipeline
do we have a checkpoint?
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 | ||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | ||
|
||
def _get_glm_embeds( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add #Copied
from here
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | ||
return prompt_embeds | ||
|
||
def encode_prompt( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
if timesteps is None | ||
else np.array(timesteps) | ||
) | ||
timesteps = timesteps.astype(np.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
timesteps = timesteps.astype(np.int64) | |
timesteps = timesteps.astype(np.int64).astype(np.float32) |
self.scheduler.config.get("base_shift", 0.25), | ||
self.scheduler.config.get("max_shift", 0.75), | ||
) | ||
_, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu) | |
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas, mu=mu) |
we updated our scheduler to work with cogview4 - is there any reason we still cannot use the 1scheduler.set_timesteps1 to set timesteps?
self.scheduler.config.get("max_shift", 0.75), | ||
) | ||
_, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu) | ||
timesteps = torch.from_numpy(timesteps).to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
timesteps = torch.from_numpy(timesteps).to(device) |
What does this pull request do?
The purpose of this PR is to add a Control module to CogView4, which refers to the implementation of Flux.
Who can review?
@arrow