Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(BA-19): Broken session CLI commands due to invalid initialization of ComputeSession #3222

Merged
merged 14 commits into from
Jan 3, 2025
Merged
1 change: 1 addition & 0 deletions changes/3222.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix broken session CLI commands due to invalid initialization of `ComputeSession`.
39 changes: 18 additions & 21 deletions src/ai/backend/client/cli/session/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ai.backend.cli.params import CommaSeparatedListType, OptionalType
from ai.backend.cli.types import ExitCode, Undefined, undefined
from ai.backend.common.arch import DEFAULT_IMAGE_ARCH
from ai.backend.common.types import ClusterMode, SessionId
from ai.backend.common.types import ClusterMode

from ...compat import asyncio_run
from ...exceptions import BackendAPIError
Expand Down Expand Up @@ -934,20 +934,20 @@ def status_history(session_id: str) -> None:


@session.command()
@click.argument("session_id", metavar="SESSID", type=SessionId)
@click.argument("session_id_or_name", metavar="SESSION_ID_OR_NAME")
@click.argument("new_name", metavar="NEWNAME")
def rename(session_id: SessionId, new_name: str) -> None:
def rename(session_id_or_name: str, new_name: str) -> None:
"""
Renames session name of running session.

\b
SESSID: Session ID or its alias given when creating the session.
SESSION_ID_OR_NAME: Session ID or its alias given when creating the session.
NEWNAME: New Session name.
"""

async def cmd_main() -> None:
async with AsyncSession() as api_sess:
session = api_sess.ComputeSession.from_session_id(session_id)
session = api_sess.ComputeSession(session_id_or_name)
await session.rename(new_name)
# FIXME: allow the renaming operation by RBAC and ownership
# resp = await session.update(name=new_name)
Expand All @@ -961,20 +961,20 @@ async def cmd_main() -> None:


@session.command()
@click.argument("session_id", metavar="SESSID", type=SessionId)
@click.argument("session_id_or_name", metavar="SESSION_ID_OR_NAME")
@click.argument("priority", metavar="PRIORITY", type=int)
def set_priority(session_id: SessionId, priority: int) -> None:
def set_priority(session_id_or_name: str, priority: int) -> None:
"""
Sets the scheduling priority of the session.

\b
SESSID: Session ID or its alias given when creating the session.
SESSION_ID_OR_NAME: Session ID or its alias given when creating the session.
PRIORITY: New priority value (0 to 100, may be clamped in the server side due to resource policies).
"""

async def cmd_main() -> None:
async with AsyncSession() as api_sess:
session = api_sess.ComputeSession.from_session_id(session_id)
session = api_sess.ComputeSession(session_id_or_name)
resp = await session.update(priority=priority)
item = resp["item"]
print_done(f"Session {item["name"]!r} priority is changed to {item["priority"]}.")
Expand All @@ -987,20 +987,20 @@ async def cmd_main() -> None:


@session.command()
@click.argument("session_id", metavar="SESSID", type=SessionId)
def commit(session_id: SessionId) -> None:
@click.argument("session_id_or_name", metavar="SESSION_ID_OR_NAME")
def commit(session_id_or_name: str) -> None:
"""
Commits a running session to tar file.

\b
SESSID: Session ID or its alias given when creating the session.
SESSION_ID_OR_NAME: Session ID or its alias given when creating the session.
"""

async def cmd_main() -> None:
async with AsyncSession() as api_sess:
session = api_sess.ComputeSession.from_session_id(session_id)
session = api_sess.ComputeSession(session_id_or_name)
await session.commit()
print_info(f"Request to commit Session(name or id: {session_id})")
print_info(f"Request to commit Session(name or id: {session_id_or_name})")

