Skip to content

Commit

Permalink
Add tests and fix cascades
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Jul 19, 2022
1 parent 3c26e97 commit 484268d
Show file tree
Hide file tree
Showing 12 changed files with 216 additions and 83 deletions.
3 changes: 3 additions & 0 deletions server/application/datasets/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@

from server.domain.common.types import ID
from server.domain.datasets.entities import DataFormat, UpdateFrequency
from server.domain.organizations.entities import LEGACY_ORGANIZATION_SIRET
from server.domain.organizations.types import Siret
from server.seedwork.application.commands import Command

from .validation import CreateDatasetValidationMixin, UpdateDatasetValidationMixin


class CreateDataset(CreateDatasetValidationMixin, Command[ID]):
organization_siret: Siret = LEGACY_ORGANIZATION_SIRET
title: str
description: str
service: str
Expand Down
3 changes: 1 addition & 2 deletions server/application/datasets/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from server.domain.datasets.entities import DataFormat, Dataset
from server.domain.datasets.exceptions import DatasetDoesNotExist
from server.domain.datasets.repositories import DatasetRepository
from server.domain.organizations.entities import LEGACY_ORGANIZATION_SIRET
from server.domain.tags.repositories import TagRepository
from server.seedwork.application.messages import MessageBus

Expand All @@ -28,7 +27,7 @@ async def create_dataset(command: CreateDataset, *, id_: ID = None) -> ID:
catalog_record_id = await catalog_record_repository.insert(
CatalogRecord(
id=catalog_record_repository.make_id(),
organization_siret=LEGACY_ORGANIZATION_SIRET,
organization_siret=command.organization_siret,
)
)
catalog_record = await catalog_record_repository.get_by_id(catalog_record_id)
Expand Down
3 changes: 1 addition & 2 deletions server/infrastructure/auth/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ class UserModel(Base):

id: uuid.UUID = Column(UUID(as_uuid=True), primary_key=True)
organization_siret: Siret = Column(
CHAR(14), ForeignKey("organization.siret"), nullable=False
CHAR(14), ForeignKey("organization.siret", ondelete="CASCADE"), nullable=False
)
organization: "OrganizationModel" = relationship(
"OrganizationModel",
back_populates="users",
cascade="delete",
)
email = Column(String, nullable=False, unique=True, index=True)
password_hash = Column(String, nullable=False)
Expand Down
17 changes: 11 additions & 6 deletions server/infrastructure/catalog_records/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,23 @@ class CatalogRecordModel(Base):

id: ID = Column(UUID(as_uuid=True), primary_key=True)
organization_siret: Siret = Column(
CHAR(14), ForeignKey("catalog.organization_siret"), nullable=False
CHAR(14),
ForeignKey("catalog.organization_siret", ondelete="CASCADE"),
nullable=False,
)
created_at: dt.datetime = Column(
DateTime(timezone=True), server_default=func.clock_timestamp(), nullable=False
)

catalog: "CatalogModel" = relationship(
"CatalogModel", back_populates="catalog_records"
"CatalogModel",
back_populates="catalog_records",
)
dataset_id: ID = Column(UUID(as_uuid=True), ForeignKey("dataset.id"))

dataset: "DatasetModel" = relationship(
"DatasetModel",
back_populates="catalog_record",
)
created_at: dt.datetime = Column(
DateTime(timezone=True), server_default=func.clock_timestamp(), nullable=False
cascade="delete",
)


Expand Down
17 changes: 9 additions & 8 deletions server/infrastructure/catalogs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
from typing import TYPE_CHECKING, List

from sqlalchemy import CHAR, Column, DateTime, ForeignKey, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship

from server.domain.common.types import ID
from server.domain.organizations.types import Siret

from ..database import Base
Expand All @@ -18,20 +16,23 @@
class CatalogModel(Base):
__tablename__ = "catalog"

id: ID = Column(UUID(as_uuid=True), primary_key=True)
created_at: dt.datetime = Column(
DateTime(timezone=True), server_default=func.clock_timestamp(), nullable=False
)
organization_siret: Siret = Column(
CHAR(14), ForeignKey("organization.siret"), nullable=False
CHAR(14),
ForeignKey("organization.siret", ondelete="CASCADE"),
primary_key=True,
)
organization: "OrganizationModel" = relationship(
"OrganizationModel",
back_populates="catalog",
cascade="delete",
)
created_at: dt.datetime = Column(
DateTime(timezone=True),
server_default=func.clock_timestamp(),
nullable=False,
)

catalog_records: List["CatalogRecordModel"] = relationship(
"CatalogRecordModel",
back_populates="catalog",
cascade="delete",
)
6 changes: 5 additions & 1 deletion server/infrastructure/datasets/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,15 @@ class DatasetModel(Base):

id: uuid.UUID = Column(UUID(as_uuid=True), primary_key=True)

