Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check user belongs to organization before update #455

Merged
merged 2 commits into from
Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions server/api/datasets/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
DeleteDataset,
UpdateDataset,
)
from server.application.datasets.exceptions import CannotCreateDataset
from server.application.datasets.exceptions import (
CannotCreateDataset,
CannotUpdateDataset,
)
from server.application.datasets.queries import GetAllDatasets, GetDatasetByID
from server.application.datasets.views import DatasetView
from server.config.di import resolve
Expand Down Expand Up @@ -107,15 +110,20 @@ async def create_dataset(data: DatasetCreate, request: "APIRequest") -> DatasetV
response_model=DatasetView,
responses={404: {}},
)
async def update_dataset(id: ID, data: DatasetUpdate) -> DatasetView:
async def update_dataset(
id: ID, data: DatasetUpdate, request: "APIRequest"
) -> DatasetView:
bus = resolve(MessageBus)

command = UpdateDataset(id=id, **data.dict())
command = UpdateDataset(account=request.user.account, id=id, **data.dict())

try:
await bus.execute(command)
except DatasetDoesNotExist:
raise HTTPException(404)
except CannotUpdateDataset as exc:
logger.exception(exc)
raise HTTPException(403, detail="Permission denied")

query = GetDatasetByID(id=id)
return await bus.execute(query)
Expand Down
2 changes: 2 additions & 0 deletions server/application/datasets/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class CreateDataset(CreateDatasetValidationMixin, Command[ID]):


class UpdateDataset(UpdateDatasetValidationMixin, Command[None]):
account: Union[Account, Skip]

id: ID
title: str
description: str
Expand Down
4 changes: 4 additions & 0 deletions server/application/datasets/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
class CannotCreateDataset(Exception):
pass


class CannotUpdateDataset(Exception):
pass
11 changes: 8 additions & 3 deletions server/application/datasets/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from server.seedwork.application.messages import MessageBus

from .commands import CreateDataset, DeleteDataset, UpdateDataset
from .exceptions import CannotCreateDataset
from .exceptions import CannotCreateDataset, CannotUpdateDataset
from .queries import GetAllDatasets, GetDatasetByID, GetDatasetFilters
from .specifications import can_create_dataset
from .specifications import can_create_dataset, can_update_dataset
from .views import DatasetFiltersView, DatasetView


Expand Down Expand Up @@ -72,9 +72,14 @@ async def update_dataset(command: UpdateDataset) -> None:
if dataset is None:
raise DatasetDoesNotExist(pk)

if not isinstance(command.account, Skip) and not can_update_dataset(
dataset, command.account
):
raise CannotUpdateDataset(f"{command.account=}, {dataset=}")

tags = await tag_repository.get_all(ids=command.tag_ids)
dataset.update(
**command.dict(exclude={"id", "tag_ids", "extra_field_values"}),
**command.dict(exclude={"account", "id", "tag_ids", "extra_field_values"}),
tags=tags,
extra_field_values=command.extra_field_values,
)
Expand Down
5 changes: 5 additions & 0 deletions server/application/datasets/specifications.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from server.domain.auth.entities import Account
from server.domain.catalogs.entities import Catalog
from server.domain.datasets.entities import Dataset


def can_create_dataset(catalog: Catalog, account: Account) -> bool:
return catalog.organization.siret == account.organization_siret


def can_update_dataset(dataset: Dataset, account: Account) -> bool:
return dataset.catalog_record.organization.siret == account.organization_siret
95 changes: 67 additions & 28 deletions tests/api/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from server.application.tags.queries import GetTagByID
from server.config.di import resolve
from server.domain.catalogs.entities import ExtraFieldValue, TextExtraField
from server.domain.common.types import ID, id_factory
from server.domain.common.types import ID, Skip, id_factory
from server.domain.datasets.entities import DataFormat, UpdateFrequency
from server.domain.datasets.exceptions import DatasetDoesNotExist
from server.domain.organizations.entities import LEGACY_ORGANIZATION
Expand All @@ -26,7 +26,7 @@
CreateDatasetPayloadFactory,
CreateOrganizationFactory,
CreatePasswordUserFactory,
UpdateDatasetFactory,
UpdateDatasetPayloadFactory,
fake,
)
from ..helpers import TestPasswordUser, create_test_password_user, to_payload
Expand Down Expand Up @@ -257,6 +257,51 @@ async def test_update_not_authenticated(self, client: httpx.AsyncClient) -> None
response = await client.put(f"/datasets/{pk}/", json={})
assert response.status_code == 401

async def test_update_in_other_org_denied(
self, client: httpx.AsyncClient, temp_user: TestPasswordUser
) -> None:
bus = resolve(MessageBus)

other_org_siret = await bus.execute(CreateOrganizationFactory.build())
await bus.execute(CreateCatalog(organization_siret=other_org_siret))

command = CreateDatasetFactory.build(
organization_siret=other_org_siret, account=Skip()
)
dataset_id = await bus.execute(command)

