Skip to content

Commit

Permalink
fix hires and corrections with batch processing
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Dec 31, 2024
1 parent 86ac38d commit 05d5ac0
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 49 deletions.
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Change Log for SD.Next

## Update for 2024-12-30
## Update for 2024-12-31

NYE refresh release with quite a few optimizatios and bug fixes...

- **LoRA**:
- **Sana** support
Expand Down Expand Up @@ -40,7 +42,9 @@
- do not show disabled networks
- enable debug logging by default
- image width/height calculation when doing img2img

- corrections with batch processing
- hires with refiner prompt and batch processing

## Update for 2024-12-24

### Highlights for 2024-12-24
Expand Down
13 changes: 7 additions & 6 deletions modules/processing_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,18 @@ def __init__(self,
# overrides
override_settings: Dict[str, Any] = {},
override_settings_restore_afterwards: bool = True,
task_args: Dict[str, Any] = {},
ops: List[str] = [],
# metadata
extra_generation_params: Dict[Any, Any] = {},
):

# extra args set by processing loop
self.task_args = {}
self.task_args = task_args

# state items
self.state: str = ''
self.ops = []
self.ops = ops
self.skip = []
self.color_corrections = []
self.is_control = False
Expand Down Expand Up @@ -266,7 +268,6 @@ def __init__(self,
self.s_max = shared.opts.s_max
self.s_tmin = shared.opts.s_tmin
self.s_tmax = float('inf') # not representable as a standard ui option
self.task_args = {}

# ip adapter
self.ip_adapter_names = []
Expand Down Expand Up @@ -299,6 +300,9 @@ def __init__(self,
self.negative_embeds = []
self.negative_pooleds = []

def __str__(self):
return f'{self.__class__.__name__}: {self.__dict__}'

@property
def sd_model(self):
return shared.sd_model
Expand Down Expand Up @@ -339,9 +343,6 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
def close(self):
self.sampler = None # pylint: disable=attribute-defined-outside-init

def __str__(self):
return f'{self.__class__.__name__}: {self.__dict__}'


class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def __init__(self, **kwargs):
Expand Down
48 changes: 18 additions & 30 deletions modules/processing_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def sharpen_tensor(tensor, ratio=0):
if ratio == 0:
debug("Sharpen: Early exit")
# debug("Sharpen: Early exit")
return tensor
kernel = torch.ones((3, 3), dtype=tensor.dtype, device=tensor.device)
kernel[1, 1] = 5.0
Expand All @@ -42,18 +42,18 @@ def soft_clamp_tensor(tensor, threshold=0.8, boundary=4):
min_replace = ((tensor + threshold) / (min_vals + threshold)) * (-boundary + threshold) - threshold
under_mask = tensor < -threshold
tensor = torch.where(over_mask, max_replace, torch.where(under_mask, min_replace, tensor))
debug(f'HDR soft clamp: threshold={threshold} boundary={boundary} shape={tensor.shape}')
# debug(f'HDR soft clamp: threshold={threshold} boundary={boundary} shape={tensor.shape}')
return tensor


def center_tensor(tensor, channel_shift=0.0, full_shift=0.0, offset=0.0):
if channel_shift == 0 and full_shift == 0 and offset == 0:
return tensor
debug(f'HDR center: Before Adjustment: Full mean={tensor.mean().item()} Channel means={tensor.mean(dim=(-1, -2)).float().cpu().numpy()}')
# debug(f'HDR center: Before Adjustment: Full mean={tensor.mean().item()} Channel means={tensor.mean(dim=(-1, -2)).float().cpu().numpy()}')
tensor -= tensor.mean(dim=(-1, -2), keepdim=True) * channel_shift
tensor -= tensor.mean() * full_shift - offset
debug(f'HDR center: channel-shift={channel_shift} full-shift={full_shift}')
debug(f'HDR center: After Adjustment: Full mean={tensor.mean().item()} Channel means={tensor.mean(dim=(-1, -2)).float().cpu().numpy()}')
# debug(f'HDR center: channel-shift={channel_shift} full-shift={full_shift}')
# debug(f'HDR center: After Adjustment: Full mean={tensor.mean().item()} Channel means={tensor.mean(dim=(-1, -2)).float().cpu().numpy()}')
return tensor


Expand All @@ -65,7 +65,7 @@ def maximize_tensor(tensor, boundary=1.0):
max_val = tensor.max()
normalization_factor = boundary / max(abs(min_val), abs(max_val))
tensor *= normalization_factor
debug(f'HDR maximize: boundary={boundary} min={min_val} max={max_val} factor={normalization_factor}')
# debug(f'HDR maximize: boundary={boundary} min={min_val} max={max_val} factor={normalization_factor}')
return tensor


Expand All @@ -78,43 +78,34 @@ def get_color(colorstr):

def color_adjust(tensor, colorstr, ratio):
color = get_color(colorstr)
debug(f'HDR tint: str={colorstr} color={color} ratio={ratio}')
# debug(f'HDR tint: str={colorstr} color={color} ratio={ratio}')
for i in range(3):
tensor[i] = center_tensor(tensor[i], full_shift=1, offset=color[i]*(ratio/2))
return tensor


def correction(p, timestep, latent):
if timestep > 950 and p.hdr_clamp:
p.extra_generation_params["HDR clamp"] = f'{p.hdr_threshold}/{p.hdr_boundary}'
latent = soft_clamp_tensor(latent, threshold=p.hdr_threshold, boundary=p.hdr_boundary)
if 600 < timestep < 900 and (p.hdr_color != 0 or p.hdr_tint_ratio != 0):
if p.hdr_brightness != 0:
latent[0:1] = center_tensor(latent[0:1], full_shift=float(p.hdr_mode), offset=2*p.hdr_brightness) # Brightness
p.extra_generation_params["HDR brightness"] = f'{p.hdr_brightness}'
p.hdr_brightness = 0
if p.hdr_color != 0:
latent[1:] = center_tensor(latent[1:], channel_shift=p.hdr_color, full_shift=float(p.hdr_mode)) # Color
p.extra_generation_params["HDR color"] = f'{p.hdr_color}'
p.hdr_color = 0
if p.hdr_tint_ratio != 0:
latent = color_adjust(latent, p.hdr_color_picker, p.hdr_tint_ratio)
p.hdr_tint_ratio = 0
p.extra_generation_params["HDR clamp"] = f'{p.hdr_threshold}/{p.hdr_boundary}'
if 600 < timestep < 900 and p.hdr_color != 0:
latent[1:] = center_tensor(latent[1:], channel_shift=p.hdr_color, full_shift=float(p.hdr_mode)) # Color
p.extra_generation_params["HDR color"] = f'{p.hdr_color}'
if 600 < timestep < 900 and p.hdr_tint_ratio != 0:
latent = color_adjust(latent, p.hdr_color_picker, p.hdr_tint_ratio)
p.extra_generation_params["HDR tint"] = f'{p.hdr_tint_ratio}'
if timestep < 200 and (p.hdr_brightness != 0): # do it late so it doesn't change the composition
if p.hdr_brightness != 0:
latent[0:1] = center_tensor(latent[0:1], full_shift=float(p.hdr_mode), offset=2*p.hdr_brightness) # Brightness
p.extra_generation_params["HDR brightness"] = f'{p.hdr_brightness}'
p.hdr_brightness = 0
latent[0:1] = center_tensor(latent[0:1], full_shift=float(p.hdr_mode), offset=p.hdr_brightness) # Brightness
p.extra_generation_params["HDR brightness"] = f'{p.hdr_brightness}'
if timestep < 350 and p.hdr_sharpen != 0:
p.extra_generation_params["HDR sharpen"] = f'{p.hdr_sharpen}'
per_step_ratio = 2 ** (timestep / 250) * p.hdr_sharpen / 16
if abs(per_step_ratio) > 0.01:
debug(f"HDR Sharpen: timestep={timestep} ratio={p.hdr_sharpen} val={per_step_ratio}")
latent = sharpen_tensor(latent, ratio=per_step_ratio)
p.extra_generation_params["HDR sharpen"] = f'{p.hdr_sharpen}'
if 1 < timestep < 100 and p.hdr_maximize:
p.extra_generation_params["HDR max"] = f'{p.hdr_max_center}/{p.hdr_max_boundry}'
latent = center_tensor(latent, channel_shift=p.hdr_max_center, full_shift=1.0)
latent = maximize_tensor(latent, boundary=p.hdr_max_boundry)
p.extra_generation_params["HDR max"] = f'{p.hdr_max_center}/{p.hdr_max_boundry}'
return latent


Expand All @@ -129,9 +120,6 @@ def correction_callback(p, timestep, kwargs, initial: bool = False):
elif skip_correction:
return kwargs
latents = kwargs["latents"]
if debug_enabled:
debug('')
debug(f' Timestep: {timestep}')
# debug(f'HDR correction: latents={latents.shape}')
if len(latents.shape) == 4: # standard batched latent
for i in range(latents.shape[0]):
Expand Down
8 changes: 4 additions & 4 deletions modules/processing_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
hires_args = set_pipeline_args(
p=p,
model=shared.sd_model,
prompts=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,
negative_prompts=[p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts,
prompts_2=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,
negative_prompts_2=[p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts,
prompts=len(output.images)* [p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,
negative_prompts=len(output.images) * [p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts,
prompts_2=len(output.images) * [p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,
negative_prompts_2=len(output.images) * [p.refiner_negative] if len(p.refiner_negative) > 0 else p.negative_prompts,
num_inference_steps=calculate_hires_steps(p),
eta=shared.opts.scheduler_eta,
guidance_scale=p.image_cfg_scale if p.image_cfg_scale is not None else p.cfg_scale,
Expand Down
12 changes: 6 additions & 6 deletions scripts/freescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ def ui(self, _is_img2img): # ui elements
s1_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step')
with gr.Row():
s2_enable = gr.Checkbox(value=True, label='2nd Stage')
s2_scale = gr.Slider(minimum=1, maximum=8.0, value=2.0, label='Scale')
s2_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step')
s2_scale = gr.Slider(minimum=1, maximum=8.0, value=2.0, label='2nd Scale')
s2_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='2nd Restart step')
with gr.Row():
s3_enable = gr.Checkbox(value=False, label='3rd Stage')
s3_scale = gr.Slider(minimum=1, maximum=8.0, value=3.0, label='Scale')
s3_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step')
s3_scale = gr.Slider(minimum=1, maximum=8.0, value=3.0, label='3rd Scale')
s3_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='3rd Restart step')
with gr.Row():
s4_enable = gr.Checkbox(value=False, label='4th Stage')
s4_scale = gr.Slider(minimum=1, maximum=8.0, value=4.0, label='Scale')
s4_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step')
s4_scale = gr.Slider(minimum=1, maximum=8.0, value=4.0, label='4th Scale')
s4_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='4th Restart step')
return [cosine_scale, override_sampler, cosine_scale_bg, dilate_tau, s1_enable, s1_scale, s1_restart, s2_enable, s2_scale, s2_restart, s3_enable, s3_scale, s3_restart, s4_enable, s4_scale, s4_restart]

def run(self, p: processing.StableDiffusionProcessing, cosine_scale, override_sampler, cosine_scale_bg, dilate_tau, s1_enable, s1_scale, s1_restart, s2_enable, s2_scale, s2_restart, s3_enable, s3_scale, s3_restart, s4_enable, s4_scale, s4_restart): # pylint: disable=arguments-differ
Expand Down
2 changes: 1 addition & 1 deletion webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def start_ui():
shared.log.info(f'API ReDocs: {local_url[:-1]}/redocs') # pylint: disable=unsubscriptable-object
if share_url is not None:
shared.log.info(f'Share URL: {share_url}')
shared.log.debug(f'Gradio functions: registered={len(shared.demo.fns)}')
# shared.log.debug(f'Gradio functions: registered={len(shared.demo.fns)}')
shared.demo.server.wants_restart = False
setup_middleware(app, cmd_opts)

Expand Down

0 comments on commit 05d5ac0

Please sign in to comment.