Skip to content

Commit

Permalink
Streamline tests for CI (#179):
Browse files Browse the repository at this point in the history
- reuse server install for client tests
- add option to skip diffusion tests
  • Loading branch information
Acly committed Dec 5, 2023
1 parent e1e3f0b commit 3522e1b
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 30 deletions.
4 changes: 3 additions & 1 deletion ai_diffusion/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def info(message: str):
f"Error during model migration: {str(e)}\nSome models remain in {upgrade_comfy_dir}"
)

async def start(self):
async def start(self, port: int | None = None):
assert self.state in [ServerState.stopped, ServerState.missing_resources]
assert self._python_cmd

Expand All @@ -374,6 +374,8 @@ async def start(self):
args.append("--force-fp16")
if settings.server_arguments:
args += settings.server_arguments.split(" ")
if port is not None:
args += ["--port", str(port)]
self._process = await asyncio.create_subprocess_exec(
self._python_cmd,
*args,
Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for the Krita Generative AI plugin."""
14 changes: 14 additions & 0 deletions tests/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pathlib import Path
from ai_diffusion.resources import SDVersion, default_checkpoints


test_dir = Path(__file__).parent
server_dir = test_dir / ".server"
image_dir = test_dir / "images"
result_dir = test_dir / ".results"
reference_dir = test_dir / "references"

default_checkpoint = {
SDVersion.sd15: default_checkpoints[0].filename,
SDVersion.sdxl: default_checkpoints[1].filename,
}
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

def pytest_addoption(parser):
parser.addoption("--test-install", action="store_true")
parser.addoption("--ci", action="store_true")


class QtTestApp:
Expand Down
34 changes: 23 additions & 11 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,25 @@
from ai_diffusion.image import Image, Extent
from ai_diffusion.client import Client, ClientEvent, parse_url, resolve_sd_version, websocket_url
from ai_diffusion.style import SDVersion, Style
from ai_diffusion.server import Server, ServerState, ServerBackend
from .config import server_dir, default_checkpoint

default_checkpoint = "realisticVisionV51_v51VAE.safetensors"

@pytest.fixture(scope="session")
def comfy_server(qtapp):
server = Server(str(server_dir))
server.backend = ServerBackend.cpu
assert server.state is ServerState.stopped, (
f"Expected server installation at {server_dir}. To create the default installation run"
" `pytest tests/test_server.py --test-install`"
)
yield qtapp.run(server.start(port=8189))
qtapp.run(server.stop())


def make_default_workflow(steps=20):
w = ComfyWorkflow()
model, clip, vae = w.load_checkpoint(default_checkpoint)
model, clip, vae = w.load_checkpoint(default_checkpoint[SDVersion.sd15])
positive = w.clip_text_encode(clip, "a photo of a cat")
negative = w.clip_text_encode(clip, "a photo of a dog")
latent_image = w.empty_latent_image(512, 512)
Expand All @@ -32,7 +44,7 @@ def make_trivial_workflow():
return w


def test_connect_bad_url(qtapp):
def test_connect_bad_url(qtapp, comfy_server):
async def main():
with pytest.raises(NetworkError):
await Client.connect("bad_url")
Expand All @@ -41,9 +53,9 @@ async def main():


@pytest.mark.parametrize("cancel_point", ["after_enqueue", "after_start", "after_sampling"])
def test_cancel(qtapp, cancel_point):
def test_cancel(qtapp, comfy_server, cancel_point):
async def main():
client = await Client.connect()
client = await Client.connect(comfy_server)
job_id = None
interrupted = False
stage = 0
Expand Down Expand Up @@ -88,13 +100,13 @@ async def main():
qtapp.run(main())


def test_disconnect(qtapp):
async def listen(client):
def test_disconnect(qtapp, comfy_server):
async def listen(client: Client):
async for msg in client.listen():
assert msg.event is ClientEvent.connected

async def main():
client = await Client.connect()
client = await Client.connect(comfy_server)
task = eventloop._loop.create_task(listen(client))
task.cancel()
with pytest.raises(asyncio.CancelledError):
Expand Down Expand Up @@ -146,13 +158,13 @@ def check_resolve_sd_version(client: Client, sd_version: SDVersion):
assert resolve_sd_version(style, None) == sd_version


def test_info(qtapp):
def test_info(pytestconfig, qtapp, comfy_server):
async def main():
client = await Client.connect()
client = await Client.connect(comfy_server)
check_client_info(client)
await client.refresh()
check_client_info(client)
check_resolve_sd_version(client, SDVersion.sd15)
check_resolve_sd_version(client, SDVersion.sdxl)
# check_resolve_sd_version(client, SDVersion.sdxl) # no SDXL checkpoint in default installation

qtapp.run(main())
17 changes: 9 additions & 8 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from ai_diffusion import network, server, resources
from ai_diffusion.style import SDVersion
from ai_diffusion.server import Server, ServerState, ServerBackend, InstallationProgress
from .config import server_dir

test_dir = Path(__file__).parent / ".server"
workload_sd15 = [p.name for p in resources.required_models if p.sd_version is SDVersion.sd15]
workload_sd15 += [resources.default_checkpoints[0].name]


@pytest.mark.parametrize("mode", ["from_scratch", "resume"])
Expand Down Expand Up @@ -38,9 +39,9 @@ async def main():


def clear_test_server():
if test_dir.exists():
shutil.rmtree(test_dir, ignore_errors=True)
test_dir.mkdir(exist_ok=True)
if server_dir.exists():
shutil.rmtree(server_dir, ignore_errors=True)
server_dir.mkdir(exist_ok=True)


def test_install_and_run(qtapp, pytestconfig, local_download_server):
Expand All @@ -57,7 +58,7 @@ def test_install_and_run(qtapp, pytestconfig, local_download_server):

clear_test_server()

server = Server(str(test_dir))
server = Server(str(server_dir))
server.backend = ServerBackend.cpu
assert server.state in [ServerState.not_installed, ServerState.missing_resources]

Expand Down Expand Up @@ -87,7 +88,7 @@ async def main():
await server.stop()
assert server.state is ServerState.stopped

version_file = test_dir / ".version"
version_file = server_dir / ".version"
assert version_file.exists()
with version_file.open("w") as f:
f.write("1.0.42")
Expand All @@ -102,10 +103,10 @@ async def main():
def test_run_external(qtapp, pytestconfig):
if not pytestconfig.getoption("--test-install"):
pytest.skip("Only runs with --test-install")
if not (test_dir / "ComfyUI").exists():
if not (server_dir / "ComfyUI").exists():
pytest.skip("ComfyUI installation not found")

server = Server(str(test_dir))
server = Server(str(server_dir))
server.backend = ServerBackend.cpu
assert server.has_python
assert server.state in [ServerState.stopped, ServerState.missing_resources]
Expand Down
14 changes: 4 additions & 10 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,7 @@
from ai_diffusion.pose import Pose
from ai_diffusion.workflow import LiveParams, Conditioning, Control
from pathlib import Path

test_dir = Path(__file__).parent
image_dir = test_dir / "images"
result_dir = test_dir / ".results"
reference_dir = test_dir / "references"
default_checkpoint = {
SDVersion.sd15: "realisticVisionV51_v51VAE.safetensors",
SDVersion.sdxl: "sdXL_v10VAEFix.safetensors",
}
from .config import image_dir, result_dir, reference_dir, default_checkpoint


@pytest.fixture(scope="session", autouse=True)
Expand All @@ -28,7 +20,9 @@ def clear_results():


@pytest.fixture()
def comfy(qtapp):
def comfy(pytestconfig, qtapp):
if pytestconfig.getoption("--ci"):
pytest.skip("Diffusion is disabled on CI")
return qtapp.run(Client.connect())


Expand Down

0 comments on commit 3522e1b

Please sign in to comment.