diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 000000000..e4379be69 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,52 @@ +name: Lint, Typecheck and Test +on: [push, pull_request] + +jobs: + check: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + submodules: true + - name: Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + cache: 'pip' + - name: Dependencies + run: pip install -r requirements.txt + - name: Typecheck + uses: jakebailey/pyright-action@v1 + with: + pylance-version: latest-release + - name: Lint + if: ${{ !cancelled() }} + uses: psf/black@stable + + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + submodules: true + - name: Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + cache: 'pip' + - name: Dependencies + run: pip install -r requirements.txt + - name: Cache models + uses: actions/cache@v3 + with: + path: scripts/docker/downloads + key: models-v1 + - name: Download models + run: python scripts/download_models.py --minimal scripts/docker/downloads + - name: Test installer + run: python -m pytest tests/test_server.py -vs --test-install + - name: Test + run: python -m pytest tests -vs --ci + diff --git a/.gitignore b/.gitignore index 8e8c7fe33..0a1f91b24 100644 --- a/.gitignore +++ b/.gitignore @@ -7,8 +7,7 @@ .server .vscode __pycache__ -scripts/docker/models -scripts/docker/custom_nodes/* +scripts/docker/downloads tests/.results settings.json ai_diffusion/styles/* diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 34788662a..857bcfa39 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -27,7 +27,7 @@ The easiest way to run a development version of the plugin is to use symlinks: ### Code formatting -The codebase uses [black](https://github.com/psf/black) for formatting. The project root contains a `pyproject.toml` to configure the line length, it should be picked up automatically. +The codebase uses [black](https://github.com/psf/black) for formatting. You can check locally by running `black` in the repository root, or use an IDE integration. ### Code style @@ -37,7 +37,9 @@ Code style follows the official Python recommendations. Only exception: no `ALL_ Type annotations should be used where types can't be inferred. Basic type checks are enabled for the project and should not report errors. -The `Krita` module is special in that it is usually only available when running inside Krita. To make type checking work, include `scripts/typeshed` in your `PYTHONPATH`. +The `Krita` module is special in that it is usually only available when running inside Krita. To make type checking work an interface file is located in `scripts/typeshed`. + +You can run `pyright` from the repository root to perform type checks on the entire codebase. This is also done by the CI. Configuration for VSCode with Pylance (.vscode/settings.json): ``` @@ -77,9 +79,9 @@ Everything else has tests. Mostly. If effort is reasonable, tests are expected. Testing changes to the installer is annoying because of the file sizes involved. There are some things that help. You can preload model files with the following script: ``` -python scripts/download_models.py --minimal scripts/docker +python scripts/download_models.py --minimal scripts/docker/downloads ``` -This will download the minimum required models and store them in `scripts/docker` (used as default location because that way the docker build script can use them too). +This will download the minimum required models and store them in `scripts/docker/downloads` (used as default location because that way the docker build script can use them too). The following command does some automated testing for installation and upgrade. It starts a local file server which pulls preloaded models, so it's reasonably fast and doesn't download the entire internet. ``` diff --git a/ai_diffusion/attention_edit.py b/ai_diffusion/attention_edit.py index 4422491de..038d76d85 100644 --- a/ai_diffusion/attention_edit.py +++ b/ai_diffusion/attention_edit.py @@ -5,7 +5,7 @@ def select_current_parenthesis_block( text: str, cursor_pos: int, open_bracket: str, close_bracket: str ) -> Tuple[int, int] | None: - """Select the current parenthesis block that the cursor points to. """ + """Select the current parenthesis block that the cursor points to.""" # Ensure cursor position is within valid range cursor_pos = max(0, min(cursor_pos, len(text))) @@ -50,9 +50,8 @@ def select_current_word(text: str, cursor_pos: int) -> Tuple[int, int]: def select_on_cursor_pos(text: str, cursor_pos: int) -> Tuple[int, int]: """Return a range in the text based on the cursor_position.""" - return ( - select_current_parenthesis_block(text, cursor_pos, "(", ")") - or select_current_word(text, cursor_pos) + return select_current_parenthesis_block(text, cursor_pos, "(", ")") or select_current_word( + text, cursor_pos ) @@ -143,4 +142,3 @@ def edit_attention(text: str, positive: bool) -> str: if weight == 1.0 else f"{open_bracket}{attention_string}:{weight:.1f}{close_bracket}" ) - diff --git a/ai_diffusion/client.py b/ai_diffusion/client.py index 769eb4585..e98981afe 100644 --- a/ai_diffusion/client.py +++ b/ai_diffusion/client.py @@ -128,7 +128,7 @@ class Client: lcm_model: dict[SDVersion, str | None] supported_sd_versions: list[SDVersion] device_info: DeviceInfo - nodes_required_inputs: dict[str, dict[str, list[str | list | dict]]] = {} + nodes_inputs: dict[str, dict[str, list[str | list | dict]]] = {} @staticmethod async def connect(url=default_url): @@ -164,7 +164,7 @@ async def connect(url=default_url): client.ip_adapter_model = { ver: _find_ip_adapter(ip, ver) for ver in [SDVersion.sd15, SDVersion.sdxl] } - client.nodes_required_inputs["IPAdapterApply"] = nodes["IPAdapterApply"]["input"]["required"] + client.nodes_inputs["IPAdapterApply"] = nodes["IPAdapterApply"]["input"]["required"] # Retrieve upscale models client.upscalers = nodes["UpscaleModelLoader"]["input"]["required"]["model_name"][0] @@ -395,7 +395,7 @@ def _check_workload(self, sdver: SDVersion) -> list[MissingResource]: def parse_url(url: str): url = url.strip("/") - url = url.replace('0.0.0.0', '127.0.0.1') + url = url.replace("0.0.0.0", "127.0.0.1") if not url.startswith("http"): url = f"http://{url}" return url diff --git a/ai_diffusion/image.py b/ai_diffusion/image.py index 701d4d76f..f715398f3 100644 --- a/ai_diffusion/image.py +++ b/ai_diffusion/image.py @@ -261,6 +261,7 @@ def make_opaque(self, background=Qt.GlobalColor.white): @property def data(self): ptr = self._qimage.bits() + assert ptr is not None, "Accessing data of invalid image" ptr.setsize(self._qimage.byteCount()) return QByteArray(ptr.asstring()) @@ -272,7 +273,9 @@ def to_array(self): import numpy as np w, h = self.extent - ptr = self._qimage.constBits().asarray(w * h * 4) + bits = self._qimage.constBits() + assert bits is not None, "Accessing data of invalid image" + ptr = bits.asarray(w * h * 4) array = np.frombuffer(ptr, np.uint8).reshape(w, h, 4) # type: ignore return array.astype(np.float32) / 255 diff --git a/ai_diffusion/network.py b/ai_diffusion/network.py index c2bad3060..4fd9a4261 100644 --- a/ai_diffusion/network.py +++ b/ai_diffusion/network.py @@ -164,6 +164,7 @@ async def _try_download(network: QNetworkAccessManager, url: str, path: Path): log.info(f"Found {path}.part, resuming download from {out_file.size()} bytes") request.setRawHeader(b"Range", f"bytes={out_file.size()}-".encode("utf-8")) reply = network.get(request) + assert reply is not None, f"Network request for {url} failed: reply is None" progress_future = asyncio.get_running_loop().create_future() finished_future = asyncio.get_running_loop().create_future() diff --git a/ai_diffusion/server.py b/ai_diffusion/server.py index bf08d0422..9a3f36c78 100644 --- a/ai_diffusion/server.py +++ b/ai_diffusion/server.py @@ -17,7 +17,7 @@ _exe = ".exe" if is_windows else "" -_process_flags = subprocess.CREATE_NO_WINDOW if is_windows else 0 +_process_flags = subprocess.CREATE_NO_WINDOW if is_windows else 0 # type: ignore class ServerState(Enum): diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py index eb5098e5c..40ae8a701 100644 --- a/ai_diffusion/ui/generation.py +++ b/ai_diffusion/ui/generation.py @@ -20,6 +20,7 @@ from ..model import Model from ..root import root from ..settings import settings +from ..util import ensure from . import theme from .widget import ( WorkspaceSelectWidget, @@ -94,7 +95,7 @@ def add(self, job: Job): self.addItem(item) scrollbar = self.verticalScrollBar() - if scrollbar.isVisible() and scrollbar.value() >= scrollbar.maximum() - 4: + if scrollbar and scrollbar.isVisible() and scrollbar.value() >= scrollbar.maximum() - 4: self.scrollToBottom() def update_selection(self): @@ -118,7 +119,7 @@ def is_finished(self, job: Job): def prune(self, jobs: JobQueue): first_id = next((job.id for job in jobs if self.is_finished(job)), None) - while self.count() > 0 and self.item(0).data(Qt.ItemDataRole.UserRole) != first_id: + while self.count() > 0 and ensure(self.item(0)).data(Qt.ItemDataRole.UserRole) != first_id: self.takeItem(0) def rebuild(self): @@ -131,8 +132,9 @@ def item_info(self, item: QListWidgetItem) -> tuple[str, int]: # job id, image def handle_preview_click(self, item: QListWidgetItem): if item.text() != "" and item.text() != "": - prompt = item.data(Qt.ItemDataRole.ToolTipRole) - QGuiApplication.clipboard().setText(prompt) + if clipboard := QGuiApplication.clipboard(): + prompt = item.data(Qt.ItemDataRole.ToolTipRole) + clipboard.setText(prompt) def mousePressEvent(self, e: QMouseEvent) -> None: # make single click deselect current item (usually requires Ctrl+click) @@ -151,7 +153,7 @@ def mousePressEvent(self, e: QMouseEvent) -> None: return super().mousePressEvent(e) def _find(self, id: JobQueue.Item): - items = (self.item(i) for i in range(self.count())) + items = (ensure(self.item(i)) for i in range(self.count())) return next((item for item in items if self._item_data(item) == id), None) def _item_data(self, item: QListWidgetItem): diff --git a/ai_diffusion/ui/live.py b/ai_diffusion/ui/live.py index f058e3eb5..b1120a4bb 100644 --- a/ai_diffusion/ui/live.py +++ b/ai_diffusion/ui/live.py @@ -118,7 +118,9 @@ def __init__(self): self.preview_area = QLabel(self) self.preview_area.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) - self.preview_area.setAlignment(Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignLeft) + self.preview_area.setAlignment( + Qt.AlignmentFlag(Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignLeft) + ) layout.addWidget(self.preview_area) @property diff --git a/ai_diffusion/ui/settings.py b/ai_diffusion/ui/settings.py index 0dfc3b122..dbd4872e9 100644 --- a/ai_diffusion/ui/settings.py +++ b/ai_diffusion/ui/settings.py @@ -1055,10 +1055,9 @@ def __init__(self, server: Server): self.setWindowTitle("Configure Image Diffusion") self.setMinimumSize(QSize(840, 480)) - screen_size = QGuiApplication.primaryScreen().availableSize() - self.resize( - QSize(max(900, int(screen_size.width() * 0.6)), int(screen_size.height() * 0.8)) - ) + if screen := QGuiApplication.primaryScreen(): + size = screen.availableSize() + self.resize(QSize(max(900, int(size.width() * 0.6)), int(size.height() * 0.8))) layout = QHBoxLayout() self.setLayout(layout) diff --git a/ai_diffusion/ui/theme.py b/ai_diffusion/ui/theme.py index f4878a589..9a898d852 100644 --- a/ai_diffusion/ui/theme.py +++ b/ai_diffusion/ui/theme.py @@ -40,7 +40,9 @@ def sd_version_icon(version: SDVersion, client: Client | None = None): return icon("sd-version-15") elif version is SDVersion.sdxl: return icon("sd-version-xl") - return None + else: + util.client_logger.warning(f"Unresolved SD version {version}, cannot fetch icon") + return icon("warning") def logo(): diff --git a/ai_diffusion/ui/widget.py b/ai_diffusion/ui/widget.py index c93bb8acf..e9f412c1b 100644 --- a/ai_diffusion/ui/widget.py +++ b/ai_diffusion/ui/widget.py @@ -22,17 +22,17 @@ import krita from ..style import Style, Styles -from ..image import Bounds from ..resources import ControlMode from ..root import root from ..client import filter_supported_styles, resolve_sd_version from ..properties import Binding, bind, bind_combo -from ..jobs import Job, JobKind, JobState, JobQueue +from ..jobs import JobState, JobQueue from ..model import Model, Workspace, ControlLayer +from ..attention_edit import edit_attention, select_on_cursor_pos +from ..util import ensure from .settings import SettingsDialog from .theme import SignalBlocker from . import actions, theme -from ..attention_edit import edit_attention, select_on_cursor_pos class QueueWidget(QToolButton): @@ -400,7 +400,7 @@ def line_count(self): @line_count.setter def line_count(self, value: int): self._line_count = value - fm = QFontMetrics(self.document().defaultFont()) + fm = QFontMetrics(ensure(self.document()).defaultFont()) self.setFixedHeight(fm.lineSpacing() * value + 6) def hasSelectedText(self) -> bool: diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index 37801eeae..d5c5dce87 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -422,7 +422,7 @@ def generate( sampler_params = _sampler_params(style, live=live) batch = 1 if live.is_active else batch - w = ComfyWorkflow(comfy.nodes_required_inputs) + w = ComfyWorkflow(comfy.nodes_inputs) model, clip, vae = load_model_with_lora(w, comfy, style, is_live=live.is_active) latent = w.empty_latent_image(extent.initial.width, extent.initial.height, batch) model, positive, negative = apply_conditioning(cond, w, comfy, model, clip, style) @@ -443,7 +443,7 @@ def inpaint(comfy: Client, style: Style, image: Image, mask: Mask, cond: Conditi region_expanded = target_bounds.extent.at_least(64).multiple_of(8) expanded_bounds = Bounds(*mask.bounds.offset, *region_expanded) - w = ComfyWorkflow(comfy.nodes_required_inputs) + w = ComfyWorkflow(comfy.nodes_inputs) model, clip, vae = load_model_with_lora(w, comfy, style) in_image = w.load_image(scaled_image) in_mask = w.load_mask(scaled_mask) @@ -521,7 +521,7 @@ def refine( extent, image, batch = prepare_image(image, resolve_sd_version(style, comfy), downscale=False) sampler_params = _sampler_params(style, live=live, strength=strength) - w = ComfyWorkflow(comfy.nodes_required_inputs) + w = ComfyWorkflow(comfy.nodes_inputs) model, clip, vae = load_model_with_lora(w, comfy, style, is_live=live.is_active) in_image = w.load_image(image) if extent.is_incompatible: @@ -554,7 +554,7 @@ def refine_region( extent, image, mask_image, batch = prepare_masked(image, mask, sd_ver, downscale_if_needed) sampler_params = _sampler_params(style, strength=strength, live=live) - w = ComfyWorkflow(comfy.nodes_required_inputs) + w = ComfyWorkflow(comfy.nodes_inputs) model, clip, vae = load_model_with_lora(w, comfy, style, is_live=live.is_active) in_image = w.load_image(image) in_mask = w.load_mask(mask_image) @@ -647,7 +647,7 @@ def upscale_tiled( else: # SDXL tile_extent = Extent(1024, 1024) - w = ComfyWorkflow(comfy.nodes_required_inputs) + w = ComfyWorkflow(comfy.nodes_inputs) img = w.load_image(image) checkpoint, clip, vae = load_model_with_lora(w, comfy, style) upscale_model = w.load_upscale_model(model) diff --git a/pyproject.toml b/pyproject.toml index 49511bf54..ae25f94c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,24 @@ [tool.black] line-length=100 -preview=1 \ No newline at end of file +preview=1 +include='(ai_diffusion|scripts|tests)/.*\.pyi?$' +extend-exclude='websockets|krita\.pyi$' + +[tool.pyright] +include = [ + "ai_diffusion", + "scripts/*.py", + "tests" +] +exclude = [ + "**/__pycache__", + "**/.pytest_cache", + "**/.server", + "ai_diffusion/websockets", +] +ignore = [ + "ai_diffusion/websockets", + "krita.pyi" +] +extraPaths = ["scripts/typeshed"] +reportMissingModuleSource = false diff --git a/requirements.txt b/requirements.txt index d65b4cbe8..4d170d28c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,10 @@ # This file is for development and running tests. # The plugin itself will run inside Krita's embedded Python, and only has access to the Python standard library and Qt5. + +# Development +black + +# Testing aiohttp markdown numpy diff --git a/scripts/docker/Dockerfile b/scripts/docker/Dockerfile index 3626ca3e0..3bf4bb985 100644 --- a/scripts/docker/Dockerfile +++ b/scripts/docker/Dockerfile @@ -74,8 +74,8 @@ RUN git clone https://github.com/ltdrdata/ComfyUI-Manager.git /ComfyUI/custom_no # Copy models RUN mkdir -p models -COPY models/ /models/ -COPY custom_nodes/ComfyUI_IPAdapter_plus/models/* /ComfyUI/custom_nodes/ComfyUI_IPAdapter_plus/models/ +COPY downloads/models/ /models/ +COPY downloads/custom_nodes/ComfyUI_IPAdapter_plus/models/* /ComfyUI/custom_nodes/ComfyUI_IPAdapter_plus/models/ COPY extra_model_paths.yaml /ComfyUI/ # Install Jupyter diff --git a/scripts/download_models.py b/scripts/download_models.py index 4f00c30e2..c359aeec4 100644 --- a/scripts/download_models.py +++ b/scripts/download_models.py @@ -14,7 +14,7 @@ """ import asyncio -from itertools import chain +from itertools import chain, islice import aiohttp import sys from pathlib import Path @@ -38,6 +38,10 @@ def all_models(): ) +def required_models(): + return chain(resources.required_models, islice(resources.default_checkpoints, 1)) + + def _progress(name: str, size: int | None): return tqdm( total=size, @@ -89,24 +93,17 @@ async def main( ): print(f"Generative AI for Krita - Model download - v{ai_diffusion.__version__}") verbose = verbose or dry_run + models = required_models() if minimal else all_models() timeout = aiohttp.ClientTimeout(total=None, sock_connect=10, sock_read=60) async with aiohttp.ClientSession(timeout=timeout) as client: - for model in all_models(): + for model in models: if ( (no_sd15 and model.sd_version is SDVersion.sd15) or (no_sdxl and model.sd_version is SDVersion.sdxl) or (no_controlnet and model.kind is resources.ResourceKind.controlnet) - or ( - no_upscalers - and model.kind is resources.ResourceKind.upscaler - and (not minimal or model.name != "NMKD Superscale model") - ) - or ( - no_checkpoints - and model.kind is resources.ResourceKind.checkpoint - and (not minimal or model.name != "Realistic Vision") - ) + or (no_upscalers and model.kind is resources.ResourceKind.upscaler) + or (no_checkpoints and model.kind is resources.ResourceKind.checkpoint) ): continue if verbose: @@ -143,12 +140,7 @@ async def main( parser.add_argument("--no-controlnet", action="store_true", help="skip ControlNet models") parser.add_argument("-m", "--minimal", action="store_true", help="minimum viable set of models") args = parser.parse_args() - if args.minimal: - assert not args.no_sd15, "Minimal requires SD1.5 models" - args.no_sdxl = True - args.no_upscalers = True - args.no_checkpoints = True - args.no_controlnet = True + args.no_sdxl = args.no_sdxl or args.minimal asyncio.run( main( args.destination, diff --git a/scripts/file_server.py b/scripts/file_server.py index 94b54fcec..df0ed954d 100644 --- a/scripts/file_server.py +++ b/scripts/file_server.py @@ -1,5 +1,5 @@ """Simple HTTP server for testing the installation process. -1) Run the docker.py script to download all required models. +1) Run the download_models.py script to download all required models. 2) Run this script to serve the model files on localhost. 3) Set environment variable HOSTMAP=1 to replace all huggingface / civitai urls. """ @@ -12,7 +12,7 @@ sys.path.append(str(Path(__file__).parent.parent)) from ai_diffusion import resources -dir = Path(__file__).parent / "docker" +dir = Path(__file__).parent / "docker" / "downloads" def url_strip(url: str): diff --git a/tests/config.py b/tests/config.py index fe823c015..050362e89 100644 --- a/tests/config.py +++ b/tests/config.py @@ -4,6 +4,7 @@ test_dir = Path(__file__).parent server_dir = test_dir / ".server" +data_dir = test_dir / "data" image_dir = test_dir / "images" result_dir = test_dir / ".results" reference_dir = test_dir / "references" diff --git a/tests/conftest.py b/tests/conftest.py index a76e54a86..b2bf2e4c6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,9 +45,10 @@ def local_download_server(): port = 51222 with subprocess.Popen([sys.executable, str(script), str(port)]) as proc: + assert proc.poll() is None network.HOSTMAP = network.HOSTMAP_LOCAL yield f"http://localhost:{port}" network.HOSTMAP = {} - proc.kill() + proc.terminate() proc.wait() diff --git a/tests/data/outpaint_context.png b/tests/data/outpaint_context.png new file mode 100644 index 000000000..2b9aa17b4 Binary files /dev/null and b/tests/data/outpaint_context.png differ diff --git a/tests/images/outpaint_context.png b/tests/images/outpaint_context.png deleted file mode 100644 index d76b687b6..000000000 --- a/tests/images/outpaint_context.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:899c7078792c4d10ce22b9a22f2a220790d504fd4256d764a3595acc77a16133 -size 1734 diff --git a/tests/test_attention_edit.py b/tests/test_attention_edit.py index c74cd040f..323afeb0d 100644 --- a/tests/test_attention_edit.py +++ b/tests/test_attention_edit.py @@ -39,7 +39,7 @@ def test_invalid_weight(self): assert edit_attention("(foo:bar)", positive=True) == "((foo:bar):1.1)" def test_no_weight(self): - assert edit_attention("(foo)", positive=True) == "((foo):1.1)" + assert edit_attention("(foo)", positive=True) == "((foo):1.1)" class TestSelectOnCursorPos: @@ -52,4 +52,3 @@ def test_word_selection(self): def test_range_selection(self): assert select_on_cursor_pos("(foo:1.3), bar, baz", 1) == (0, 9) assert select_on_cursor_pos("foo, (bar:1.1), baz", 6) == (5, 14) - diff --git a/tests/test_server.py b/tests/test_server.py index ba54b1d82..3ebd03f0a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -48,7 +48,7 @@ def test_install_and_run(qtapp, pytestconfig, local_download_server): """Test installing and running ComfyUI server from scratch. * Takes a while, only runs with --test-install * Starts and downloads from local file server instead of huggingface/civitai - * Required to run `scripts/download_models.py -m scripts/docker` to download models once + * Required to run `scripts/download_models.py -m scripts/docker/downloads` to download models once * Remove `local_download_server` fixture to download from original urls * Also tests upgrading server from "previous" version * In this case it's the same version, but it removes & re-installs anyway @@ -62,7 +62,10 @@ def test_install_and_run(qtapp, pytestconfig, local_download_server): server.backend = ServerBackend.cpu assert server.state in [ServerState.not_installed, ServerState.missing_resources] + last_stage = "" + def handle_progress(report: InstallationProgress): + nonlocal last_stage assert ( report.progress is None or report.progress.value == -1 @@ -70,8 +73,9 @@ def handle_progress(report: InstallationProgress): and report.progress.value <= 1 ) assert report.stage != "" - if report.progress is None: - print(report.stage, report.message) + if report.progress is None and report.stage != last_stage: + last_stage = report.stage + print(report.stage) async def main(): await server.install(handle_progress) @@ -81,9 +85,9 @@ async def main(): await server.download(workload_sd15, handle_progress) assert server.state is ServerState.stopped and server.version == resources.version - url = await server.start() + url = await server.start(port=8191) assert server.state is ServerState.running - assert url == "127.0.0.1:8188" + assert url == "127.0.0.1:8191" await server.stop() assert server.state is ServerState.stopped @@ -112,9 +116,9 @@ def test_run_external(qtapp, pytestconfig): assert server.state in [ServerState.stopped, ServerState.missing_resources] async def main(): - url = await server.start() + url = await server.start(port=8192) assert server.state is ServerState.running - assert url == "127.0.0.1:8188" + assert url == "127.0.0.1:8192" await server.stop() assert server.state is ServerState.stopped diff --git a/tests/test_settings.py b/tests/test_settings.py index f3a44a4d3..a5e13c502 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -52,13 +52,11 @@ def test_performance_preset(): def style_is_default(style): - return all( - [ - getattr(style, name) == s.default - for name, s in StyleSettings.__dict__.items() - if isinstance(s, Setting) and name != "name" - ] - ) + return all([ + getattr(style, name) == s.default + for name, s in StyleSettings.__dict__.items() + if isinstance(s, Setting) and name != "name" + ]) def test_styles(): diff --git a/tests/test_workflow.py b/tests/test_workflow.py index afb401817..784415f77 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -8,7 +8,7 @@ from ai_diffusion.pose import Pose from ai_diffusion.workflow import LiveParams, Conditioning, Control from pathlib import Path -from .config import image_dir, result_dir, reference_dir, default_checkpoint +from .config import data_dir, image_dir, result_dir, reference_dir, default_checkpoint @pytest.fixture(scope="session", autouse=True) @@ -86,7 +86,7 @@ def test_compute_batch_size(extent, min_size, max_batches, expected): ids=["left", "right", "top", "bottom", "full", "small", "offset"], ) def test_inpaint_context(area, expected_extent, expected_crop: tuple[int, int] | None): - image = Image.load(image_dir / "outpaint_context.png") + image = Image.load(data_dir / "outpaint_context.png") default = comfyworkflow.Output(0, 0) result = workflow.create_inpaint_context(image, area, default) if expected_crop: