Skip to content

Commit

Permalink
Fix issue 1137: Query ComfyUI queue state and use it to display count…
Browse files Browse the repository at this point in the history
… of jobs in the queue before ours.
  • Loading branch information
FeepingCreature committed Sep 10, 2024
1 parent f873f86 commit 57d9ba3
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 5 deletions.
2 changes: 2 additions & 0 deletions ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class ClientMessage(NamedTuple):
images: ImageCollection | None = None
result: dict | None = None
error: str | None = None
# jobs queued before our next one
queue_length: int | None = None


class User(QObject, ObservableProperties):
Expand Down
28 changes: 26 additions & 2 deletions ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import Enum
from collections import deque
from itertools import chain, product
from operator import itemgetter
from typing import NamedTuple, Optional, Sequence

from .api import WorkflowInput
Expand Down Expand Up @@ -264,6 +265,7 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr

if msg["type"] == "status":
await self._report(ClientEvent.connected, "")
await self._poll_server_queue()

if msg["type"] == "execution_start":
id = msg["data"]["prompt_id"]
Expand Down Expand Up @@ -345,6 +347,28 @@ async def clear_queue(self):

self._jobs.clear()

async def _poll_server_queue(self):
queue = await self._get("api/queue")
server_jobs = queue["queue_running"] + queue["queue_pending"]
# why are they unsorted to start with...?
server_jobs = sorted(server_jobs, key=itemgetter(0))
server_jobs = [entry[1] for entry in server_jobs]
if not (self._jobs or self._active):
return

if self._active:
first_job = self._active
else:
first_job = self._jobs[0]
# as we got the job from `_jobs` or `_active`, this field must have been set (in `_run_job`).
first_remote_id = util.ensure(await first_job.get_remote_id())
try:
offset = server_jobs.index(first_remote_id)
except ValueError:
# probably just haven't gotten the notification yet
return
await self._report(ClientEvent.queued, first_job.local_id, queue_length=offset)

async def disconnect(self):
if self._is_connected:
self._is_connected = False
Expand Down Expand Up @@ -445,7 +469,7 @@ def _get_active_job(self, remote_id: str) -> Optional[JobInfo]:
return self._active
elif self._active:
log.warning(f"Received message for job {remote_id}, but job {self._active} is active")
if len(self._jobs) == 0:
if not self._jobs:
log.warning(f"Received unknown job {remote_id}")
return None
active = next((j for j in self._jobs if j.remote_id == remote_id), None)
Expand All @@ -456,7 +480,7 @@ def _get_active_job(self, remote_id: str) -> Optional[JobInfo]:
async def _start_job(self, remote_id: str):
if self._active is not None:
log.warning(f"Started job {remote_id}, but {self._active} was never finished")
if len(self._jobs) == 0:
if not self._jobs:
log.warning(f"Received unknown job {remote_id}")
return None

Expand Down
12 changes: 9 additions & 3 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Model(QObject, ObservableProperties):
progress = Property(0.0)
jobs: JobQueue
error = Property("")
queue_length: int = 0

workspace_changed = pyqtSignal(Workspace)
style_changed = pyqtSignal(Style)
Expand All @@ -87,6 +88,7 @@ class Model(QObject, ObservableProperties):
error_changed = pyqtSignal(str)
has_error_changed = pyqtSignal(bool)
modified = pyqtSignal(QObject, str)
queue_length_changed = pyqtSignal(int)

def __init__(self, document: Document, connection: Connection):
super().__init__()
Expand Down Expand Up @@ -408,9 +410,13 @@ def handle_message(self, message: ClientMessage):
return

if message.event is ClientEvent.queued:
self.jobs.notify_started(job)
self.progress = -1
self.progress_changed.emit(-1)
if message.queue_length is not None:
self.queue_length = message.queue_length
self.queue_length_changed.emit(message.queue_length)
if message.queue_length is None or message.queue_length == 0:
self.jobs.notify_started(job)
self.progress = -1
self.progress_changed.emit(-1)
elif message.event is ClientEvent.progress:
self.jobs.notify_started(job)
self.progress_kind = ProgressKind.generation
Expand Down
5 changes: 5 additions & 0 deletions ai_diffusion/ui/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def _connect_model(self):
self._connections = [
self._model.jobs.count_changed.connect(self._update),
self._model.progress_kind_changed.connect(self._update),
self._model.queue_length_changed.connect(self._update),
]

def _update(self):
Expand All @@ -220,6 +221,10 @@ def _update(self):
self.setIcon(theme.icon("queue-upload"))
self.setToolTip(_("Uploading models.") + f" {queued_msg} {cancel_msg}")
count += 1
elif self._model.queue_length > 0:
self.setIcon(theme.icon("queue-inactive"))
self.setToolTip(_("Server is busy."))
count = f"+{self.model.queue_length}"
elif self._model.jobs.any_executing():
self.setIcon(theme.icon("queue-active"))
if count > 0:
Expand Down

0 comments on commit 57d9ba3

Please sign in to comment.