Skip to content

Commit

Permalink
Support remaining samplers, #39
Browse files Browse the repository at this point in the history
  • Loading branch information
AlUlkesh committed Oct 9, 2023
1 parent faf17b4 commit fba7d9e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ You can also install it manually by running the following command from within th

git clone https://github.com/AlUlkesh/sd_save_intermediate_images/ extensions/sd_save_intermediate_images

## Limitations
Does not work with DDIM, PLMS or UNIPC
## Samplers
Works with all a1111 samplers

## Output

Expand Down
38 changes: 29 additions & 9 deletions scripts/sd_save_intermediate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
from modules import scripts
from modules import script_callbacks
from modules.processing import Processed, process_images, fix_seed, create_infotext
try:
from modules.sd_samplers_kdiffusion import KDiffusionSampler
from modules.sd_samplers_common import sample_to_image
except ImportError:
from modules.sd_samplers import KDiffusionSampler, sample_to_image
from modules.sd_samplers_kdiffusion import KDiffusionSampler
from modules.sd_samplers_timesteps import CompVisSampler as TimestepsSampler, samplers_timesteps
from modules.sd_samplers_common import sample_to_image
from modules.images import save_image, FilenameGenerator, get_next_sequence_number
from modules.shared import opts, state, cmd_opts

Expand All @@ -27,7 +25,9 @@
# replace: \1, ssii_add_last_frames, ssii_add_first_frames
# plus debug

orig_callback_state = KDiffusionSampler.callback_state
orig_callback_state_KDiffusionSampler = KDiffusionSampler.callback_state
orig_callback_state_TimestepsSampler = TimestepsSampler.callback_state
orig_callback_state = None
ui_config_backup = os.path.join(scripts.basedir(), "ui-config_backup.json")
video_bat_mode = ""
ui_items = {
Expand Down Expand Up @@ -368,6 +368,12 @@ def hr_active_check(p):
hr_active = False
return hr_active

def is_TimestepsSampler(sampler_name):
for sampler in samplers_timesteps:
if sampler[0] == sampler_name:
return True
return False

class Script(scripts.Script):
def title(self):
return "Save intermediate images during the sampling process"
Expand Down Expand Up @@ -580,6 +586,12 @@ def callback_state(self, d):
"""
callback_state runs after each processing step
"""
logger.debug(f"sampler_name: {p.sampler_name}")
if is_TimestepsSampler(p.sampler_name):
orig_callback_state = orig_callback_state_TimestepsSampler
else:
orig_callback_state = orig_callback_state_KDiffusionSampler

current_step = d["i"]

hr = hr_check(p)
Expand Down Expand Up @@ -652,9 +664,15 @@ def callback_state(self, d):
if ssii_intermediate_type == "According to Live preview subject setting" and index == 0:
image = state.current_image
elif ssii_intermediate_type == "Noisy":
image = sample_to_image(d["x"], index=index)
if d["x"] is None:
image = state.current_image
else:
image = sample_to_image(d["x"], index=index)
else:
image = sample_to_image(d["denoised"], index=index)
if d["denoised"] is None:
image = state.current_image
else:
image = sample_to_image(d["denoised"], index=index)

logger.debug(f"ssii_intermediate_type, ssii_every_n, ssii_start_at_n, ssii_stop_at_n: {ssii_intermediate_type}, {ssii_every_n}, {ssii_start_at_n}, {ssii_stop_at_n}")
logger.debug(f"Step, abs_step, hr, hr_active: {current_step}, {abs_step}, {hr}, {hr_active}")
Expand Down Expand Up @@ -753,10 +771,12 @@ def callback_state(self, d):
return orig_callback_state(self, d)

setattr(KDiffusionSampler, "callback_state", callback_state)
setattr(TimestepsSampler, "callback_state", callback_state)

def postprocess(self, p, processed, ssii_is_active, ssii_final_save, ssii_intermediate_type, ssii_every_n, ssii_start_at_n, ssii_stop_at_n, ssii_mode, ssii_video_format, ssii_mp4_parms, ssii_video_fps, ssii_add_first_frames, ssii_add_last_frames, ssii_smooth, ssii_seconds, ssii_lores, ssii_hires, ssii_ffmpeg_bat, ssii_bat_only, ssii_debug):
logger.debug(f"func: {sys._getframe(0).f_code.co_name}")
setattr(KDiffusionSampler, "callback_state", orig_callback_state)
setattr(KDiffusionSampler, "callback_state", orig_callback_state_KDiffusionSampler)
setattr(TimestepsSampler, "callback_state", orig_callback_state_TimestepsSampler)

# Make video for last batch_count
make_video(p, ssii_is_active, ssii_final_save, ssii_intermediate_type, ssii_every_n, ssii_start_at_n, ssii_stop_at_n, ssii_mode, ssii_video_format, ssii_mp4_parms, ssii_video_fps, ssii_add_first_frames, ssii_add_last_frames, ssii_smooth, ssii_seconds, ssii_lores, ssii_hires, ssii_ffmpeg_bat, ssii_bat_only, ssii_debug)
Expand Down

0 comments on commit fba7d9e

Please sign in to comment.