payload = to_payload(
UpdateDatasetPayloadFactory.build_from_create_command(command)
)
response = await client.put(
f"/datasets/{dataset_id}/", json=payload, auth=temp_user.auth
)

assert response.status_code == 403

async def test_update_in_other_org_admin_denied(
self, client: httpx.AsyncClient, admin_user: TestPasswordUser
) -> None:
bus = resolve(MessageBus)

other_org_siret = await bus.execute(CreateOrganizationFactory.build())
assert admin_user.account.organization_siret != other_org_siret
await bus.execute(CreateCatalog(organization_siret=other_org_siret))

command = CreateDatasetFactory.build(
organization_siret=other_org_siret, account=Skip()
)
dataset_id = await bus.execute(command)

payload = to_payload(
UpdateDatasetPayloadFactory.build_from_create_command(command)
)
response = await client.put(
f"/datasets/{dataset_id}/", json=payload, auth=admin_user.auth
)

assert response.status_code == 403

async def test_delete_not_authenticated(self, client: httpx.AsyncClient) -> None:
pk = id_factory()
response = await client.delete(f"/datasets/{pk}/")
Expand Down Expand Up @@ -425,7 +470,7 @@ async def test_not_found(
pk = id_factory()
response = await client.put(
f"/datasets/{pk}/",
json=to_payload(UpdateDatasetFactory.build(id=pk)),
json=to_payload(UpdateDatasetPayloadFactory.build(id=pk)),
auth=temp_user.auth,
)
assert response.status_code == 404
Expand Down Expand Up @@ -480,13 +525,13 @@ async def test_fields_empty_invalid(
response = await client.put(
f"/datasets/{dataset_id}/",
json=to_payload(
UpdateDatasetFactory.build(
UpdateDatasetPayloadFactory.build_from_create_command(
command.copy(exclude={"title", "description", "service", "url"}),
factory_use_construct=True, # Skip validation
title="",
description="",
service="",
url="",
**command.dict(exclude={"title", "description", "service", "url"}),
)
),
auth=temp_user.auth,
Expand Down Expand Up @@ -523,7 +568,7 @@ async def test_update(
other_last_updated_at = fake.date_time_tz()

payload = to_payload(
UpdateDatasetFactory.build(
UpdateDatasetPayloadFactory.build(
title="Other title",
description="Other description",
service="Other service",
Expand Down Expand Up @@ -603,9 +648,9 @@ async def test_formats_add(
response = await client.put(
f"/datasets/{dataset_id}/",
json=to_payload(
UpdateDatasetFactory.build(
UpdateDatasetPayloadFactory.build_from_create_command(
command.copy(exclude={"formats"}),
formats=[DataFormat.WEBSITE, DataFormat.API, DataFormat.FILE_GIS],
**command.dict(exclude={"formats"}),
)
),
auth=temp_user.auth,
Expand All @@ -627,9 +672,9 @@ async def test_formats_remove(
response = await client.put(
f"/datasets/{dataset_id}/",
json=to_payload(
UpdateDatasetFactory.build(
UpdateDatasetPayloadFactory.build_from_create_command(
command.copy(exclude={"formats"}),
formats=[DataFormat.WEBSITE],
**command.dict(exclude={"formats"}),
)
),
auth=temp_user.auth,
Expand All @@ -654,9 +699,9 @@ async def test_tags_add(
response = await client.put(
f"/datasets/{dataset_id}/",
json=to_payload(
UpdateDatasetFactory.build(
UpdateDatasetPayloadFactory.build_from_create_command(
command.copy(exclude={"tag_ids"}),
tag_ids=[str(tag_architecture_id)],
**command.dict(exclude={"tag_ids"}),
)
),
auth=temp_user.auth,
Expand Down Expand Up @@ -684,9 +729,9 @@ async def test_tags_remove(
response = await client.put(
f"/datasets/{dataset_id}/",
json=to_payload(
UpdateDatasetFactory.build(
UpdateDatasetPayloadFactory.build_from_create_command(
command.copy(exclude={"tag_ids"}),
tag_ids=[],
**command.dict(exclude={"tag_ids"}),
)
),
auth=temp_user.auth,
Expand Down Expand Up @@ -752,9 +797,7 @@ async def test_create_dataset_with_extra_field_values(
}
]

async def test_add_extra_field_value(
self, client: httpx.AsyncClient, temp_user: TestPasswordUser
) -> None:
async def test_add_extra_field_value(self, client: httpx.AsyncClient) -> None:
bus = resolve(MessageBus)
siret, user, extra_field_id = await self._setup()

Expand All @@ -766,19 +809,18 @@ async def test_add_extra_field_value(
assert not dataset.extra_field_values

payload = to_payload(
UpdateDatasetFactory.build(
id=dataset_id,
UpdateDatasetPayloadFactory.build_from_create_command(
command.copy(exclude={"extra_field_values"}),
extra_field_values=[
ExtraFieldValue(
extra_field_id=extra_field_id,
value="Environ 10 To",
)
],
**command.dict(exclude={"extra_field_values"}),
)
)
response = await client.put(
f"/datasets/{dataset_id}/", json=payload, auth=temp_user.auth
f"/datasets/{dataset_id}/", json=payload, auth=user.auth
)
assert response.status_code == 200
data = response.json()
Expand All @@ -789,9 +831,7 @@ async def test_add_extra_field_value(
}
]

async def test_remove_extra_field_value(
self, client: httpx.AsyncClient, temp_user: TestPasswordUser
) -> None:
async def test_remove_extra_field_value(self, client: httpx.AsyncClient) -> None:
bus = resolve(MessageBus)
siret, user, extra_field_id = await self._setup()

Expand All @@ -810,14 +850,13 @@ async def test_remove_extra_field_value(
assert len(dataset.extra_field_values) == 1

payload = to_payload(
UpdateDatasetFactory.build(
id=dataset_id,
UpdateDatasetPayloadFactory.build_from_create_command(
command.copy(exclude={"extra_field_values"}),
extra_field_values=[],
**command.dict(exclude={"extra_field_values"}),
)
)
response = await client.put(
f"/datasets/{dataset_id}/", json=payload, auth=temp_user.auth
f"/datasets/{dataset_id}/", json=payload, auth=user.auth
)
assert response.status_code == 200
data = response.json()
Expand Down
7 changes: 5 additions & 2 deletions tests/api/test_datasets_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ async def test_search_results_change_when_data_changes(

# Update dataset title
update_command = UpdateDatasetFactory.build(
id=pk, title="Modifié", **command.dict(exclude={"title"})
account=temp_user.account,
id=pk,
title="Modifié",
**command.dict(exclude={"title", "account", "organization_siret"}),
)

await bus.execute(update_command)
Expand All @@ -200,7 +203,7 @@ async def test_search_results_change_when_data_changes(
# Same on description
update_command = UpdateDatasetFactory.build(
description="Jeu de données spécial",
**update_command.dict(exclude={"description"})
**update_command.dict(exclude={"description"}),
)
await bus.execute(update_command)
response = await client.get(
Expand Down
24 changes: 20 additions & 4 deletions tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel
from pydantic_factories import ModelFactory, Require, Use

from server.api.datasets.schemas import DatasetCreate
from server.api.datasets.schemas import DatasetCreate, DatasetUpdate
from server.application.auth.commands import CreateDataPassUser, CreatePasswordUser
from server.application.datasets.commands import CreateDataset, UpdateDataset
from server.application.organizations.commands import CreateOrganization
Expand Down Expand Up @@ -93,13 +93,29 @@ class CreateDatasetPayloadFactory(_BaseCreateDatasetFactory, Factory[DatasetCrea
__model__ = DatasetCreate


class UpdateDatasetFactory(Factory[UpdateDataset]):
__model__ = UpdateDataset

class _BaseUpdateDatasetFactory:
tag_ids = Use(lambda: [])
extra_field_values = Use(lambda: [])


class UpdateDatasetFactory(_BaseCreateDatasetFactory, Factory[UpdateDataset]):
__model__ = UpdateDataset

account = Require()


class UpdateDatasetPayloadFactory(_BaseUpdateDatasetFactory, Factory[DatasetUpdate]):
__model__ = DatasetUpdate

@classmethod
def build_from_create_command(
cls, command: CreateDataset, **kwargs: Any
) -> DatasetUpdate:
return cls.build(
**command.dict(exclude={"account", "organization_siret"}), **kwargs
)


class CreateOrganizationFactory(Factory[CreateOrganization]):
__model__ = CreateOrganization

Expand Down
3 changes: 2 additions & 1 deletion tests/tools/test_initdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from server.application.datasets.commands import UpdateDataset
from server.application.datasets.queries import GetAllDatasets, GetDatasetByID
from server.config.di import resolve
from server.domain.common.types import ID
from server.domain.common.types import ID, Skip
from server.seedwork.application.messages import MessageBus
from tools import initdata

Expand Down Expand Up @@ -146,6 +146,7 @@ async def test_repo_initdata(

# Make a change.
command = UpdateDataset(
account=Skip(),
**dataset.dict(exclude={"title"}),
tag_ids=[tag.id for tag in dataset.tags],
title="Changed",
Expand Down
2 changes: 1 addition & 1 deletion tools/initdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _get_dataset_attr(dataset: Dataset, attr: str) -> Any:
existing_dataset = await repository.get_by_id(id_)

if existing_dataset is not None:
update_command = UpdateDataset(id=id_, **item["params"])
update_command = UpdateDataset(account=Skip(), id=id_, **item["params"])

changed = any(
getattr(update_command, k) != _get_dataset_attr(existing_dataset, k)
Expand Down