catalog_record_id: uuid.UUID = Column(
UUID(as_uuid=True),
ForeignKey("catalog_record.id", ondelete="CASCADE"),
nullable=False,
)
catalog_record: "CatalogRecordModel" = relationship(
"CatalogRecordModel",
back_populates="dataset",
cascade="delete",
lazy="joined",
uselist=False,
)

Expand Down
4 changes: 3 additions & 1 deletion server/infrastructure/datasets/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import contains_eager, selectinload

from server.domain.common.pagination import Page
from server.domain.common.types import ID
Expand Down Expand Up @@ -48,8 +48,10 @@ async def _maybe_get_by_id(
) -> Optional[DatasetModel]:
stmt = (
select(DatasetModel)
.join(DatasetModel.catalog_record)
.where(DatasetModel.id == id)
.options(
contains_eager(DatasetModel.catalog_record),
selectinload(DatasetModel.formats),
selectinload(DatasetModel.tags),
)
Expand Down
1 change: 1 addition & 0 deletions server/infrastructure/organizations/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class OrganizationModel(Base):
catalog: "CatalogModel" = relationship(
"CatalogModel",
back_populates="organization",
cascade="delete",
uselist=False,
)
users: List["UserModel"] = relationship(
Expand Down
105 changes: 105 additions & 0 deletions server/migrations/versions/4e40358ad25c_add_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""add-catalog
Revision ID: 4e40358ad25c
Revises: d9ea6ea6708f
Create Date: 2022-07-19 15:03:34.512545
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "4e40358ad25c"
down_revision = "d9ea6ea6708f"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"catalog",
sa.Column("organization_siret", sa.CHAR(length=14), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("clock_timestamp()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["organization_siret"], ["organization.siret"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("organization_siret"),
)

# Create initial organization.
op.execute(
"INSERT INTO organization (siret, name) "
"VALUES ('00000000000000', 'Organisation par défaut');"
)

# Create its catalog.
op.execute("INSERT INTO catalog (organization_siret) VALUES ('00000000000000');")

# Add all users to the initial organization.
op.add_column("user", sa.Column("organization_siret", sa.CHAR(14)))
op.execute("UPDATE \"user\" SET organization_siret = '00000000000000';")
op.alter_column("user", "organization_siret", nullable=False)
op.create_foreign_key(
None,
"user",
"organization",
["organization_siret"],
["siret"],
ondelete="CASCADE",
)

# Add all catalog records to the initial catalog.
op.add_column("catalog_record", sa.Column("organization_siret", sa.CHAR(14)))
op.execute("UPDATE catalog_record SET organization_siret = '00000000000000';")
op.alter_column("catalog_record", "organization_siret", nullable=False)
op.create_foreign_key(
None,
"catalog_record",
"catalog",
["organization_siret"],
["organization_siret"],
ondelete="CASCADE",
)

# Swap the 1-1 relationship between CatalogRecord and Dataset, as the CASCADE
# must be on the Dataset side: when a catalog record is removed, the dataset must
# be dropped too.
op.drop_constraint(
"catalog_record_dataset_id_fkey", "catalog_record", type_="foreignkey"
)
op.add_column(
"dataset",
sa.Column(
"catalog_record_id",
postgresql.UUID(as_uuid=True),
),
)
op.execute(
"UPDATE dataset SET catalog_record_id = catalog_record.id "
"FROM dataset AS d JOIN catalog_record ON d.id = catalog_record.dataset_id;"
)
op.alter_column("dataset", "catalog_record_id", nullable=False)
op.drop_column("catalog_record", "dataset_id")
op.create_foreign_key(
None,
"dataset",
"catalog_record",
["catalog_record_id"],
["id"],
ondelete="CASCADE",
)


def downgrade():
op.drop_constraint(None, "user", type_="foreignkey")
op.drop_column("user", "organization_siret")
op.drop_constraint(None, "catalog_record", type_="foreignkey")
op.drop_column("catalog_record", "organization_siret")
op.drop_table("catalog")
63 changes: 0 additions & 63 deletions server/migrations/versions/cfb6eef87415_add_catalog.py

This file was deleted.

2 changes: 2 additions & 0 deletions tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from server.domain.common import datetime as dtutil
from server.domain.datasets.entities import DataFormat
from server.domain.licenses.entities import BUILTIN_LICENSE_SUGGESTIONS
from server.domain.organizations.entities import LEGACY_ORGANIZATION_SIRET

T = TypeVar("T", bound=BaseModel)

Expand Down Expand Up @@ -48,6 +49,7 @@ class CreateTagFactory(Factory[CreateTag]):
class CreateDatasetFactory(Factory[CreateDataset]):
__model__ = CreateDataset

organization_siret = Use(lambda: LEGACY_ORGANIZATION_SIRET)
title = Use(fake.sentence)
description = Use(fake.text)
service = Use(fake.company)
Expand Down
Loading

0 comments on commit 484268d

Please sign in to comment.