diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index c514c9d55..bdb703c30 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -8,6 +8,7 @@ from .image import Bounds, Extent, Image from .resources import Arch, ControlMode +from .util import base_type_match class ComfyRunMode(Enum): @@ -40,21 +41,13 @@ def input(self, key: str, default: None = None) -> Input | None: ... def input(self, key: str, default: T | None = None) -> T | Input | None: result = self.inputs.get(key, default) - assert ( - default is None - or type(result) == type(default) - or (isnumber(result) and isnumber(default)) - ) + assert default is None or base_type_match(result, default) return result def output(self, index=0) -> Output: return Output(int(self.id), index) -def isnumber(x): - return isinstance(x, (int, float)) - - class ComfyWorkflow: """Builder for workflows which can be sent to the ComfyUI prompt API.""" diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index edb70b127..dd31d7895 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -16,7 +16,7 @@ from .jobs import Job, JobParams, JobQueue, JobKind from .properties import Property, ObservableProperties from .style import Styles -from .util import user_data_dir, client_logger as log +from .util import base_type_match, user_data_dir, client_logger as log from .ui import theme from . import eventloop @@ -492,7 +492,9 @@ def live_result(self): def _coerce(params: dict[str, Any], types: list[CustomParam]): def use(value, default): - if value is None or not type(value) == type(default): + if default is None: + return value + if value is None or not base_type_match(value, default): return default return value diff --git a/ai_diffusion/ui/custom_workflow.py b/ai_diffusion/ui/custom_workflow.py index dcbaa764e..f36bc5223 100644 --- a/ai_diffusion/ui/custom_workflow.py +++ b/ai_diffusion/ui/custom_workflow.py @@ -17,7 +17,7 @@ from ..root import root from ..settings import settings from ..localization import translate as _ -from ..util import ensure, clamp +from ..util import ensure, clamp, base_type_match from .generation import GenerateButton, ProgressBar, QueueButton, HistoryWidget, create_error_label from .live import LivePreviewArea from .switch import SwitchWidget @@ -51,8 +51,11 @@ def _update(self): assert False, f"Unknown filter: {self.filter}" for l in layers: - if self.findData(l.id) == -1: + index = self.findData(l.id) + if index == -1: self.addItem(l.name, l.id) + elif self.itemText(index) != l.name: + self.setItemText(index, l.name) i = 0 while i < self.count(): if self.itemData(i) not in (l.id for l in layers): @@ -116,8 +119,8 @@ def value(self): return self._widget.value() @value.setter - def value(self, value: int): - self._widget.setValue(value) + def value(self, value: int | float): + self._widget.setValue(int(value)) class FloatParamWidget(QWidget): @@ -164,11 +167,11 @@ def value(self): return self._widget.value() @value.setter - def value(self, value: float): + def value(self, value: float | int): if isinstance(self._widget, QSlider): self._widget.setValue(round(value * 100)) else: - self._widget.setValue(value) + self._widget.setValue(float(value)) class BoolParamWidget(QWidget): @@ -321,25 +324,27 @@ def value(self, value: str): def _create_param_widget(param: CustomParam, parent: QWidget) -> CustomParamWidget: - if param.kind is ParamKind.image_layer: - return LayerSelect("image", parent) - if param.kind is ParamKind.mask_layer: - return LayerSelect("mask", parent) - if param.kind is ParamKind.number_int: - return IntParamWidget(param, parent) - if param.kind is ParamKind.number_float: - return FloatParamWidget(param, parent) - if param.kind is ParamKind.toggle: - return BoolParamWidget(param, parent) - if param.kind is ParamKind.text: - return TextParamWidget(param, parent) - if param.kind in [ParamKind.prompt_positive, ParamKind.prompt_negative]: - return PromptParamWidget(param, parent) - if param.kind is ParamKind.choice: - return ChoiceParamWidget(param, parent) - if param.kind is ParamKind.style: - return StyleParamWidget(parent) - assert False, f"Unknown param kind: {param.kind}" + match param.kind: + case ParamKind.image_layer: + return LayerSelect("image", parent) + case ParamKind.mask_layer: + return LayerSelect("mask", parent) + case ParamKind.number_int: + return IntParamWidget(param, parent) + case ParamKind.number_float: + return FloatParamWidget(param, parent) + case ParamKind.toggle: + return BoolParamWidget(param, parent) + case ParamKind.text: + return TextParamWidget(param, parent) + case ParamKind.prompt_positive | ParamKind.prompt_negative: + return PromptParamWidget(param, parent) + case ParamKind.choice: + return ChoiceParamWidget(param, parent) + case ParamKind.style: + return StyleParamWidget(parent) + case _: + assert False, f"Unknown param kind: {param.kind}" class WorkflowParamsWidget(QWidget): @@ -375,7 +380,7 @@ def value(self): def value(self, values: dict[str, Any]): for name, value in values.items(): if widget := self._widgets.get(name): - if type(widget.value) == type(value): + if base_type_match(widget.value, value): widget.value = value diff --git a/ai_diffusion/util.py b/ai_diffusion/util.py index 5fcf91c26..0f4cd47b2 100644 --- a/ai_diffusion/util.py +++ b/ai_diffusion/util.py @@ -120,6 +120,14 @@ def median_or_zero(values: Iterable[float]) -> float: return 0 +def isnumber(x): + return isinstance(x, (int, float)) + + +def base_type_match(a, b): + return type(a) == type(b) or (isnumber(a) and isnumber(b)) + + def unique(seq: Sequence[T], key) -> list[T]: seen = set() return [x for x in seq if (k := key(x)) not in seen and not seen.add(k)]