diff --git a/server/api/dataformats/routes.py b/server/api/dataformats/routes.py index 1cfb689f..d8a535e7 100644 --- a/server/api/dataformats/routes.py +++ b/server/api/dataformats/routes.py @@ -1,9 +1,12 @@ import logging from typing import List -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException -from server.application.dataformats.queries import GetAllDataFormat +from server.api.dataformats.schemas import DataFormatCreate +from server.application.dataformats.commands import CreateDataFormat +from server.application.dataformats.exceptions import CannotCreateDataFormat +from server.application.dataformats.queries import GetAllDataFormat, GetDataFormatById from server.application.dataformats.views import DataFormatView from server.config.di import resolve from server.seedwork.application.messages import MessageBus @@ -24,3 +27,25 @@ async def list_dataformat() -> List[DataFormatView]: bus = resolve(MessageBus) return await bus.execute(GetAllDataFormat()) + + +@router.post( + "/", + dependencies=[Depends(IsAuthenticated())], + response_model=DataFormatView, + status_code=201, +) +async def create_dataformat(data: DataFormatCreate) -> DataFormatView: + bus = resolve(MessageBus) + print(data) + + command = CreateDataFormat(value=data.value) + + try: + id = await bus.execute(command) + query = GetDataFormatById(id=id) + return await bus.execute(query) + + except CannotCreateDataFormat as exc: + logger.exception(exc) + raise HTTPException(403, detail="Permission denied") diff --git a/server/api/dataformats/schemas.py b/server/api/dataformats/schemas.py new file mode 100644 index 00000000..2833b33f --- /dev/null +++ b/server/api/dataformats/schemas.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + +from server.application.dataformats.validation import CreateDataFormatValidationMixin + + +class DataFormatCreate(CreateDataFormatValidationMixin, BaseModel): + value: str diff --git a/server/application/dataformats/commands.py b/server/application/dataformats/commands.py new file mode 100644 index 00000000..5034314c --- /dev/null +++ b/server/application/dataformats/commands.py @@ -0,0 +1,6 @@ +from server.application.dataformats.validation import CreateDataFormatValidationMixin +from server.seedwork.application.commands import Command + + +class CreateDataFormat(CreateDataFormatValidationMixin, Command[int]): + value: str diff --git a/server/application/dataformats/exceptions.py b/server/application/dataformats/exceptions.py new file mode 100644 index 00000000..9aa3c71f --- /dev/null +++ b/server/application/dataformats/exceptions.py @@ -0,0 +1,2 @@ +class CannotCreateDataFormat(Exception): + pass diff --git a/server/application/dataformats/handlers.py b/server/application/dataformats/handlers.py index fd27a4b0..875b569d 100644 --- a/server/application/dataformats/handlers.py +++ b/server/application/dataformats/handlers.py @@ -1,8 +1,10 @@ -from typing import List +from typing import List, Optional -from server.application.dataformats.queries import GetAllDataFormat +from server.application.dataformats.commands import CreateDataFormat +from server.application.dataformats.queries import GetAllDataFormat, GetDataFormatById from server.application.dataformats.views import DataFormatView from server.config.di import resolve +from server.domain.dataformats.entities import DataFormat from server.domain.dataformats.repositories import DataFormatRepository @@ -10,3 +12,17 @@ async def get_all_dataformats(query: GetAllDataFormat) -> List[DataFormatView]: repository = resolve(DataFormatRepository) dataformats = await repository.get_all() return [DataFormatView(**dataformat.dict()) for dataformat in dataformats] + + +async def create_dataformat(command: CreateDataFormat) -> Optional[int]: + repository = resolve(DataFormatRepository) + return await repository.insert(DataFormat(name=command.value)) + + +async def get_by_id(query: GetDataFormatById) -> Optional[DataFormatView]: + repository = resolve(DataFormatRepository) + dataformat = await repository.get_by_id(id=query.id) + + if dataformat is not None: + return DataFormatView(**dataformat.dict()) + return None diff --git a/server/application/dataformats/queries.py b/server/application/dataformats/queries.py index db7aeba6..08ca2c0d 100644 --- a/server/application/dataformats/queries.py +++ b/server/application/dataformats/queries.py @@ -7,3 +7,7 @@ class GetAllDataFormat(Query[List[DataFormatView]]): pass + + +class GetDataFormatById(Query[DataFormatView]): + id: int diff --git a/server/application/dataformats/validation.py b/server/application/dataformats/validation.py new file mode 100644 index 00000000..aa54c5ee --- /dev/null +++ b/server/application/dataformats/validation.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel, validator + + +class CreateDataFormatValidationMixin(BaseModel): + @validator("value", check_fields=False) + def check_value_at_least_one(cls, value: str) -> str: + if not value: + raise ValueError("dataformat must have a value") + return value diff --git a/server/domain/dataformats/repositories.py b/server/domain/dataformats/repositories.py index 1dd772ca..360d8be5 100644 --- a/server/domain/dataformats/repositories.py +++ b/server/domain/dataformats/repositories.py @@ -14,3 +14,6 @@ async def get_all(self, ids: List[Optional[int]] = None) -> List[DataFormat]: async def get_by_name(self, name: str) -> Optional[DataFormat]: raise NotImplementedError # pragma: no cover + + async def get_by_id(self, id: int) -> Optional[DataFormat]: + raise NotImplementedError # pragma: no cover diff --git a/server/infrastructure/dataformats/module.py b/server/infrastructure/dataformats/module.py index 141421c9..744150fe 100644 --- a/server/infrastructure/dataformats/module.py +++ b/server/infrastructure/dataformats/module.py @@ -1,7 +1,16 @@ -from server.application.dataformats.handlers import get_all_dataformats -from server.application.dataformats.queries import GetAllDataFormat +from server.application.dataformats.commands import CreateDataFormat +from server.application.dataformats.handlers import ( + create_dataformat, + get_all_dataformats, + get_by_id, +) +from server.application.dataformats.queries import GetAllDataFormat, GetDataFormatById from server.seedwork.application.modules import Module class DataFormatModule(Module): - query_handlers = {GetAllDataFormat: get_all_dataformats} + query_handlers = { + GetAllDataFormat: get_all_dataformats, + GetDataFormatById: get_by_id, + } + command_handlers = {CreateDataFormat: create_dataformat} diff --git a/server/infrastructure/dataformats/repositories.py b/server/infrastructure/dataformats/repositories.py index 8671399b..46e9c1e8 100644 --- a/server/infrastructure/dataformats/repositories.py +++ b/server/infrastructure/dataformats/repositories.py @@ -43,3 +43,12 @@ async def get_by_name(self, name: str) -> Optional[DataFormat]: if instance is None: return None return make_entity(instance) + + async def get_by_id(self, id: int) -> Optional[DataFormat]: + async with self._db.session() as session: + stmt = select(DataFormatModel).where(DataFormatModel.id == id) + result = await session.execute(stmt) + instance = result.unique().scalar_one_or_none() + if instance is None: + return None + return make_entity(instance) diff --git a/tests/api/test_dataformats.py b/tests/api/test_dataformats.py index e4a358d0..bf0abe2d 100644 --- a/tests/api/test_dataformats.py +++ b/tests/api/test_dataformats.py @@ -30,3 +30,17 @@ async def test_dataformat_list_not_authenticated( ) -> None: response = await client.get("/dataformats/") assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_create_dataformat( + client: httpx.AsyncClient, temp_user: TestPasswordUser +) -> None: + + payload = {"value": "toto"} + + response = await client.post("/dataformats/", json=payload, auth=temp_user.auth) + assert response.status_code == 201 + + data = response.json() + assert data["name"] == "toto"