Skip to content

Commit

Permalink
Check user belongs to dataset organization before creation (#437)
Browse files Browse the repository at this point in the history
* Check user belongs to dataset organization before creation

* Prevent ADMIN from creating dataset outside own org, refactor tests

* Refine

* Check user belongs to organization before update (#455)

* Check user belongs to organization before update

* Test ADMIN cannot update in other org either
  • Loading branch information
florimondmanca authored Sep 30, 2022
1 parent 0425084 commit fdc0df4
Show file tree
Hide file tree
Showing 19 changed files with 384 additions and 137 deletions.
29 changes: 23 additions & 6 deletions server/api/datasets/routes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from fastapi import APIRouter, Depends
from fastapi.exceptions import HTTPException
from starlette.responses import Response
Expand All @@ -7,21 +9,28 @@
DeleteDataset,
UpdateDataset,
)
from server.application.datasets.exceptions import (
CannotCreateDataset,
CannotUpdateDataset,
)
from server.application.datasets.queries import GetAllDatasets, GetDatasetByID
from server.application.datasets.views import DatasetView
from server.config.di import resolve
from server.domain.auth.entities import UserRole
from server.domain.catalogs.exceptions import CatalogDoesNotExist
from server.domain.common.pagination import Page, Pagination
from server.domain.common.types import ID
from server.domain.datasets.exceptions import DatasetDoesNotExist
from server.domain.datasets.specifications import DatasetSpec
from server.domain.organizations.exceptions import OrganizationDoesNotExist
from server.seedwork.application.messages import MessageBus

from ..auth.permissions import HasRole, IsAuthenticated
from ..types import APIRequest
from . import filters
from .schemas import DatasetCreate, DatasetListParams, DatasetUpdate

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/datasets", tags=["datasets"])

router.include_router(filters.router)
Expand Down Expand Up @@ -78,15 +87,18 @@ async def get_dataset_by_id(id: ID) -> DatasetView:
response_model=DatasetView,
status_code=201,
)
async def create_dataset(data: DatasetCreate) -> DatasetView:
async def create_dataset(data: DatasetCreate, request: "APIRequest") -> DatasetView:
bus = resolve(MessageBus)

command = CreateDataset(**data.dict())
command = CreateDataset(account=request.user.account, **data.dict())

try:
id = await bus.execute(command)
except OrganizationDoesNotExist as exc:
except CatalogDoesNotExist as exc:
raise HTTPException(400, detail=str(exc))
except CannotCreateDataset as exc:
logger.exception(exc)
raise HTTPException(403, detail="Permission denied")

query = GetDatasetByID(id=id)
return await bus.execute(query)
Expand All @@ -98,15 +110,20 @@ async def create_dataset(data: DatasetCreate) -> DatasetView:
response_model=DatasetView,
responses={404: {}},
)
async def update_dataset(id: ID, data: DatasetUpdate) -> DatasetView:
async def update_dataset(
id: ID, data: DatasetUpdate, request: "APIRequest"
) -> DatasetView:
bus = resolve(MessageBus)

command = UpdateDataset(id=id, **data.dict())
command = UpdateDataset(account=request.user.account, id=id, **data.dict())

try:
await bus.execute(command)
except DatasetDoesNotExist:
raise HTTPException(404)
except CannotUpdateDataset as exc:
logger.exception(exc)
raise HTTPException(403, detail="Permission denied")

query = GetDatasetByID(id=id)
return await bus.execute(query)
Expand Down
9 changes: 7 additions & 2 deletions server/application/datasets/commands.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import datetime as dt
from typing import List, Optional
from typing import List, Optional, Union

from pydantic import EmailStr, Field

from server.domain.auth.entities import Account
from server.domain.catalogs.entities import ExtraFieldValue
from server.domain.common.types import ID
from server.domain.common.types import ID, Skip
from server.domain.datasets.entities import DataFormat, UpdateFrequency
from server.domain.organizations.entities import LEGACY_ORGANIZATION
from server.domain.organizations.types import Siret
Expand All @@ -14,6 +15,8 @@


class CreateDataset(CreateDatasetValidationMixin, Command[ID]):
account: Union[Account, Skip]

organization_siret: Siret = LEGACY_ORGANIZATION.siret
title: str
description: str
Expand All @@ -32,6 +35,8 @@ class CreateDataset(CreateDatasetValidationMixin, Command[ID]):


class UpdateDataset(UpdateDatasetValidationMixin, Command[None]):
account: Union[Account, Skip]

id: ID
title: str
description: str
Expand Down
6 changes: 6 additions & 0 deletions server/application/datasets/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class CannotCreateDataset(Exception):
pass


