Skip to content

Commit

Permalink
squashme: fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
olevski committed Feb 4, 2025
1 parent db72c3a commit c4608a3
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 15 deletions.
6 changes: 3 additions & 3 deletions components/renku_data_services/data_connectors/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from cryptography.hazmat.primitives.asymmetric import rsa
from sqlalchemy import Select, delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import joinedload, selectinload
from ulid import ULID

from renku_data_services import base_models, errors
Expand Down Expand Up @@ -57,7 +57,7 @@ async def get_data_connectors(
.options(
joinedload(schemas.DataConnectorORM.slug)
.joinedload(ns_schemas.EntitySlugORM.project)
.joinedload(ProjectORM.slug)
.selectinload(ProjectORM.slug)
)
)
if namespace:
Expand Down Expand Up @@ -95,7 +95,7 @@ async def get_data_connector(
.options(
joinedload(schemas.DataConnectorORM.slug)
.joinedload(ns_schemas.EntitySlugORM.project)
.joinedload(ProjectORM.slug)
.selectinload(ProjectORM.slug)
)
)
data_connector = result.one_or_none()
Expand Down
4 changes: 2 additions & 2 deletions components/renku_data_services/namespace/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ class EntitySlugORM(BaseORM):
"""Entity slugs.
Note that valid combinations here are:
- namespace_id
- namespace_id + project_id
- namespace_id + project_id + data_connector_id
- namespace_id + data_connector_id
Expand All @@ -234,14 +233,15 @@ class EntitySlugORM(BaseORM):
unique=True,
postgresql_nulls_not_distinct=True,
),
# TODO: Add the constraint that at least 1 of project and data_connector has to be set
)

id: Mapped[int] = mapped_column(Integer, Identity(always=True), primary_key=True, init=False)
slug: Mapped[str] = mapped_column(String(99), index=True, nullable=False)
project_id: Mapped[ULID | None] = mapped_column(
ForeignKey(ProjectORM.id, ondelete="CASCADE", name="entity_slugs_project_id_fk"), index=True, nullable=True
)
project: Mapped[ProjectORM | None] = relationship(init=False, repr=False, back_populates="slug", lazy="joined")
project: Mapped[ProjectORM | None] = relationship(init=False, repr=False, back_populates="slug", lazy="selectin")
data_connector_id: Mapped[ULID | None] = mapped_column(
ForeignKey(DataConnectorORM.id, ondelete="CASCADE", name="entity_slugs_data_connector_id_fk"),
index=True,
Expand Down
20 changes: 11 additions & 9 deletions components/renku_data_services/project/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def get_projects(
stmt = select(schemas.ProjectORM)
stmt = stmt.where(schemas.ProjectORM.id.in_(project_ids))
if namespace:
stmt = _filter_by_namespace_slug(stmt, namespace)
stmt = _filter_projects_by_namespace_slug(stmt, namespace)

stmt = stmt.order_by(coalesce(schemas.ProjectORM.updated_at, schemas.ProjectORM.creation_date).desc())

Expand All @@ -87,7 +87,7 @@ async def get_projects(
select(func.count()).select_from(schemas.ProjectORM).where(schemas.ProjectORM.id.in_(project_ids))
)
if namespace:
stmt_count = _filter_by_namespace_slug(stmt_count, namespace)
stmt_count = _filter_projects_by_namespace_slug(stmt_count, namespace)
results = await session.scalars(stmt), await session.scalar(stmt_count)
projects_orm = results[0].all()
total_elements = results[1] or 0
Expand Down Expand Up @@ -153,7 +153,7 @@ async def get_project_by_namespace_slug(
"""Get one project from the database."""
async with self.session_maker() as session:
stmt = select(schemas.ProjectORM)
stmt = _filter_by_namespace_slug(stmt, namespace)
stmt = _filter_projects_by_namespace_slug(stmt, namespace)
stmt = stmt.where(schemas.ProjectORM.slug.has(ns_schemas.EntitySlugORM.slug == slug.value))
if with_documentation:
stmt = stmt.options(undefer(schemas.ProjectORM.documentation))
Expand Down Expand Up @@ -257,6 +257,7 @@ async def insert_project(
.where(ns_schemas.EntitySlugORM.namespace_id == ns.id)
.where(ns_schemas.EntitySlugORM.slug == slug)
.where(ns_schemas.EntitySlugORM.data_connector_id.is_(None))
.where(ns_schemas.EntitySlugORM.project_id.is_not(None)),
)
if existing_slug is not None:
raise errors.ConflictError(message=f"An entity with the slug '{ns.slug}/{slug}' already exists.")
Expand Down Expand Up @@ -284,7 +285,6 @@ async def insert_project(
session.add(project_slug)
await session.flush()
await session.refresh(project_orm)

return project_orm.dump()

@with_db_transaction
Expand Down Expand Up @@ -458,12 +458,14 @@ async def get_project_permissions(self, user: base_models.APIUser, project_id: U
_T = TypeVar("_T")


def _filter_by_namespace_slug(statement: Select[tuple[_T]], namespace: str) -> Select[tuple[_T]]:
def _filter_projects_by_namespace_slug(statement: Select[tuple[_T]], namespace: str) -> Select[tuple[_T]]:
"""Filters a select query on projects to a given namespace."""
return (
statement.where(ns_schemas.NamespaceORM.slug == namespace.lower())
.where(ns_schemas.EntitySlugORM.namespace_id == ns_schemas.NamespaceORM.id)
.where(schemas.ProjectORM.id == ns_schemas.EntitySlugORM.project_id)
return statement.where(
schemas.ProjectORM.slug.has(
ns_schemas.EntitySlugORM.namespace.has(
ns_schemas.NamespaceORM.slug == namespace.lower(),
)
)
)


Expand Down
7 changes: 6 additions & 1 deletion components/renku_data_services/project/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ class ProjectORM(BaseORM):
# NOTE: The project slugs table has a foreign key from the projects table, but there is a stored procedure
# triggered by the deletion of slugs to remove the project used by the slug. See migration 89aa4573cfa9.
slug: Mapped["EntitySlugORM"] = relationship(
lazy="joined", init=False, repr=False, viewonly=True, back_populates="project"
# NOTE: joined lazy relationship can cause duplicates to be returned
lazy="selectin",
init=False,
repr=False,
viewonly=True,
back_populates="project",
)
repositories: Mapped[list["ProjectRepositoryORM"]] = relationship(
back_populates="project",
Expand Down
71 changes: 71 additions & 0 deletions test/bases/renku_data_services/data_api/test_namespaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,74 @@ async def test_entity_slug_uniqueness(sanic_client, user_headers) -> None:
}
_, response = await sanic_client.post("/api/data/data_connectors", headers=user_headers, json=payload)
assert response.status_code == 201, response.text


@pytest.mark.asyncio
async def test_creating_dc_in_project(sanic_client, user_headers) -> None:
# Create a group i.e. /test1
payload = {
"name": "test1",
"slug": "test1",
"description": "Group 1 Description",
}
_, response = await sanic_client.post("/api/data/groups", headers=user_headers, json=payload)
assert response.status_code == 201, response.text

# Create a project in the group /test1/prj1
payload = {
"name": "prj1",
"namespace": "test1",
"slug": "prj1",
}
_, response = await sanic_client.post("/api/data/projects", headers=user_headers, json=payload)
assert response.status_code == 201, response.text
project_id = response.json["id"]

# Ensure there is only one project
_, response = await sanic_client.get("/api/data/projects", headers=user_headers)
assert response.status_code == 200, response.text
assert len(response.json) == 1

# Create a data connector in the project /test1/proj1/dc1
payload = {
"name": "dc1",
"namespace": "test1/prj1",
"slug": "dc1",
"storage": {
"configuration": {"type": "s3", "endpoint": "http://s3.aws.com"},
"source_path": "giab",
"target_path": "giab",
},
}
_, response = await sanic_client.post("/api/data/data_connectors", headers=user_headers, json=payload)
assert response.status_code == 201, response.text
dc_id = response.json["id"]

# Ensure there is only one project
_, response = await sanic_client.get("/api/data/projects", headers=user_headers)
assert response.status_code == 200, response.text
assert len(response.json) == 1

# Ensure that you can list the data connector
_, response = await sanic_client.get(f"/api/data/data_connectors/{dc_id}", headers=user_headers)
assert response.status_code == 200, response.text

# Link the data connector to the project
payload = {"project_id": project_id}
_, response = await sanic_client.post(
f"/api/data/data_connectors/{dc_id}/project_links", headers=user_headers, json=payload
)
assert response.status_code == 201, response.text

# Ensure that you can see the data connector link
_, response = await sanic_client.get(f"/api/data/data_connectors/{dc_id}/project_links", headers=user_headers)
assert response.status_code == 200, response.text
assert len(response.json) == 1
dc_link = response.json[0]
assert dc_link["project_id"] == project_id
assert dc_link["data_connector_id"] == dc_id

# Ensure that you can list data connectors
_, response = await sanic_client.get("/api/data/data_connectors", headers=user_headers)
assert response.status_code == 200, response.text
assert len(response.json) == 1

0 comments on commit c4608a3

Please sign in to comment.