try:
asyncio.run(cmd_main())
Expand Down Expand Up @@ -1241,7 +1241,7 @@ def scp(


def _events_cmd(docs: Optional[str] = None):
@click.argument("session_name_or_id", metavar="SESSION_ID_OR_NAME")
@click.argument("session_id_or_name", metavar="SESSION_ID_OR_NAME")
@click.option(
"-o",
"--owner",
Expand All @@ -1256,7 +1256,7 @@ def _events_cmd(docs: Optional[str] = None):
default="*",
help="Filter the events by kernel-specific ones or session-specific ones.",
)
def events(session_name_or_id, owner_access_key, scope):
def events(session_id_or_name, owner_access_key, scope):
"""
Monitor the lifecycle events of a compute session.

Expand All @@ -1265,11 +1265,8 @@ def events(session_name_or_id, owner_access_key, scope):

async def _run_events():
async with AsyncSession() as session:
try:
session_id = uuid.UUID(session_name_or_id)
compute_session = session.ComputeSession.from_session_id(session_id)
except ValueError:
compute_session = session.ComputeSession(session_name_or_id, owner_access_key)
Comment on lines -1268 to -1272
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention of the code appears to be to raise a ValueError if session_name_or_id is a str that is not a UUID, and to call from_session_id if it is a session ID. Otherwise, the default constructor is intended to be called.

However, when from_session_id is called, the name field is not populated, which causes the code to not function correctly.

Additionally, it seems to be a mistake that owner_access_key is not set when session_id is of UUID type.

compute_session = session.ComputeSession(session_id_or_name, owner_access_key)

async with compute_session.listen_events(scope=scope) as response:
async for ev in response:
click.echo(
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/client/func/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,11 +625,11 @@ async def restart(self):
pass

@api_function
async def rename(self, new_id):
async def rename(self, new_name):
"""
Renames Session ID of running compute session.
"""
params = {"name": new_id}
params = {"name": new_name}
if self.owner_access_key:
params["owner_access_key"] = self.owner_access_key
prefix = get_naming(api_session.get().api_version, "path")
Expand Down
6 changes: 6 additions & 0 deletions tests/client/cli/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ python_tests(
"src/ai/backend/client/cli:src",
],
)

python_test_utils(
sources=[
"conftest.py",
],
)
14 changes: 14 additions & 0 deletions tests/client/cli/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest
from click.testing import CliRunner

from ai.backend.cli.loader import load_entry_points


@pytest.fixture(scope="module")
def runner():
return CliRunner()


@pytest.fixture(scope="module")
def cli_entrypoint():
return load_entry_points(allowlist={"ai.backend.client.cli"})
1 change: 1 addition & 0 deletions tests/client/cli/session/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_tests(name="tests")
51 changes: 51 additions & 0 deletions tests/client/cli/session/test_lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
from aioresponses import aioresponses

from ai.backend.cli.types import ExitCode
from ai.backend.client.config import set_config


@pytest.mark.parametrize(
"test_case",
[
{
"session_id_or_name": "00000000-0000-0000-0000-000000000000",
"new_session_name": "new-name",
"expected_exit_code": ExitCode.OK,
},
{
"session_id_or_name": "mock-session-name",
"new_session_name": "new-name",
"expected_exit_code": ExitCode.OK,
},
],
ids=["Use session command by uuid", "Use session command by session name"],
)
def test_session_command(
test_case, runner, cli_entrypoint, monkeypatch, example_keypair, unused_tcp_port_factory
):
"""
Test whether the Session CLI commands work correctly when either session_id or session_name is provided as argument.
"""

api_port = unused_tcp_port_factory()
api_url = "http://127.0.0.1:{}".format(api_port)

set_config(None)
monkeypatch.setenv("BACKEND_ACCESS_KEY", example_keypair[0])
monkeypatch.setenv("BACKEND_SECRET_KEY", example_keypair[1])
monkeypatch.setenv("BACKEND_ENDPOINT", api_url)
monkeypatch.setenv("BACKEND_ENDPOINT_TYPE", "api")

with aioresponses() as mocked:
session_id_or_name = test_case["session_id_or_name"]
new_session_name = test_case["new_session_name"]

mocked.post(
f"{api_url}/session/{session_id_or_name}/rename?name={new_session_name}", status=204
)

result = runner.invoke(
cli_entrypoint, args=["session", "rename", session_id_or_name, new_session_name]
)
assert result.exit_code == test_case["expected_exit_code"]
12 changes: 0 additions & 12 deletions tests/client/cli/test_cli_commands.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,11 @@
import re

import pytest
from click.testing import CliRunner

from ai.backend.cli.loader import load_entry_points
from ai.backend.cli.types import ExitCode
from ai.backend.client.config import get_config, set_config


@pytest.fixture(scope="module")
def runner():
return CliRunner()


@pytest.fixture(scope="module")
def cli_entrypoint():
return load_entry_points(allowlist={"ai.backend.client.cli"})


@pytest.mark.parametrize("help_arg", ["-h", "--help"])
def test_print_help(runner, cli_entrypoint, help_arg):
result = runner.invoke(cli_entrypoint, [help_arg])
Expand Down
Loading