Skip to content

Commit

Permalink
update flux
Browse files Browse the repository at this point in the history
  • Loading branch information
vladmandic committed Sep 4, 2024
1 parent ce94b5a commit db6a52a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 5 deletions.
6 changes: 3 additions & 3 deletions modules/model_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def load_flux_quanto(checkpoint_info, diffusers_load_config):
debug(f'Loading FLUX: quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
if not os.path.exists(quantization_map):
repo_id = checkpoint_info.name.replace('Diffusers/', '').replace('Diffusers\\', '').replace('models--', '').replace('--', '/')
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', **diffusers_load_config)
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
Expand All @@ -72,7 +72,7 @@ def load_flux_quanto(checkpoint_info, diffusers_load_config):
debug(f'Loading FLUX: quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
if not os.path.exists(quantization_map):
repo_id = checkpoint_info.name.replace('Diffusers/', '').replace('Diffusers\\', '').replace('models--', '').replace('--', '/')
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', **diffusers_load_config)
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
Expand All @@ -97,7 +97,7 @@ def load_flux_quanto(checkpoint_info, diffusers_load_config):
return transformer, text_encoder_2


def load_flux_bnb(checkpoint_info, diffusers_load_config, ): # pylint: disable=unused-argument
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
transformer, text_encoder_2 = None, None
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
Expand Down
4 changes: 2 additions & 2 deletions modules/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,10 +500,10 @@ def after(self, p, processed, *args):
s = ScriptSummary('after')
script_index = args[0] if len(args) > 0 else 0
if script_index == 0:
return None
return processed
script = self.selectable_scripts[script_index-1]
if script is None or not hasattr(script, 'after'):
return None
return processed
parsed = p.per_script_args.get(script.title(), args[script.args_from:script.args_to])
after_processed = script.after(p, processed, *parsed)
if after_processed is not None:
Expand Down
55 changes: 55 additions & 0 deletions scripts/cogvideo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import gradio as gr
import diffusers
from modules import scripts, processing, shared, devices, sd_models


class Script(scripts.Script):
def title(self):
return 'CogVideoX'

def show(self, is_img2img):
return shared.native


def ui(self, _is_img2img):
def video_type_change(video_type):
return [
gr.update(visible=video_type != 'None'),
gr.update(visible=video_type == 'GIF' or video_type == 'PNG'),
gr.update(visible=video_type == 'MP4'),
gr.update(visible=video_type == 'MP4'),
]

with gr.Row():
gr.HTML("<span>&nbsp CogVideoX</span><br>")
with gr.Row():
model = gr.Dropdown(label='Model', choices=['THUDM/CogVideoX-2b', 'THUDM/CogVideoX-5b'], value='THUDM/CogVideoX-2b')
sampler = gr.Dropdown(label='Sampler', choices=['DDIM', 'DPM'], value='DDIM')
with gr.Row():
frames = gr.Slider(label='Frames', minimum=1, maximum=64, step=1, value=16)
guidance = gr.Slider(label='Guidance', minimum=0.0, maximum=14.0, step=0.5, value=6.0)
with gr.Row():
offload = gr.Dropdown(label='Offload', choices=['none', 'balanced', 'model', 'sequential'], value='balanced')
override = gr.Checkbox(label='Override resolution', value=True)
with gr.Row():
video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None')
duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False)
with gr.Row():
loop = gr.Checkbox(label='Loop', value=True, visible=False)
pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False)
interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False)
video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, loop, pad, interpolate])
return [model, sampler, frames, guidance, offload, override, video_type, duration, duration, loop, pad, interpolate]

def run(self, p: processing.StableDiffusionProcessing, model, sampler, frames, guidance, offload, override, video_type, duration, loop, pad, interpolate): # pylint: disable=arguments-differ, unused-argument
shared.log.debug(f'CogVideoX: model={model} sampler={sampler} frames={frames} guidance={guidance} offload={offload} override={override} video_type={video_type} duration={duration} loop={loop} pad={pad} interpolate={interpolate}')
p.extra_generation_params['CogVideoX'] = model
p.do_not_save_grid = True
if 'animatediff' not in p.ops:
p.ops.append('cogvideox')

def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, model, sampler, frames, guidance, override_resolution, video_type, duration, loop, pad, interpolate): # pylint: disable=arguments-differ, unused-argument
from modules.images import save_video
if video_type != 'None':
save_video(p, filename=None, images=processed.images, video_type=video_type, duration=duration, loop=loop, pad=pad, interpolate=interpolate)

0 comments on commit db6a52a

Please sign in to comment.