class CannotUpdateDataset(Exception):
pass
34 changes: 23 additions & 11 deletions server/application/datasets/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,48 @@
from server.config.di import resolve
from server.domain.catalog_records.entities import CatalogRecord
from server.domain.catalog_records.repositories import CatalogRecordRepository
from server.domain.catalogs.exceptions import CatalogDoesNotExist
from server.domain.catalogs.repositories import CatalogRepository
from server.domain.common.pagination import Pagination
from server.domain.common.types import ID
from server.domain.common.types import ID, Skip
from server.domain.datasets.entities import DataFormat, Dataset
from server.domain.datasets.exceptions import DatasetDoesNotExist
from server.domain.datasets.repositories import DatasetRepository
from server.domain.organizations.exceptions import OrganizationDoesNotExist
from server.domain.organizations.repositories import OrganizationRepository
from server.domain.tags.repositories import TagRepository
from server.seedwork.application.messages import MessageBus

from .commands import CreateDataset, DeleteDataset, UpdateDataset
from .exceptions import CannotCreateDataset, CannotUpdateDataset
from .queries import GetAllDatasets, GetDatasetByID, GetDatasetFilters
from .specifications import can_create_dataset, can_update_dataset
from .views import DatasetFiltersView, DatasetView


async def create_dataset(command: CreateDataset, *, id_: ID = None) -> ID:
repository = resolve(DatasetRepository)
organization_repository = resolve(OrganizationRepository)
catalog_repository = resolve(CatalogRepository)
catalog_record_repository = resolve(CatalogRecordRepository)
tag_repository = resolve(TagRepository)

if id_ is None:
id_ = repository.make_id()

organization = await organization_repository.get_by_siret(
siret=command.organization_siret
)
catalog = await catalog_repository.get_by_siret(siret=command.organization_siret)

if catalog is None:
raise CatalogDoesNotExist(command.organization_siret)

if organization is None:
raise OrganizationDoesNotExist(command.organization_siret)
if not isinstance(command.account, Skip) and not can_create_dataset(
catalog, command.account
):
raise CannotCreateDataset(
f"{command.account.organization_siret=}, {catalog.organization.siret=}"
)

catalog_record_id = await catalog_record_repository.insert(
CatalogRecord(
id=catalog_record_repository.make_id(),
organization=organization,
organization=catalog.organization,
)
)
catalog_record = await catalog_record_repository.get_by_id(catalog_record_id)
Expand All @@ -65,9 +72,14 @@ async def update_dataset(command: UpdateDataset) -> None:
if dataset is None:
raise DatasetDoesNotExist(pk)

if not isinstance(command.account, Skip) and not can_update_dataset(
dataset, command.account
):
raise CannotUpdateDataset(f"{command.account=}, {dataset=}")

tags = await tag_repository.get_all(ids=command.tag_ids)
dataset.update(
**command.dict(exclude={"id", "tag_ids", "extra_field_values"}),
**command.dict(exclude={"account", "id", "tag_ids", "extra_field_values"}),
tags=tags,
extra_field_values=command.extra_field_values,
)
Expand Down
11 changes: 11 additions & 0 deletions server/application/datasets/specifications.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from server.domain.auth.entities import Account
from server.domain.catalogs.entities import Catalog
from server.domain.datasets.entities import Dataset


def can_create_dataset(catalog: Catalog, account: Account) -> bool:
return catalog.organization.siret == account.organization_siret


def can_update_dataset(dataset: Dataset, account: Account) -> bool:
return dataset.catalog_record.organization.siret == account.organization_siret
8 changes: 8 additions & 0 deletions server/domain/common/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import uuid
from typing import NewType

from pydantic import BaseModel

ID = NewType("ID", uuid.UUID)


def id_factory() -> ID:
return ID(uuid.uuid4())


class Skip(BaseModel):
"""
A marker class for when an operation should be skipped.
"""
16 changes: 13 additions & 3 deletions tests/api/test_catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,22 @@
from server.domain.organizations.types import Siret
from server.seedwork.application.messages import MessageBus

from ..factories import CreateDatasetFactory, CreateOrganizationFactory, fake
from ..helpers import TestPasswordUser, api_key_auth
from ..factories import (
CreateDatasetFactory,
CreateOrganizationFactory,
CreatePasswordUserFactory,
fake,
)
from ..helpers import TestPasswordUser, api_key_auth, create_test_password_user


@pytest.mark.asyncio
async def test_catalog_create(client: httpx.AsyncClient) -> None:
bus = resolve(MessageBus)
siret = await bus.execute(CreateOrganizationFactory.build(name="Org 1"))
user = await create_test_password_user(
CreatePasswordUserFactory.build(organization_siret=siret)
)

response = await client.post(
"/catalogs/", json={"organization_siret": str(siret)}, auth=api_key_auth
Expand All @@ -34,7 +42,9 @@ async def test_catalog_create(client: httpx.AsyncClient) -> None:
catalog = await bus.execute(GetCatalogBySiret(siret=siret))
assert catalog.organization.siret == siret

dataset_id = await bus.execute(CreateDatasetFactory.build(organization_siret=siret))
dataset_id = await bus.execute(
CreateDatasetFactory.build(account=user.account, organization_siret=siret)
)
dataset = await bus.execute(GetDatasetByID(id=dataset_id))
assert dataset.catalog_record.organization.siret == siret

Expand Down
Loading

0 comments on commit fdc0df4

Please sign in to comment.