Skip to content

Commit

Permalink
Fix compatibility with main
Browse files Browse the repository at this point in the history
  • Loading branch information
carson-katri committed Jul 7, 2023
1 parent 3450579 commit 74185e0
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 41 deletions.
7 changes: 7 additions & 0 deletions api/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ def draw_memory_optimizations(self, layout, context):
def draw_extra(self, layout, context):
"""Draw additional UI in the panel"""
...

def get_batch_size(self, context) -> int:
"""Return the selected batch size for the backend (if applicable).
A default implementation is provided that returns `1`.
"""
return 1

def generate(
self,
Expand Down
2 changes: 1 addition & 1 deletion api/models/generation_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class GenerationResult:
"""

@staticmethod
def tile_images(results: list['GenerationResult']):
def tile_images(results: list['GenerationResult']) -> NDArray:
images = [result.image for result in results]
if len(images) == 0:
return None
Expand Down
3 changes: 3 additions & 0 deletions diffusers_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def model_case(model, i):
def list_schedulers(self, context) -> List[str]:
return [scheduler.value for scheduler in Scheduler]

def get_batch_size(self, context) -> int:
return self.batch_size

def optimizations(self) -> Optimizations:
optimizations = Optimizations()
for prop in dir(self):
Expand Down
4 changes: 2 additions & 2 deletions engine/nodes/pipeline_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def execute(self, context, prompt, negative_prompt, width, height, steps, seed,
self.prompt.steps = steps
self.prompt.seed = str(seed)
self.prompt.cfg_scale = cfg_scale
args = self.prompt.generate_args()
args = self.prompt.generate_args(context)

shared_args = context.depsgraph.scene.dream_textures_engine_prompt.generate_args()
shared_args = context.depsgraph.scene.dream_textures_engine_prompt.generate_args(context)

# the source image is a default color, ignore it.
if np.array(source_image).shape == (4,):
Expand Down
22 changes: 8 additions & 14 deletions operators/dream_texture.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,12 @@ def execute(self, context):
screen = context.screen
scene = context.scene

generated_args = scene.dream_textures_prompt.generate_args()
generated_args = scene.dream_textures_prompt.generate_args(context)
context.scene.seamless_result.update_args(generated_args)
context.scene.seamless_result.update_args(history_template, as_id=True)

init_image = None
if generated_args['use_init_img']:
init_image = get_source_image(context, generated_args['init_img_src'])
if init_image is not None:
init_image = np.flipud(
(np.array(init_image.pixels) * 255)
.astype(np.uint8)
.reshape((init_image.size[1], init_image.size[0], init_image.channels))
)

# Setup the progress indicator
bpy.types.Scene.dream_textures_progress = bpy.props.IntProperty(name="", default=0, min=0, max=generated_args['steps'])
bpy.types.Scene.dream_textures_progress = bpy.props.IntProperty(name="", default=0, min=0, max=generated_args.steps)
scene.dream_textures_info = "Starting..."

last_data_block = None
Expand All @@ -109,13 +99,14 @@ def step_callback(progress: List[api.GenerationResult]):
for region in area.regions:
if region.type == "UI":
region.tag_redraw()
image = api.GenerationResult.tile_images(progress)
last_data_block = bpy_image(f"Step {progress[-1].progress}/{progress[-1].total}", image.shape[1], image.shape[0], image.ravel(), last_data_block)
for area in screen.areas:
if area.type == 'IMAGE_EDITOR' and not area.spaces.active.use_image_pin:
area.spaces.active.image = last_data_block

iteration = 0
iteration_limit = len(file_batch_lines) if is_file_batch else generated_args['iterations']
iteration_limit = len(file_batch_lines) if is_file_batch else generated_args.iterations
iteration_square = math.ceil(math.sqrt(iteration_limit))
node_pad = np.array((20, 20))
node_size = np.array((240, 277)) + node_pad
Expand Down Expand Up @@ -177,9 +168,12 @@ def callback(results: List[api.GenerationResult] | Exception):
iteration += 1
if iteration < iteration_limit:
generate_next()
else:
scene.dream_textures_info = ""
scene.dream_textures_progress = 0

def generate_next():
backend.generate(**prompt.generate_args(context, iteration=iteration), step_callback=step_callback, callback=callback)
backend.generate(prompt.generate_args(context, iteration=iteration), step_callback=step_callback, callback=callback)

generate_next()

Expand Down
4 changes: 2 additions & 2 deletions operators/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def on_exception(_, exception):

context.scene.dream_textures_info = "Starting..."
if context.scene.dream_textures_project_use_control_net:
generated_args = context.scene.dream_textures_project_prompt.generate_args()
generated_args = context.scene.dream_textures_project_prompt.generate_args(context)
del generated_args['control']
future = gen.control_net(
control=[np.flipud(depth)], # the depth control needs to be flipped.
Expand All @@ -430,7 +430,7 @@ def on_exception(_, exception):
future = gen.depth_to_image(
depth=depth,
image=init_img_path,
**context.scene.dream_textures_project_prompt.generate_args()
**context.scene.dream_textures_project_prompt.generate_args(context)
)
gen._active_generation_future = future
future.call_done_on_exception = False
Expand Down
2 changes: 1 addition & 1 deletion operators/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def step_progress_update(self, context):
.reshape((input_image.size[1], input_image.size[0], input_image.channels))
)

generated_args = context.scene.dream_textures_upscale_prompt.generate_args()
generated_args = context.scene.dream_textures_upscale_prompt.generate_args(context)
context.scene.dream_textures_upscale_seamless_result.update_args(generated_args)

# Setup the progress indicator
Expand Down
18 changes: 4 additions & 14 deletions property_groups/dream_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,27 +206,18 @@ def get_seed(self):
h = ~h
return (h & 0xFFFFFFFF) ^ (h >> 32) # 64 bit hash down to 32 bits

def get_optimizations(self: DreamPrompt):
optimizations = Optimizations()
for prop in dir(self):
split_name = prop.replace('optimizations_', '')
if prop.startswith('optimizations_') and hasattr(optimizations, split_name):
setattr(optimizations, split_name, getattr(self, prop))
if self.optimizations_attention_slice_size_src == 'auto':
optimizations.attention_slice_size = 'auto'
return optimizations

def generate_args(self, context, iteration=0) -> api.GenerationArguments:
is_file_batch = self.prompt_structure == file_batch_structure.id
file_batch_lines = []
file_batch_lines_negative = []
if is_file_batch:
file_batch_lines = [line.body for line in context.scene.dream_textures_prompt_file.lines if len(line.body.strip()) > 0]
file_batch_lines_negative = [""] * len(file_batch_lines)

optim: Optimizations = self.get_optimizations()

backend: api.Backend = self.get_backend()
batch_size = backend.get_batch_size(context)
iteration_limit = len(file_batch_lines) if is_file_batch else self.iterations
batch_size = min(optim.batch_size, iteration_limit-iteration)
batch_size = min(batch_size, iteration_limit-iteration)

task: api.Task = api.PromptToImage()
if self.use_init_img:
Expand Down Expand Up @@ -320,7 +311,6 @@ def get_backend(self) -> api.Backend:
DreamPrompt.generate_prompt = generate_prompt
DreamPrompt.get_prompt_subject = get_prompt_subject
DreamPrompt.get_seed = get_seed
DreamPrompt.get_optimizations = get_optimizations
DreamPrompt.generate_args = generate_args
DreamPrompt.validate = validate
DreamPrompt.get_backend = get_backend
21 changes: 14 additions & 7 deletions property_groups/seamless_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ..generator_process.actions.detect_seamless import SeamlessAxes
from ..generator_process import Generator
from ..preferences import StableDiffusionPreferences

from ..api.models import GenerationArguments

def update(self, context):
if hasattr(context.area, "regions"):
Expand Down Expand Up @@ -52,9 +52,16 @@ def result(future):
self.result = future.result().text
Generator.shared().detect_seamless(pixels).add_done_callback(result)

def update_args(self, args: dict[str, any], as_id=False):
if args['seamless_axes'] == SeamlessAxes.AUTO and self.result != 'Processing':
if as_id:
args['seamless_axes'] = SeamlessAxes(self.result).id
else:
args['seamless_axes'] = SeamlessAxes(self.result)
def update_args(self, args, as_id=False):
if isinstance(args, GenerationArguments):
if args.seamless_axes == SeamlessAxes.AUTO and self.result != 'Processing':
if as_id:
args.seamless_axes = SeamlessAxes(self.result).id
else:
args.seamless_axes = SeamlessAxes(self.result)
else:
if args['seamless_axes'] == SeamlessAxes.AUTO and self.result != 'Processing':
if as_id:
args['seamless_axes'] = SeamlessAxes(self.result).id
else:
args['seamless_axes'] = SeamlessAxes(self.result)

0 comments on commit 74185e0

Please sign in to comment.