From b9a1d8af9e14e645e43e95504240374f46f89c60 Mon Sep 17 00:00:00 2001 From: Acly Date: Fri, 25 Oct 2024 13:14:27 +0200 Subject: [PATCH] Send text from custom workflows to Krita #1285 --- ai_diffusion/client.py | 23 +++++-- ai_diffusion/comfy_client.py | 37 +++++++++-- ai_diffusion/custom_workflow.py | 17 +++++ ai_diffusion/model.py | 6 +- ai_diffusion/ui/custom_workflow.py | 103 ++++++++++++++++++++++++++++- tests/test_custom_workflow.py | 48 +++++++++++++- 6 files changed, 217 insertions(+), 17 deletions(-) diff --git a/ai_diffusion/client.py b/ai_diffusion/client.py index 395ceee2d..3ce36b82a 100644 --- a/ai_diffusion/client.py +++ b/ai_diffusion/client.py @@ -26,6 +26,22 @@ class ClientEvent(Enum): queued = 6 upload = 7 published = 8 + output = 9 + + +class TextOutput(NamedTuple): + key: str + name: str + text: str + mime: str + + +class SharedWorkflow(NamedTuple): + publisher: str + workflow: dict + + +ClientOutput = dict | SharedWorkflow | TextOutput class ClientMessage(NamedTuple): @@ -33,7 +49,7 @@ class ClientMessage(NamedTuple): job_id: str = "" progress: float = 0 images: ImageCollection | None = None - result: dict | SharedWorkflow | None = None + result: ClientOutput | None = None error: str | None = None @@ -69,11 +85,6 @@ def parse(data: dict): return DeviceInfo("cpu", "unknown", 0) -class SharedWorkflow(NamedTuple): - publisher: str - workflow: dict - - class CheckpointInfo(NamedTuple): filename: str arch: Arch diff --git a/ai_diffusion/comfy_client.py b/ai_diffusion/comfy_client.py index d8927207a..0516618dc 100644 --- a/ai_diffusion/comfy_client.py +++ b/ai_diffusion/comfy_client.py @@ -11,7 +11,7 @@ from .api import WorkflowInput from .client import Client, CheckpointInfo, ClientMessage, ClientEvent, DeviceInfo, ClientModels -from .client import SharedWorkflow, TranslationPackage, ClientFeatures +from .client import SharedWorkflow, TranslationPackage, ClientFeatures, TextOutput from .client import filter_supported_styles, loras_to_upload from .files import FileFormat from .image import Image, ImageCollection @@ -308,10 +308,13 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr log.error(f"Received message {msg} but there is no active job") if msg["type"] == "executed": - job = self._get_active_job(msg["data"]["prompt_id"]) - pose_json = _extract_pose_json(msg) - if job and pose_json: - result = pose_json + if job := self._get_active_job(msg["data"]["prompt_id"]): + text_output = _extract_text_output(job.local_id, msg) + if text_output is not None: + await self._messages.put(text_output) + pose_json = _extract_pose_json(msg) + if pose_json is not None: + result = pose_json if msg["type"] == "execution_error": job = self._get_active_job(msg["data"]["prompt_id"]) @@ -733,3 +736,27 @@ def _extract_pose_json(msg: dict): except Exception as e: log.warning(f"Error processing message, error={str(e)}, msg={msg}") return None + + +def _extract_text_output(job_id: str, msg: dict): + try: + output = msg["data"]["output"] + if output is not None and "text" in output: + key = msg["data"].get("node") + payload = output["text"] + name, text, mime = (None, None, "text/plain") + if isinstance(payload, list) and len(payload) >= 1: + payload = payload[0] + if isinstance(payload, dict): + text = payload.get("text") + name = payload.get("name") + mime = payload.get("content-type", mime) + elif isinstance(payload, str): + text = payload + name = f"Node {key}" + if text is not None and name is not None: + result = TextOutput(key, name, text, mime) + return ClientMessage(ClientEvent.output, job_id, result=result) + except Exception as e: + log.warning(f"Error processing message, error={str(e)}, msg={msg}") + return None diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index dd31d7895..7b9f5a50f 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -10,6 +10,7 @@ from PyQt5.QtCore import pyqtSignal from .api import WorkflowInput +from .client import TextOutput, ClientOutput from .comfy_workflow import ComfyWorkflow, ComfyNode from .connection import Connection, ConnectionState from .image import Bounds, Image @@ -323,6 +324,7 @@ class CustomWorkspace(QObject, ObservableProperties): mode = Property(CustomGenerationMode.regular, setter="_set_mode") is_live = Property(False, setter="toggle_live") has_result = Property(False) + outputs = Property({}) workflow_id_changed = pyqtSignal(str) graph_changed = pyqtSignal() @@ -331,6 +333,7 @@ class CustomWorkspace(QObject, ObservableProperties): is_live_changed = pyqtSignal(bool) result_available = pyqtSignal(Image) has_result_changed = pyqtSignal(bool) + outputs_changed = pyqtSignal(dict) modified = pyqtSignal(QObject, str) _live_poll_rate = 0.1 @@ -345,6 +348,7 @@ def __init__(self, workflows: WorkflowCollection, generator: ImageGenerator, job self._last_input: WorkflowInput | None = None self._last_result: Image | None = None self._last_job: JobParams | None = None + self._new_outputs: list[str] = [] jobs.job_finished.connect(self._handle_job_finished) workflows.dataChanged.connect(self._update_workflow) @@ -463,7 +467,20 @@ def collect_parameters(self, layers: "LayerManager", bounds: Bounds): return params + def show_output(self, output: ClientOutput | None): + if isinstance(output, TextOutput): + self._new_outputs.append(output.key) + self.outputs[output.key] = output + self.outputs_changed.emit(self.outputs) + def _handle_job_finished(self, job: Job): + to_remove = [k for k in self.outputs.keys() if k not in self._new_outputs] + for key in to_remove: + del self.outputs[key] + if len(to_remove) > 0: + self.outputs_changed.emit(self.outputs) + self._new_outputs.clear() + if job.kind is JobKind.live_preview: if len(job.results) > 0: self._last_result = job.results[0] diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index ab6ed8a66..ae9c01f5b 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -17,7 +17,7 @@ from .settings import ApplyBehavior, settings from .network import NetworkError from .image import Extent, Image, Mask, Bounds, DummyImage -from .client import ClientMessage, ClientEvent, SharedWorkflow +from .client import ClientMessage, ClientEvent, ClientOutput from .client import filter_supported_styles, resolve_arch from .custom_workflow import CustomWorkspace, WorkflowCollection, CustomGenerationMode from .document import Document, KritaDocument @@ -461,6 +461,8 @@ def handle_message(self, message: ClientMessage): self.jobs.notify_started(job) self.progress_kind = ProgressKind.upload self.progress = message.progress + elif message.event is ClientEvent.output: + self.custom.show_output(message.result) elif message.event is ClientEvent.finished: if message.images: self.jobs.set_results(job, message.images) @@ -604,7 +606,7 @@ def apply_generated_result(self, job_id: str, index: int): self.jobs.selection = None self.jobs.notify_used(job_id, index) - def add_control_layer(self, job: Job, result: dict | SharedWorkflow | None): + def add_control_layer(self, job: Job, result: ClientOutput | None): assert job.kind is JobKind.control_layer and job.control if job.control.mode is ControlMode.pose and isinstance(result, (dict, list)): pose = Pose.from_open_pose_json(result) diff --git a/ai_diffusion/ui/custom_workflow.py b/ai_diffusion/ui/custom_workflow.py index f36bc5223..0dfe4b6a8 100644 --- a/ai_diffusion/ui/custom_workflow.py +++ b/ai_diffusion/ui/custom_workflow.py @@ -7,9 +7,11 @@ from PyQt5.QtWidgets import QComboBox, QFileDialog, QFrame, QGridLayout, QHBoxLayout, QMenu from PyQt5.QtWidgets import QLabel, QLineEdit, QListWidgetItem, QMessageBox, QSpinBox, QAction from PyQt5.QtWidgets import QToolButton, QVBoxLayout, QWidget, QSlider, QDoubleSpinBox +from PyQt5.QtWidgets import QScrollArea, QTextEdit, QSizePolicy from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows, WorkflowSource from ..custom_workflow import CustomGenerationMode +from ..client import TextOutput from ..jobs import JobKind from ..model import Model, ApplyBehavior from ..properties import Binding, Bind, bind, bind_combo @@ -22,6 +24,7 @@ from .live import LivePreviewArea from .switch import SwitchWidget from .widget import TextPromptWidget, WorkspaceSelectWidget, StyleSelectWidget +from .settings_widgets import ExpanderButton from . import theme @@ -384,6 +387,93 @@ def value(self, values: dict[str, Any]): widget.value = value +class WorkflowOutputsWidget(QWidget): + def __init__(self, parent: QWidget): + super().__init__(parent) + self._value: dict[str, TextOutput] = {} + + self._scroll_area = QScrollArea(self) + self._scroll_area.setFrameShape(QFrame.Shape.NoFrame) + self._scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + + self.expander = ExpanderButton(_("Text Output"), self) + self.expander.setStyleSheet("QToolButton { border: none; }") + self.expander.setChecked(True) + self.expander.toggled.connect(self._scroll_area.setVisible) + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self.expander) + layout.addWidget(self._scroll_area) + + @property + def value(self): + return self._value + + @value.setter + def value(self, value: dict[str, TextOutput]): + self._value = value + self._update() + + def _update(self): + if len(self._value) == 0: + self.expander.hide() + self._scroll_area.hide() + return + elif not self.expander.isVisible(): + self.expander.show() + self._scroll_area.show() + + widget = QWidget(self._scroll_area) + layout = QGridLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.setColumnMinimumWidth(1, 8) + layout.setColumnStretch(2, 1) + widget.setLayout(layout) + + line = 0 + text_areas: list[QTextEdit] = [] + for output in self._value.values(): + label = QLabel(output.name, widget) + if (not output.mime or output.mime == "text/plain") and len(output.text) < 40: + value = QLabel(output.text, widget) + value.setWordWrap(True) + value.setMinimumWidth(40) + layout.addWidget(label, line, 0) + layout.addWidget(value, line, 2) + line += 1 + else: + value = QTextEdit(widget) + value.setFrameShape(QFrame.Shape.StyledPanel) + value.setStyleSheet( + "QTextEdit { background: transparent; border-left: 1px solid %s; padding-left: 2px; }" + % theme.line + ) + value.setReadOnly(True) + match output.mime: + case "" | "text/plain": + value.setPlainText(output.text) + case "text/html": + value.setHtml(output.text) + case "text/markdown": + value.setMarkdown(output.text) + layout.addWidget(label, line, 0, 1, 3) + layout.addWidget(value, line + 1, 0, 1, 3) + text_areas.append(value) + line += 2 + + layout.setRowStretch(line, 1) + widget.setFixedWidth(self._scroll_area.width() - 8) + self._scroll_area.setWidget(widget) + if self.expander.isChecked(): + widget.show() + + for w in text_areas: + size = ensure(w.document()).size().toSize() + w.setFixedHeight(max(size.height() + 2, self.fontMetrics().height() + 6)) + widget.adjustSize() + + def popup_on_error(func): @wraps(func) def wrapper(self, *args, **kwargs): @@ -487,6 +577,9 @@ def __init__(self): self._progress_bar = ProgressBar(self) self._error_text = create_error_label(self) + self._outputs = WorkflowOutputsWidget(self) + self._outputs.expander.toggled.connect(self._update_layout) + self._history = HistoryWidget(self) self._history.item_activated.connect(self.apply_result) @@ -525,12 +618,17 @@ def __init__(self): self._layout.addLayout(actions_layout) self._layout.addWidget(self._progress_bar) self._layout.addWidget(self._error_text) - self._layout.addWidget(self._history) - self._layout.addWidget(self._live_preview) + self._layout.addWidget(self._outputs, stretch=1) + self._layout.addWidget(self._history, stretch=3) + self._layout.addWidget(self._live_preview, stretch=5) self.setLayout(self._layout) self._update_ui() + def _update_layout(self): + stretch = 1 if self._outputs.expander.isChecked() else 0 + self._layout.setStretchFactor(self._outputs, stretch) + @property def model(self): return self._model @@ -543,6 +641,7 @@ def model(self, model: Model): self._model_bindings = [ bind(model, "workspace", self._workspace_select, "value", Bind.one_way), bind_combo(model.custom, "workflow_id", self._workflow_select, Bind.one_way), + bind(model.custom, "outputs", self._outputs, "value", Bind.one_way), model.workspace_changed.connect(self._cancel_name), model.custom.graph_changed.connect(self._update_current_workflow), model.error_changed.connect(self._error_text.setText), diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index ebf1693d8..3a5166e38 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -1,19 +1,21 @@ import json import pytest +from copy import copy from pathlib import Path from PyQt5.QtCore import Qt from ai_diffusion.api import CustomWorkflowInput, ImageInput, WorkflowInput -from ai_diffusion.client import Client, ClientModels, CheckpointInfo +from ai_diffusion.client import Client, ClientModels, CheckpointInfo, TextOutput from ai_diffusion.connection import Connection, ConnectionState from ai_diffusion.comfy_workflow import ComfyNode, ComfyWorkflow, Output from ai_diffusion.custom_workflow import CustomWorkflow, WorkflowSource, WorkflowCollection from ai_diffusion.custom_workflow import SortedWorkflows, CustomWorkspace from ai_diffusion.custom_workflow import CustomParam, ParamKind, workflow_parameters from ai_diffusion.image import Image, Extent -from ai_diffusion.jobs import JobQueue +from ai_diffusion.jobs import JobQueue, Job, JobKind, JobParams from ai_diffusion.style import Style from ai_diffusion.resources import Arch +from ai_diffusion.image import Bounds from ai_diffusion import workflow from .config import test_dir @@ -289,6 +291,48 @@ def test_parameters(): ] +def test_text_output(): + connection_workflows = {"connection1": make_dummy_graph(42)} + connection = create_mock_connection(connection_workflows, {}) + workflows = WorkflowCollection(connection) + + output_events = [] + + def on_output(outputs: dict): + output_events.append(copy(outputs)) + + text_messages = [ + TextOutput("1", "Food", "Dumpling", "text/plain"), + TextOutput("2", "Drink", "Tea", "text/plain"), + TextOutput("3", "Time", "Moonrise", "text/plain"), + TextOutput("1", "Food", "Sweet Potato", "text/plain"), + ] + + jobs = JobQueue() + workspace = CustomWorkspace(workflows, dummy_generate, jobs) + workspace.outputs_changed.connect(on_output) + workspace.show_output(text_messages[0]) + workspace.show_output(text_messages[1]) + assert workspace.outputs == {"1": text_messages[0], "2": text_messages[1]} + + job_params = JobParams(Bounds(0, 0, 1, 1), "test") + jobs.job_finished.emit(Job("job1", JobKind.diffusion, job_params)) + + workspace.show_output(text_messages[3]) + workspace.show_output(text_messages[2]) + jobs.job_finished.emit(Job("job2", JobKind.diffusion, job_params)) + assert workspace.outputs == {"1": text_messages[3], "3": text_messages[2]} + + assert output_events == [ + {"1": text_messages[0]}, # show_output(0) + {"1": text_messages[0], "2": text_messages[1]}, # show_output(1) + # job_finished(job1) - no changes + {"1": text_messages[3], "2": text_messages[1]}, # show_output(3) + {"1": text_messages[3], "2": text_messages[1], "3": text_messages[2]}, # show_output(2) + {"1": text_messages[3], "3": text_messages[2]}, # job_finished(job2) + ] + + def test_expand(): ext = ComfyWorkflow() in_img, width, height, seed = ext.add("ETN_KritaCanvas", 4)