Skip to content

Commit

Permalink
Fix issue 1137: Query ComfyUI queue state and use it to display count…
Browse files Browse the repository at this point in the history
… of jobs in the queue before ours.

Also use it to avoid clearing jobs that aren't ours.
  • Loading branch information
FeepingCreature committed Sep 6, 2024
1 parent 8d63f31 commit 29a40b2
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 3 deletions.
2 changes: 2 additions & 0 deletions ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ClientEvent(Enum):
disconnected = 5
queued = 6
upload = 7
foreign_jobs = 8


class ClientMessage(NamedTuple):
Expand All @@ -34,6 +35,7 @@ class ClientMessage(NamedTuple):
images: ImageCollection | None = None
result: dict | None = None
error: str | None = None
foreign_jobs: int | None = None


class User(QObject, ObservableProperties):
Expand Down
34 changes: 31 additions & 3 deletions ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import Enum
from collections import deque
from itertools import chain, product
from operator import itemgetter
from typing import NamedTuple, Optional, Sequence

from .api import WorkflowInput
Expand Down Expand Up @@ -99,6 +100,7 @@ class ComfyClient(Client):
_websocket_listener: asyncio.Task
_supported_sd_versions: list[SDVersion]
_supported_languages: list[TranslationPackage]
_server_job_ids: list[str] = []

@staticmethod
async def connect(url=default_url, access_token=""):
Expand Down Expand Up @@ -264,6 +266,7 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr

if msg["type"] == "status":
await self._report(ClientEvent.connected, "")
await self.update_server_queue()

if msg["type"] == "execution_start":
id = msg["data"]["prompt_id"]
Expand Down Expand Up @@ -340,9 +343,34 @@ async def clear_queue(self):
except asyncio.QueueEmpty:
break

await self._post("queue", {"clear": True})
remote_ids = [await job.get_remote_id() for job in self._jobs]
await self._post("api/queue", {"delete": remote_ids})

self._jobs.clear()

async def update_server_queue(self):
queue = await self._get("api/queue")
server_jobs = queue["queue_running"] + queue["queue_pending"]
# why are they unsorted to start with...?
server_jobs = sorted(server_jobs, key=itemgetter(0))
self._server_job_ids = [entry[1] for entry in server_jobs]
if not (self._jobs or self._active):
await self._report(ClientEvent.foreign_jobs, "", foreign_jobs=len(self._server_job_ids))
return

if self._active:
first_remote_id = await self._active.get_remote_id()
else:
first_remote_id = await self._jobs[0].get_remote_id()
# if we got it from _jobs or _active, this field must have been set (in _run_job).
first_remote_id = util.ensure(first_remote_id)
try:
offset = self._server_job_ids.index(first_remote_id)
except ValueError:
# probably just haven't gotten the notification yet
return
await self._report(ClientEvent.foreign_jobs, "", foreign_jobs=offset)

async def disconnect(self):
if self._is_connected:
self._is_connected = False
Expand Down Expand Up @@ -443,7 +471,7 @@ def _get_active_job(self, remote_id: str) -> Optional[JobInfo]:
return self._active
elif self._active:
log.warning(f"Received message for job {remote_id}, but job {self._active} is active")
if len(self._jobs) == 0:
if not self._jobs:
log.warning(f"Received unknown job {remote_id}")
return None
active = next((j for j in self._jobs if j.remote_id == remote_id), None)
Expand All @@ -454,7 +482,7 @@ def _get_active_job(self, remote_id: str) -> Optional[JobInfo]:
async def _start_job(self, remote_id: str):
if self._active is not None:
log.warning(f"Started job {remote_id}, but {self._active} was never finished")
if len(self._jobs) == 0:
if not self._jobs:
log.warning(f"Received unknown job {remote_id}")
return None

Expand Down
5 changes: 5 additions & 0 deletions ai_diffusion/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ class Connection(QObject, ObservableProperties):
state = Property(ConnectionState.disconnected)
error = Property("")
missing_resource: MissingResource | None = None
foreign_jobs: int = 0

state_changed = pyqtSignal(ConnectionState)
error_changed = pyqtSignal(str)
models_changed = pyqtSignal()
message_received = pyqtSignal(ClientMessage)
foreign_jobs_changed = pyqtSignal(int)

_client: Client | None = None
_task: asyncio.Task | None = None
Expand Down Expand Up @@ -169,6 +171,9 @@ async def _handle_messages(self):
if temporary_disconnect:
temporary_disconnect = False
self.error = ""
elif msg.event is ClientEvent.foreign_jobs:
self.foreign_jobs = util.ensure(msg.foreign_jobs)
self.foreign_jobs_changed.emit(self.foreign_jobs)
else:
self.message_received.emit(msg)
except asyncio.CancelledError:
Expand Down
5 changes: 5 additions & 0 deletions ai_diffusion/ui/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def _connect_model(self):
self._connections = [
self._model.jobs.count_changed.connect(self._update),
self._model.progress_kind_changed.connect(self._update),
root.connection.foreign_jobs_changed.connect(self._update),
]

def _update(self):
Expand All @@ -220,6 +221,10 @@ def _update(self):
self.setIcon(theme.icon("queue-upload"))
self.setToolTip(_("Uploading models.") + f" {queued_msg} {cancel_msg}")
count += 1
elif root.connection.foreign_jobs > 0:
self.setIcon(theme.icon("queue-inactive"))
self.setToolTip(_("Server is busy."))
count = f"+{root.connection.foreign_jobs}"
elif self._model.jobs.any_executing():
self.setIcon(theme.icon("queue-active"))
if count > 0:
Expand Down

0 comments on commit 29a40b2

Please sign in to comment.