diff --git a/server/api/auth/backends/token.py b/server/api/auth/backends/token.py index d48a70c42..51652b2a4 100644 --- a/server/api/auth/backends/token.py +++ b/server/api/auth/backends/token.py @@ -8,9 +8,9 @@ ) from starlette.requests import HTTPConnection -from server.application.auth.queries import GetUserByAPIToken +from server.application.auth.queries import GetAccountByAPIToken from server.config.di import resolve -from server.domain.auth.exceptions import UserDoesNotExist +from server.domain.auth.exceptions import AccountDoesNotExist from server.seedwork.application.messages import MessageBus from ..models import ApiUser @@ -44,11 +44,11 @@ async def authenticate( bus = resolve(MessageBus) - query = GetUserByAPIToken(api_token=api_token) + query = GetAccountByAPIToken(api_token=api_token) try: - user = await bus.execute(query) - except UserDoesNotExist: + account = await bus.execute(query) + except AccountDoesNotExist: raise AuthenticationError() - return AuthCredentials(scopes=["authenticated"]), ApiUser(user) + return AuthCredentials(scopes=["authenticated"]), ApiUser(account) diff --git a/server/api/auth/models.py b/server/api/auth/models.py index fd95d5662..85381f405 100644 --- a/server/api/auth/models.py +++ b/server/api/auth/models.py @@ -2,29 +2,29 @@ from starlette.authentication import BaseUser -from server.application.auth.views import UserView +from server.application.auth.views import AccountView class ApiUser(BaseUser): - def __init__(self, user: Optional[UserView]) -> None: - self._user = user + def __init__(self, account: Optional[AccountView]) -> None: + self._account = account @property - def obj(self) -> UserView: - if self._user is None: + def account(self) -> AccountView: + if self._account is None: raise RuntimeError( - "Cannot access .obj, as the user is anonymous. " + "Cannot access .account, as the user is anonymous. " "Hint: did you forget to check for .is_authenticated?" ) - return self._user + return self._account # Implement the 'BaseUser' interface. @property def is_authenticated(self) -> bool: - return self._user is not None + return self._account is not None @property def display_name(self) -> str: - return self._user.email if self._user is not None else "" + return self._account.email if self._account is not None else "" diff --git a/server/api/auth/permissions.py b/server/api/auth/permissions.py index fb23527fb..e97497cc2 100644 --- a/server/api/auth/permissions.py +++ b/server/api/auth/permissions.py @@ -130,7 +130,7 @@ def has_permission(self, request: APIRequest) -> bool: "Hint: use IsAuthenticated() & HasRole(...)" ) - return request.user.obj.role in self._roles + return request.user.account.role in self._roles def _patch_openapi_security_params(*permissions: BasePermission) -> Callable: diff --git a/server/api/auth/routes.py b/server/api/auth/routes.py index 5da0ef105..1dea8f45b 100644 --- a/server/api/auth/routes.py +++ b/server/api/auth/routes.py @@ -1,8 +1,8 @@ from fastapi import APIRouter, Depends, HTTPException -from server.application.auth.commands import CreateUser, DeleteUser -from server.application.auth.queries import GetUserByEmail, Login -from server.application.auth.views import AuthenticatedUserView, UserView +from server.application.auth.commands import CreatePasswordUser, DeletePasswordUser +from server.application.auth.queries import GetAccountByEmail, LoginPasswordUser +from server.application.auth.views import AccountView, AuthenticatedAccountView from server.config.di import resolve from server.domain.auth.entities import UserRole from server.domain.auth.exceptions import EmailAlreadyExists, LoginFailed @@ -10,7 +10,7 @@ from server.seedwork.application.messages import MessageBus from .permissions import HasRole, IsAuthenticated -from .schemas import CheckAuthResponse, UserCreate, UserLogin +from .schemas import CheckAuthResponse, PasswordUserCreate, PasswordUserLogin router = APIRouter(prefix="/auth", tags=["auth"]) @@ -18,28 +18,28 @@ @router.post( "/users/", dependencies=[Depends(IsAuthenticated() & HasRole(UserRole.ADMIN))], - response_model=UserView, + response_model=AccountView, status_code=201, ) -async def create_user(data: UserCreate) -> UserView: +async def create_password_user(data: PasswordUserCreate) -> AccountView: bus = resolve(MessageBus) - command = CreateUser(email=data.email, password=data.password) + command = CreatePasswordUser(email=data.email, password=data.password) try: await bus.execute(command) except EmailAlreadyExists as exc: raise HTTPException(400, detail=str(exc)) - query = GetUserByEmail(email=data.email) + query = GetAccountByEmail(email=data.email) return await bus.execute(query) -@router.post("/login/", response_model=AuthenticatedUserView) -async def login(data: UserLogin) -> AuthenticatedUserView: +@router.post("/login/", response_model=AuthenticatedAccountView) +async def login_password_user(data: PasswordUserLogin) -> AuthenticatedAccountView: bus = resolve(MessageBus) - query = Login(email=data.email, password=data.password) + query = LoginPasswordUser(email=data.email, password=data.password) try: return await bus.execute(query) @@ -52,10 +52,10 @@ async def login(data: UserLogin) -> AuthenticatedUserView: dependencies=[Depends(IsAuthenticated() & HasRole(UserRole.ADMIN))], status_code=204, ) -async def delete_user(id: ID) -> None: +async def delete_password_user(id: ID) -> None: bus = resolve(MessageBus) - command = DeleteUser(id=id) + command = DeletePasswordUser(account_id=id) await bus.execute(command) diff --git a/server/api/auth/schemas.py b/server/api/auth/schemas.py index 545ba8515..e12d2c195 100644 --- a/server/api/auth/schemas.py +++ b/server/api/auth/schemas.py @@ -3,12 +3,12 @@ from pydantic import BaseModel, EmailStr, SecretStr -class UserCreate(BaseModel): +class PasswordUserCreate(BaseModel): email: EmailStr password: SecretStr -class UserLogin(BaseModel): +class PasswordUserLogin(BaseModel): email: EmailStr password: SecretStr diff --git a/server/application/auth/commands.py b/server/application/auth/commands.py index 9e730817a..8099f75af 100644 --- a/server/application/auth/commands.py +++ b/server/application/auth/commands.py @@ -6,14 +6,14 @@ from server.seedwork.application.commands import Command -class CreateUser(Command[ID]): +class CreatePasswordUser(Command[ID]): organization_siret: Siret = LEGACY_ORGANIZATION_SIRET email: EmailStr password: SecretStr -class DeleteUser(Command[None]): - id: ID +class DeletePasswordUser(Command[None]): + account_id: ID class ChangePassword(Command[None]): diff --git a/server/application/auth/handlers.py b/server/application/auth/handlers.py index 2a8cafbf1..24ffff5a5 100644 --- a/server/application/auth/handlers.py +++ b/server/application/auth/handlers.py @@ -1,23 +1,23 @@ -from server.application.auth.views import AuthenticatedUserView, UserView +from server.application.auth.views import AccountView, AuthenticatedAccountView from server.config.di import resolve -from server.domain.auth.entities import User, UserRole +from server.domain.auth.entities import Account, PasswordUser, UserRole from server.domain.auth.exceptions import ( + AccountDoesNotExist, EmailAlreadyExists, LoginFailed, - UserDoesNotExist, ) -from server.domain.auth.repositories import UserRepository +from server.domain.auth.repositories import AccountRepository, PasswordUserRepository from server.domain.common.types import ID -from .commands import ChangePassword, CreateUser, DeleteUser +from .commands import ChangePassword, CreatePasswordUser, DeletePasswordUser from .passwords import PasswordEncoder, generate_api_token -from .queries import GetUserByAPIToken, GetUserByEmail, Login +from .queries import GetAccountByAPIToken, GetAccountByEmail, LoginPasswordUser -async def create_user( - command: CreateUser, *, id_: ID = None, role: UserRole = UserRole.USER +async def create_password_user( + command: CreatePasswordUser, *, id_: ID = None, role: UserRole = UserRole.USER ) -> ID: - repository = resolve(UserRepository) + repository = resolve(PasswordUserRepository) password_encoder = resolve(PasswordEncoder) if id_ is None: @@ -25,82 +25,89 @@ async def create_user( email = command.email - user = await repository.get_by_email(email) + password_user = await repository.get_by_email(email) - if user is not None: + if password_user is not None: raise EmailAlreadyExists(email) password_hash = password_encoder.hash(command.password) api_token = generate_api_token() - user = User( + account = Account( id=id_, organization_siret=command.organization_siret, email=email, - password_hash=password_hash, role=role, api_token=api_token, ) - return await repository.insert(user) + password_user = PasswordUser( + account_id=id_, + account=account, + password_hash=password_hash, + ) + + return await repository.insert(password_user) -async def delete_user(command: DeleteUser) -> None: - repository = resolve(UserRepository) - await repository.delete(command.id) +async def delete_password_user(command: DeletePasswordUser) -> None: + repository = resolve(PasswordUserRepository) + await repository.delete(command.account_id) -async def login(query: Login) -> AuthenticatedUserView: - repository = resolve(UserRepository) +async def login_password_user(query: LoginPasswordUser) -> AuthenticatedAccountView: + repository = resolve(PasswordUserRepository) password_encoder = resolve(PasswordEncoder) - user = await repository.get_by_email(query.email) + password_user = await repository.get_by_email(query.email) - if user is None: + if password_user is None: password_encoder.hash(query.password) # Mitigate timing attacks. raise LoginFailed("Invalid credentials") - if not password_encoder.verify(password=query.password, hash=user.password_hash): + if not password_encoder.verify( + password=query.password, hash=password_user.password_hash + ): raise LoginFailed("Invalid credentials") - return AuthenticatedUserView(**user.dict()) + return AuthenticatedAccountView(**password_user.account.dict()) -async def get_user_by_email(query: GetUserByEmail) -> UserView: - repository = resolve(UserRepository) +async def get_account_by_email(query: GetAccountByEmail) -> AccountView: + repository = resolve(AccountRepository) email = query.email - user = await repository.get_by_email(email) + account = await repository.get_by_email(email) - if user is None: - raise UserDoesNotExist(email) + if account is None: + raise AccountDoesNotExist(email) - return UserView(**user.dict()) + return AccountView(**account.dict()) -async def get_user_by_api_token(query: GetUserByAPIToken) -> UserView: - repository = resolve(UserRepository) +async def get_account_by_api_token(query: GetAccountByAPIToken) -> AccountView: + repository = resolve(AccountRepository) - user = await repository.get_by_api_token(query.api_token) + account = await repository.get_by_api_token(query.api_token) - if user is None: - raise UserDoesNotExist("__token__") + if account is None: + raise AccountDoesNotExist("__token__") - return UserView(**user.dict()) + return AccountView(**account.dict()) async def change_password(command: ChangePassword) -> None: - repository = resolve(UserRepository) + repository = resolve(PasswordUserRepository) password_encoder = resolve(PasswordEncoder) email = command.email - user = await repository.get_by_email(email) + password_user = await repository.get_by_email(email) - if user is None: - raise UserDoesNotExist(email) + if password_user is None: + raise AccountDoesNotExist(email) - user.update_password(password_encoder.hash(command.password)) - user.update_api_token(generate_api_token()) # Require new login + password_user.update_password(password_encoder.hash(command.password)) + password_user.account.update_api_token(generate_api_token()) # Require new login - await repository.update(user) + await repository.update(password_user) diff --git a/server/application/auth/queries.py b/server/application/auth/queries.py index baf4fe324..9d67106f8 100644 --- a/server/application/auth/queries.py +++ b/server/application/auth/queries.py @@ -1,17 +1,17 @@ from pydantic import EmailStr, SecretStr -from server.application.auth.views import AuthenticatedUserView, UserView +from server.application.auth.views import AccountView, AuthenticatedAccountView from server.seedwork.application.queries import Query -class Login(Query[AuthenticatedUserView]): +class LoginPasswordUser(Query[AuthenticatedAccountView]): email: EmailStr password: SecretStr -class GetUserByEmail(Query[UserView]): +class GetAccountByEmail(Query[AccountView]): email: EmailStr -class GetUserByAPIToken(Query[UserView]): +class GetAccountByAPIToken(Query[AccountView]): api_token: str diff --git a/server/application/auth/views.py b/server/application/auth/views.py index 45971db84..129eb2019 100644 --- a/server/application/auth/views.py +++ b/server/application/auth/views.py @@ -5,14 +5,14 @@ from server.domain.organizations.types import Siret -class UserView(BaseModel): +class AccountView(BaseModel): id: ID organization_siret: Siret email: str role: UserRole -class AuthenticatedUserView(BaseModel): +class AuthenticatedAccountView(BaseModel): id: ID organization_siret: Siret email: str diff --git a/server/config/di.py b/server/config/di.py index 8916a9073..37fbbee45 100644 --- a/server/config/di.py +++ b/server/config/di.py @@ -60,14 +60,17 @@ async def create_todo(...): """ from server.application.auth.passwords import PasswordEncoder -from server.domain.auth.repositories import UserRepository +from server.domain.auth.repositories import AccountRepository, PasswordUserRepository from server.domain.catalog_records.repositories import CatalogRecordRepository from server.domain.datasets.repositories import DatasetRepository from server.domain.organizations.repositories import OrganizationRepository from server.domain.tags.repositories import TagRepository from server.infrastructure.adapters.messages import MessageBusAdapter from server.infrastructure.auth.passwords import Argon2PasswordEncoder -from server.infrastructure.auth.repositories import SqlUserRepository +from server.infrastructure.auth.repositories import ( + SqlAccountRepository, + SqlPasswordUserRepository, +) from server.infrastructure.catalog_records.repositories import ( SqlCatalogRecordRepository, ) @@ -133,7 +136,8 @@ def configure(container: "Container") -> None: # Repositories - container.register_instance(UserRepository, SqlUserRepository(db)) + container.register_instance(AccountRepository, SqlAccountRepository(db)) + container.register_instance(PasswordUserRepository, SqlPasswordUserRepository(db)) container.register_instance(CatalogRecordRepository, SqlCatalogRecordRepository(db)) container.register_instance(DatasetRepository, SqlDatasetRepository(db)) container.register_instance(TagRepository, SqlTagRepository(db)) diff --git a/server/domain/auth/entities.py b/server/domain/auth/entities.py index 1ea333d26..daf76a3f1 100644 --- a/server/domain/auth/entities.py +++ b/server/domain/auth/entities.py @@ -11,19 +11,24 @@ class UserRole(enum.Enum): ADMIN = "ADMIN" -class User(Entity): +class Account(Entity): id: ID organization_siret: Siret email: str - password_hash: str role: UserRole api_token: str - def update_password(self, password_hash: str) -> None: - self.password_hash = password_hash - def update_api_token(self, api_token: str) -> None: self.api_token = api_token class Config: orm_mode = True + + +class PasswordUser(Entity): + account_id: ID + account: Account + password_hash: str + + def update_password(self, password_hash: str) -> None: + self.password_hash = password_hash diff --git a/server/domain/auth/exceptions.py b/server/domain/auth/exceptions.py index f73ca226a..459a1ed6b 100644 --- a/server/domain/auth/exceptions.py +++ b/server/domain/auth/exceptions.py @@ -1,8 +1,8 @@ from ..common.exceptions import DoesNotExist -class UserDoesNotExist(DoesNotExist): - entity_name = "User" +class AccountDoesNotExist(DoesNotExist): + entity_name = "Account" class EmailAlreadyExists(Exception): diff --git a/server/domain/auth/repositories.py b/server/domain/auth/repositories.py index 210237be0..c58ba3424 100644 --- a/server/domain/auth/repositories.py +++ b/server/domain/auth/repositories.py @@ -3,23 +3,28 @@ from server.seedwork.domain.repositories import Repository from ..common.types import ID, id_factory -from .entities import User +from .entities import Account, PasswordUser -class UserRepository(Repository): - def make_id(self) -> ID: - return id_factory() +class AccountRepository(Repository): + async def get_by_email(self, email: str) -> Optional[Account]: + raise NotImplementedError # pragma: no cover - async def get_by_email(self, email: str) -> Optional[User]: + async def get_by_api_token(self, api_token: str) -> Optional[Account]: raise NotImplementedError # pragma: no cover - async def get_by_api_token(self, api_token: str) -> Optional[User]: + +class PasswordUserRepository(Repository): + def make_id(self) -> ID: + return id_factory() + + async def get_by_email(self, email: str) -> Optional[PasswordUser]: raise NotImplementedError # pragma: no cover - async def insert(self, entity: User) -> ID: + async def insert(self, entity: PasswordUser) -> ID: raise NotImplementedError # pragma: no cover - async def update(self, entity: User) -> None: + async def update(self, entity: PasswordUser) -> None: raise NotImplementedError # pragma: no cover async def delete(self, id: ID) -> None: diff --git a/server/infrastructure/auth/models.py b/server/infrastructure/auth/models.py index 2fba479f6..f9f5b40aa 100644 --- a/server/infrastructure/auth/models.py +++ b/server/infrastructure/auth/models.py @@ -1,5 +1,4 @@ -import uuid -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from sqlalchemy import CHAR, Column, Enum, ForeignKey, String from sqlalchemy.dialects.postgresql import UUID @@ -7,6 +6,7 @@ from server.application.auth.passwords import API_TOKEN_LENGTH from server.domain.auth.entities import UserRole +from server.domain.common.types import ID from server.domain.organizations.types import Siret from ..database import Base @@ -15,10 +15,14 @@ from ..organizations.models import OrganizationModel -class UserModel(Base): - __tablename__ = "user" +class AccountModel(Base): + """ + Store information common to all user accounts. + """ - id: uuid.UUID = Column(UUID(as_uuid=True), primary_key=True) + __tablename__ = "account" + + id: ID = Column(UUID(as_uuid=True), primary_key=True) organization_siret: Siret = Column( CHAR(14), ForeignKey("organization.siret"), @@ -26,9 +30,28 @@ class UserModel(Base): ) organization: "OrganizationModel" = relationship( "OrganizationModel", - back_populates="users", + back_populates="accounts", + ) + email: str = Column(String, nullable=False, unique=True, index=True) + role: UserRole = Column(Enum(UserRole, name="user_role_enum"), nullable=False) + api_token: str = Column(String(API_TOKEN_LENGTH), nullable=False) + + password_user: Optional["PasswordUserModel"] = relationship( + "PasswordUserModel", back_populates="account", uselist=False + ) + + +class PasswordUserModel(Base): + """ + Store information specific to users that authenticate with email/password. + """ + + __tablename__ = "password_user" + + account_id: ID = Column( + ForeignKey("account.id", ondelete="CASCADE"), primary_key=True + ) + account: "AccountModel" = relationship( + "AccountModel", back_populates="password_user", cascade="delete" ) - email = Column(String, nullable=False, unique=True, index=True) - password_hash = Column(String, nullable=False) - role = Column(Enum(UserRole, name="user_role_enum"), nullable=False) - api_token = Column(String(API_TOKEN_LENGTH), nullable=False) + password_hash: str = Column(String, nullable=False) diff --git a/server/infrastructure/auth/module.py b/server/infrastructure/auth/module.py index d6dee26c5..3675c6ecc 100644 --- a/server/infrastructure/auth/module.py +++ b/server/infrastructure/auth/module.py @@ -1,25 +1,33 @@ -from server.application.auth.commands import ChangePassword, CreateUser, DeleteUser +from server.application.auth.commands import ( + ChangePassword, + CreatePasswordUser, + DeletePasswordUser, +) from server.application.auth.handlers import ( change_password, - create_user, - delete_user, - get_user_by_api_token, - get_user_by_email, - login, + create_password_user, + delete_password_user, + get_account_by_api_token, + get_account_by_email, + login_password_user, +) +from server.application.auth.queries import ( + GetAccountByAPIToken, + GetAccountByEmail, + LoginPasswordUser, ) -from server.application.auth.queries import GetUserByAPIToken, GetUserByEmail, Login from server.seedwork.application.modules import Module class AuthModule(Module): command_handlers = { - CreateUser: create_user, - DeleteUser: delete_user, + CreatePasswordUser: create_password_user, + DeletePasswordUser: delete_password_user, ChangePassword: change_password, } query_handlers = { - Login: login, - GetUserByEmail: get_user_by_email, - GetUserByAPIToken: get_user_by_api_token, + LoginPasswordUser: login_password_user, + GetAccountByEmail: get_account_by_email, + GetAccountByAPIToken: get_account_by_api_token, } diff --git a/server/infrastructure/auth/repositories.py b/server/infrastructure/auth/repositories.py index 4683811e4..965b79795 100644 --- a/server/infrastructure/auth/repositories.py +++ b/server/infrastructure/auth/repositories.py @@ -1,85 +1,116 @@ from typing import Any, Optional -from sqlalchemy import delete, select +from sqlalchemy import select from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import contains_eager -from server.domain.auth.entities import User -from server.domain.auth.exceptions import UserDoesNotExist -from server.domain.auth.repositories import UserRepository +from server.domain.auth.entities import Account, PasswordUser +from server.domain.auth.exceptions import AccountDoesNotExist +from server.domain.auth.repositories import AccountRepository, PasswordUserRepository from server.domain.common.types import ID from ..database import Database -from .models import UserModel -from .transformers import update_instance +from .models import AccountModel, PasswordUserModel +from .transformers import ( + make_account_entity, + make_account_instance, + make_password_user_entity, + make_password_user_instance, + update_instance, +) -class SqlUserRepository(UserRepository): +class SqlAccountRepository(AccountRepository): def __init__(self, db: Database) -> None: self._db = db async def _maybe_get_by( - self, session: AsyncSession, **kwargs: Any - ) -> Optional[UserModel]: - whereclauses = ( - getattr(UserModel, column) == value for column, value in kwargs.items() - ) - stmt = select(UserModel).where(*whereclauses) + self, session: AsyncSession, *whereclauses: Any + ) -> Optional[AccountModel]: + stmt = select(AccountModel).where(*whereclauses) result = await session.execute(stmt) - try: - return result.scalar_one() - except NoResultFound: - return None + return result.scalar_one_or_none() - async def get_by_email(self, email: str) -> Optional[User]: + async def get_by_email(self, email: str) -> Optional[Account]: async with self._db.session() as session: - instance = await self._maybe_get_by(session, email=email) + instance = await self._maybe_get_by(session, AccountModel.email == email) + if instance is None: + return None + return make_account_entity(instance) + async def get_by_api_token(self, api_token: str) -> Optional[Account]: + async with self._db.session() as session: + instance = await self._maybe_get_by( + session, AccountModel.api_token == api_token + ) if instance is None: return None + return make_account_entity(instance) + + +class SqlPasswordUserRepository(PasswordUserRepository): + def __init__(self, db: Database) -> None: + self._db = db - return User.from_orm(instance) + async def _maybe_get_by( + self, session: AsyncSession, *whereclauses: Any + ) -> Optional[PasswordUserModel]: + stmt = ( + select(PasswordUserModel) + .join(AccountModel) + .options(contains_eager(PasswordUserModel.account)) + .where(*whereclauses) + ) + result = await session.execute(stmt) + try: + return result.scalar_one() + except NoResultFound: + return None - async def get_by_api_token(self, api_token: str) -> Optional[User]: + async def get_by_email(self, email: str) -> Optional[PasswordUser]: async with self._db.session() as session: - instance = await self._maybe_get_by(session, api_token=api_token) + instance = await self._maybe_get_by(session, AccountModel.email == email) if instance is None: return None - return User.from_orm(instance) + return make_password_user_entity(instance) - async def insert(self, entity: User) -> ID: + async def insert(self, entity: PasswordUser) -> ID: async with self._db.session() as session: - instance = UserModel( - id=entity.id, - organization_siret=entity.organization_siret, - email=entity.email, - password_hash=entity.password_hash, - role=entity.role, - api_token=entity.api_token, - ) + account_instance = make_account_instance(entity.account) + session.add(account_instance) + instance = make_password_user_instance(entity) session.add(instance) await session.commit() await session.refresh(instance) - return ID(instance.id) + return ID(instance.account_id) - async def update(self, entity: User) -> None: + async def update(self, entity: PasswordUser) -> None: async with self._db.session() as session: - instance = await self._maybe_get_by(session, id=entity.id) + instance = await self._maybe_get_by( + session, PasswordUserModel.account_id == entity.account_id + ) if instance is None: - raise UserDoesNotExist(entity.email) + raise AccountDoesNotExist(entity.account_id) update_instance(instance, entity) await session.commit() - async def delete(self, id: ID) -> None: + async def delete(self, account_id: ID) -> None: async with self._db.session() as session: - stmt = delete(UserModel).where(UserModel.id == id) - await session.execute(stmt) + instance = await self._maybe_get_by( + session, PasswordUserModel.account_id == account_id + ) + + if instance is None: + return + + await session.delete(instance) await session.commit() diff --git a/server/infrastructure/auth/transformers.py b/server/infrastructure/auth/transformers.py index ab57f2d04..6e431bbda 100644 --- a/server/infrastructure/auth/transformers.py +++ b/server/infrastructure/auth/transformers.py @@ -1,8 +1,46 @@ -from server.domain.auth.entities import User +from server.domain.auth.entities import Account, PasswordUser -from .models import UserModel +from .models import AccountModel, PasswordUserModel -def update_instance(instance: UserModel, entity: User) -> None: - for field in set(User.__fields__) - {"id"}: +def make_account_instance(entity: Account) -> AccountModel: + return AccountModel( + id=entity.id, + organization_siret=entity.organization_siret, + email=entity.email, + role=entity.role, + api_token=entity.api_token, + ) + + +def make_account_entity(instance: AccountModel) -> Account: + return Account( + id=instance.id, + organization_siret=instance.organization_siret, + email=instance.email, + role=instance.role, + api_token=instance.api_token, + ) + + +def make_password_user_instance(entity: PasswordUser) -> PasswordUserModel: + return PasswordUserModel( + account_id=entity.account_id, + password_hash=entity.password_hash, + ) + + +def make_password_user_entity(instance: PasswordUserModel) -> PasswordUser: + return PasswordUser( + account_id=instance.account_id, + account=make_account_entity(instance.account), + password_hash=instance.password_hash, + ) + + +def update_instance(instance: PasswordUserModel, entity: PasswordUser) -> None: + for field in set(PasswordUser.__fields__) - {"account_id", "account"}: setattr(instance, field, getattr(entity, field)) + + for field in set(Account.__fields__) - {"id"}: + setattr(instance, field, getattr(entity.account, field)) diff --git a/server/infrastructure/organizations/models.py b/server/infrastructure/organizations/models.py index b505c9a0d..97fadae03 100644 --- a/server/infrastructure/organizations/models.py +++ b/server/infrastructure/organizations/models.py @@ -8,7 +8,7 @@ from ..database import Base if TYPE_CHECKING: - from ..auth.repositories import UserModel + from ..auth.models import AccountModel from ..catalogs.models import CatalogModel @@ -24,7 +24,7 @@ class OrganizationModel(Base): uselist=False, ) - users: List["UserModel"] = relationship( - "UserModel", + accounts: List["AccountModel"] = relationship( + "AccountModel", back_populates="organization", ) diff --git a/server/migrations/versions/17a9b8d2f84e_split_account_password_user.py b/server/migrations/versions/17a9b8d2f84e_split_account_password_user.py new file mode 100644 index 000000000..e7eb86531 --- /dev/null +++ b/server/migrations/versions/17a9b8d2f84e_split_account_password_user.py @@ -0,0 +1,82 @@ +"""split-account-password-user + +Revision ID: 17a9b8d2f84e +Revises: f2ef4eef61e3 +Create Date: 2022-08-02 11:07:59.669511 +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "17a9b8d2f84e" +down_revision = "f2ef4eef61e3" +branch_labels = None +depends_on = None + + +def upgrade(): + # Users become accounts. + op.rename_table("user", "account") + op.execute("ALTER INDEX user_pkey RENAME TO account_pkey;") + op.execute("ALTER INDEX ix_user_email RENAME TO ix_account_email;") + op.execute( + """ + ALTER TABLE account + RENAME CONSTRAINT user_organization_siret_fkey + TO account_organization_siret_fkey; + """ + ) + + op.create_table( + "password_user", + sa.Column("account_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("password_hash", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["account_id"], + ["account.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("account_id"), + ) + + # Create password users from existing accounts. + op.execute( + """ + INSERT INTO password_user (account_id, password_hash) + SELECT account.id, account.password_hash FROM account; + """ + ) + + op.drop_column("account", "password_hash") + + +def downgrade(): + # Move password_hash back to accounts. + op.add_column( + "account", + sa.Column("password_hash", sa.String()), + ) + op.execute( + """ + UPDATE account + SET password_hash = pu.password_hash + FROM password_user AS pu + JOIN account AS acc ON pu.account_id = acc.id; + """ + ) + op.alter_column("account", "password_hash", nullable=False) + + op.drop_table("password_user") + + # Accounts become users. + op.execute("ALTER INDEX account_pkey RENAME TO user_pkey;") + op.execute("ALTER INDEX ix_account_email RENAME TO ix_user_email;") + op.execute( + """ + ALTER TABLE account + RENAME CONSTRAINT account_organization_siret_fkey + TO user_organization_siret_fkey; + """ + ) + op.rename_table("account", "user") diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index e18622024..204f5423d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -4,14 +4,14 @@ import pytest from pydantic import EmailStr -from server.application.auth.queries import GetUserByEmail +from server.application.auth.queries import GetAccountByEmail from server.config.di import resolve -from server.domain.auth.exceptions import UserDoesNotExist +from server.domain.auth.exceptions import AccountDoesNotExist from server.domain.common.types import id_factory from server.domain.organizations.entities import LEGACY_ORGANIZATION_SIRET from server.seedwork.application.messages import MessageBus -from ..helpers import TestUser +from ..helpers import TestPasswordUser @pytest.mark.asyncio @@ -55,7 +55,7 @@ ) async def test_create_user_invalid( client: httpx.AsyncClient, - admin_user: TestUser, + admin_user: TestPasswordUser, payload: dict, expected_errors_attrs: List[dict], ) -> None: @@ -72,7 +72,7 @@ async def test_create_user_invalid( @pytest.mark.asyncio async def test_create_user( - client: httpx.AsyncClient, temp_user: TestUser, admin_user: TestUser + client: httpx.AsyncClient, temp_user: TestPasswordUser, admin_user: TestPasswordUser ) -> None: payload = {"email": "john@doe.com", "password": "s3kr3t"} @@ -97,25 +97,25 @@ async def test_create_user( @pytest.mark.asyncio async def test_create_user_already_exists( - client: httpx.AsyncClient, temp_user: TestUser, admin_user: TestUser + client: httpx.AsyncClient, temp_user: TestPasswordUser, admin_user: TestPasswordUser ) -> None: - payload = {"email": temp_user.email, "password": "somethingelse"} + payload = {"email": temp_user.account.email, "password": "somethingelse"} response = await client.post("/auth/users/", json=payload, auth=admin_user.auth) assert response.status_code == 400 @pytest.mark.asyncio -async def test_login(client: httpx.AsyncClient, temp_user: TestUser) -> None: - payload = {"email": temp_user.email, "password": temp_user.password} +async def test_login(client: httpx.AsyncClient, temp_user: TestPasswordUser) -> None: + payload = {"email": temp_user.account.email, "password": temp_user.password} response = await client.post("/auth/login/", json=payload) assert response.status_code == 200 user = response.json() assert user == { - "id": str(temp_user.id), + "id": str(temp_user.account_id), "organization_siret": str(LEGACY_ORGANIZATION_SIRET), - "email": temp_user.email, - "role": temp_user.role.value, - "api_token": temp_user.api_token, + "email": temp_user.account.email, + "role": temp_user.account.role.value, + "api_token": temp_user.account.api_token, } @@ -128,10 +128,10 @@ async def test_login(client: httpx.AsyncClient, temp_user: TestUser) -> None: ], ) async def test_login_failed( - client: httpx.AsyncClient, email: str, password: str, temp_user: TestUser + client: httpx.AsyncClient, email: str, password: str, temp_user: TestPasswordUser ) -> None: payload = { - "email": email.format(email=temp_user.email), + "email": email.format(email=temp_user.account.email), "password": password.format(password=temp_user.password), } response = await client.post("/auth/login/", json=payload) @@ -141,7 +141,7 @@ async def test_login_failed( @pytest.mark.asyncio -async def test_check(client: httpx.AsyncClient, temp_user: TestUser) -> None: +async def test_check(client: httpx.AsyncClient, temp_user: TestPasswordUser) -> None: response = await client.get("/auth/check/", auth=temp_user.auth) assert response.status_code == 200 @@ -158,11 +158,11 @@ async def test_check(client: httpx.AsyncClient, temp_user: TestUser) -> None: ], ) async def test_check_failed( - client: httpx.AsyncClient, temp_user: TestUser, headers: dict + client: httpx.AsyncClient, temp_user: TestPasswordUser, headers: dict ) -> None: if "Authorization" in headers: headers["Authorization"] = headers["Authorization"].format( - api_token=temp_user.api_token + api_token=temp_user.account.api_token ) response = await client.get("/auth/check/", headers=headers) @@ -173,27 +173,31 @@ async def test_check_failed( @pytest.mark.asyncio async def test_delete_user( - client: httpx.AsyncClient, temp_user: TestUser, admin_user: TestUser + client: httpx.AsyncClient, temp_user: TestPasswordUser, admin_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) # Permissions - response = await client.delete(f"/auth/users/{temp_user.id}/") + response = await client.delete(f"/auth/users/{temp_user.account_id}/") assert response.status_code == 401 - response = await client.delete(f"/auth/users/{temp_user.id}/", auth=temp_user.auth) + response = await client.delete( + f"/auth/users/{temp_user.account_id}/", auth=temp_user.auth + ) assert response.status_code == 403 - response = await client.delete(f"/auth/users/{temp_user.id}/", auth=admin_user.auth) + response = await client.delete( + f"/auth/users/{temp_user.account_id}/", auth=admin_user.auth + ) assert response.status_code == 204 - query = GetUserByEmail(email=EmailStr(temp_user.email)) - with pytest.raises(UserDoesNotExist): + query = GetAccountByEmail(email=EmailStr(temp_user.account.email)) + with pytest.raises(AccountDoesNotExist): await bus.execute(query) @pytest.mark.asyncio async def test_delete_user_idempotent( - client: httpx.AsyncClient, admin_user: TestUser + client: httpx.AsyncClient, admin_user: TestPasswordUser ) -> None: # Represents a non-existing user, or a user previously deleted. # These should be handled the same way as existing users by diff --git a/tests/api/test_datasets.py b/tests/api/test_datasets.py index c7fae8281..00833419c 100644 --- a/tests/api/test_datasets.py +++ b/tests/api/test_datasets.py @@ -16,7 +16,7 @@ from tests.factories import CreateDatasetFactory from ..factories import UpdateDatasetFactory, fake -from ..helpers import TestUser, to_payload +from ..helpers import TestPasswordUser, to_payload @pytest.mark.asyncio @@ -76,7 +76,7 @@ ) async def test_create_dataset_invalid( client: httpx.AsyncClient, - temp_user: TestUser, + temp_user: TestPasswordUser, payload: dict, expected_errors_attrs: list, ) -> None: @@ -93,7 +93,7 @@ async def test_create_dataset_invalid( @pytest.mark.asyncio async def test_dataset_crud( - client: httpx.AsyncClient, temp_user: TestUser, admin_user: TestUser + client: httpx.AsyncClient, temp_user: TestPasswordUser, admin_user: TestPasswordUser ) -> None: last_updated_at = fake.date_time_tz() @@ -197,7 +197,7 @@ async def test_delete_not_authenticated(self, client: httpx.AsyncClient) -> None assert response.status_code == 401 async def test_delete_not_admin( - self, client: httpx.AsyncClient, temp_user: TestUser + self, client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: pk = id_factory() response = await client.delete(f"/datasets/{pk}/", auth=temp_user.auth) @@ -257,7 +257,7 @@ async def add_dataset_pagination_corpus(n: int, tags: list) -> None: ) async def test_dataset_pagination( client: httpx.AsyncClient, - temp_user: TestUser, + temp_user: TestPasswordUser, tags: list, params: dict, expected_total_pages: int, @@ -280,7 +280,7 @@ async def test_dataset_pagination( @pytest.mark.asyncio async def test_dataset_get_all_uses_reverse_chronological_order( - client: httpx.AsyncClient, temp_user: TestUser + client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) await bus.execute(CreateDatasetFactory.build(title="Oldest")) @@ -306,7 +306,11 @@ class TestDatasetOptionalFields: ], ) async def test_optional_fields_missing_uses_defaults( - self, client: httpx.AsyncClient, temp_user: TestUser, field: str, default: Any + self, + client: httpx.AsyncClient, + temp_user: TestPasswordUser, + field: str, + default: Any, ) -> None: payload = to_payload(CreateDatasetFactory.build()) payload.pop(field) @@ -316,7 +320,7 @@ async def test_optional_fields_missing_uses_defaults( assert dataset[field] == default async def test_optional_fields_invalid( - self, client: httpx.AsyncClient, temp_user: TestUser + self, client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: response = await client.post( "/datasets/", @@ -345,7 +349,7 @@ async def test_optional_fields_invalid( @pytest.mark.asyncio class TestDatasetUpdate: async def test_not_found( - self, client: httpx.AsyncClient, temp_user: TestUser + self, client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: pk = id_factory() response = await client.put( @@ -356,7 +360,7 @@ async def test_not_found( assert response.status_code == 404 async def test_full_entity_expected( - self, client: httpx.AsyncClient, temp_user: TestUser + self, client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) dataset_id = await bus.execute(CreateDatasetFactory.build()) @@ -388,7 +392,7 @@ async def test_full_entity_expected( assert error["type"] == "value_error.missing", field async def test_fields_empty_invalid( - self, client: httpx.AsyncClient, temp_user: TestUser + self, client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) @@ -432,7 +436,9 @@ async def test_fields_empty_invalid( assert err_url["loc"] == ["body", "url"] assert "empty" in err_service["msg"] - async def test_update(self, client: httpx.AsyncClient, temp_user: TestUser) -> None: + async def test_update( + self, client: httpx.AsyncClient, temp_user: TestPasswordUser + ) -> None: bus = resolve(MessageBus) dataset_id = await bus.execute(CreateDatasetFactory.build()) @@ -505,7 +511,7 @@ async def test_update(self, client: httpx.AsyncClient, temp_user: TestUser) -> N @pytest.mark.asyncio class TestFormats: async def test_formats_add( - self, client: httpx.AsyncClient, temp_user: TestUser + self, client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) command = CreateDatasetFactory.build( @@ -528,7 +534,7 @@ async def test_formats_add( assert sorted(response.json()["formats"]) == ["api", "file_gis", "website"] async def test_formats_remove( - self, client: httpx.AsyncClient, temp_user: TestUser + self, client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) command = CreateDatasetFactory.build( @@ -554,7 +560,7 @@ async def test_formats_remove( @pytest.mark.asyncio class TestTags: async def test_tags_add( - self, client: httpx.AsyncClient, temp_user: TestUser + self, client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) @@ -582,7 +588,7 @@ async def test_tags_add( assert dataset.tags == [tag_architecture] async def test_tags_remove( - self, client: httpx.AsyncClient, temp_user: TestUser + self, client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) @@ -610,7 +616,7 @@ async def test_tags_remove( @pytest.mark.asyncio class TestDeleteDataset: async def test_delete( - self, client: httpx.AsyncClient, admin_user: TestUser + self, client: httpx.AsyncClient, admin_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) @@ -623,7 +629,7 @@ async def test_delete( await bus.execute(GetDatasetByID(id=dataset_id)) async def test_idempotent( - self, client: httpx.AsyncClient, admin_user: TestUser + self, client: httpx.AsyncClient, admin_user: TestPasswordUser ) -> None: # Repeated calls on a deleted (or non-existing) resource should be fine. dataset_id = id_factory() diff --git a/tests/api/test_datasets_filters.py b/tests/api/test_datasets_filters.py index a78250017..2e4117b5c 100644 --- a/tests/api/test_datasets_filters.py +++ b/tests/api/test_datasets_filters.py @@ -10,12 +10,12 @@ from server.seedwork.application.messages import MessageBus from ..factories import CreateDatasetFactory, CreateTagFactory -from ..helpers import TestUser +from ..helpers import TestPasswordUser @pytest.mark.asyncio async def test_dataset_filters_info( - client: httpx.AsyncClient, temp_user: TestUser + client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) @@ -142,7 +142,7 @@ class _Env: ) async def test_dataset_filters_apply( client: httpx.AsyncClient, - temp_user: TestUser, + temp_user: TestPasswordUser, filtername: str, create_kwargs: Callable[[_Env], dict], positive_value: Callable[[_Env], list], @@ -171,7 +171,7 @@ async def test_dataset_filters_apply( @pytest.mark.asyncio async def test_dataset_filters_license_any( - client: httpx.AsyncClient, temp_user: TestUser + client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) diff --git a/tests/api/test_datasets_search.py b/tests/api/test_datasets_search.py index 9901c27a0..5e21ad017 100644 --- a/tests/api/test_datasets_search.py +++ b/tests/api/test_datasets_search.py @@ -10,7 +10,7 @@ from server.seedwork.application.messages import MessageBus from tests.factories import CreateDatasetFactory, UpdateDatasetFactory -from ..helpers import TestUser +from ..helpers import TestPasswordUser DEFAULT_CORPUS_ITEMS = [ ("Inventaire national forestier", "Ensemble des forĂȘts de France"), @@ -84,7 +84,10 @@ async def add_corpus(items: List[Tuple[str, str]] = None) -> None: ], ) async def test_search( - client: httpx.AsyncClient, temp_user: TestUser, q: str, expected_titles: List[str] + client: httpx.AsyncClient, + temp_user: TestPasswordUser, + q: str, + expected_titles: List[str], ) -> None: await add_corpus() @@ -116,7 +119,7 @@ async def test_search( ], ) async def test_search_robustness( - client: httpx.AsyncClient, temp_user: TestUser, q_ref: str, q_other: str + client: httpx.AsyncClient, temp_user: TestPasswordUser, q_ref: str, q_other: str ) -> None: await add_corpus() @@ -145,7 +148,7 @@ async def test_search_robustness( @pytest.mark.asyncio async def test_search_results_change_when_data_changes( client: httpx.AsyncClient, - temp_user: TestUser, + temp_user: TestPasswordUser, ) -> None: await add_corpus() @@ -219,7 +222,9 @@ async def test_search_results_change_when_data_changes( @pytest.mark.asyncio -async def test_search_ranking(client: httpx.AsyncClient, temp_user: TestUser) -> None: +async def test_search_ranking( + client: httpx.AsyncClient, temp_user: TestPasswordUser +) -> None: items = [ ("A", "..."), ("B", "ForĂȘt nouvelle"), @@ -271,7 +276,7 @@ async def test_search_ranking(client: httpx.AsyncClient, temp_user: TestUser) -> ) async def test_search_highlight( client: httpx.AsyncClient, - temp_user: TestUser, + temp_user: TestPasswordUser, corpus: list, q: str, expected_headlines: Optional[dict], diff --git a/tests/api/test_licenses.py b/tests/api/test_licenses.py index 7ead6a629..4679bb393 100644 --- a/tests/api/test_licenses.py +++ b/tests/api/test_licenses.py @@ -5,11 +5,13 @@ from server.seedwork.application.messages import MessageBus from ..factories import CreateDatasetFactory -from ..helpers import TestUser +from ..helpers import TestPasswordUser @pytest.mark.asyncio -async def test_license_list(client: httpx.AsyncClient, temp_user: TestUser) -> None: +async def test_license_list( + client: httpx.AsyncClient, temp_user: TestPasswordUser +) -> None: bus = resolve(MessageBus) response = await client.get("/licenses/", auth=temp_user.auth) diff --git a/tests/api/test_organizations.py b/tests/api/test_organizations.py index 3b5c92b7f..a441996f7 100644 --- a/tests/api/test_organizations.py +++ b/tests/api/test_organizations.py @@ -2,7 +2,7 @@ import pytest from ..factories import CreateOrganizationFactory -from ..helpers import TestUser, to_payload +from ..helpers import TestPasswordUser, to_payload def api_key_auth(request: httpx.Request) -> httpx.Request: @@ -39,7 +39,7 @@ def api_key_auth(request: httpx.Request) -> httpx.Request: ) async def test_create_organization_invalid( client: httpx.AsyncClient, - temp_user: TestUser, + temp_user: TestPasswordUser, payload: dict, expected_errors_attrs: list, ) -> None: @@ -104,7 +104,7 @@ async def test_create_anonymous(self, client: httpx.AsyncClient) -> None: assert response.status_code == 403 async def test_create_authenticated( - self, client: httpx.AsyncClient, temp_user: TestUser + self, client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: response = await client.post( "/organizations/", diff --git a/tests/api/test_tags.py b/tests/api/test_tags.py index 99b5c99a3..97a4f930b 100644 --- a/tests/api/test_tags.py +++ b/tests/api/test_tags.py @@ -5,11 +5,13 @@ from server.config.di import resolve from server.seedwork.application.messages import MessageBus -from ..helpers import TestUser +from ..helpers import TestPasswordUser @pytest.mark.asyncio -async def test_tags_list(client: httpx.AsyncClient, temp_user: TestUser) -> None: +async def test_tags_list( + client: httpx.AsyncClient, temp_user: TestPasswordUser +) -> None: bus = resolve(MessageBus) response = await client.get("/tags/", auth=temp_user.auth) diff --git a/tests/application/test_auth.py b/tests/application/test_auth.py index 330f6f463..15c9605f1 100644 --- a/tests/application/test_auth.py +++ b/tests/application/test_auth.py @@ -1,8 +1,8 @@ import pytest from pydantic import EmailStr, SecretStr -from server.application.auth.commands import ChangePassword, CreateUser -from server.application.auth.queries import Login +from server.application.auth.commands import ChangePassword, CreatePasswordUser +from server.application.auth.queries import LoginPasswordUser from server.config.di import resolve from server.domain.auth.exceptions import LoginFailed from server.seedwork.application.messages import MessageBus @@ -13,11 +13,13 @@ async def test_changepassword() -> None: bus = resolve(MessageBus) email = EmailStr("changepassworduser@mydomain.org") - await bus.execute(CreateUser(email=email, password=SecretStr("initialpwd"))) + await bus.execute(CreatePasswordUser(email=email, password=SecretStr("initialpwd"))) await bus.execute(ChangePassword(email=email, password=SecretStr("newpwd"))) with pytest.raises(LoginFailed): - await bus.execute(Login(email=email, password=SecretStr("initialpwd"))) + await bus.execute( + LoginPasswordUser(email=email, password=SecretStr("initialpwd")) + ) - await bus.execute(Login(email=email, password=SecretStr("newpwd"))) + await bus.execute(LoginPasswordUser(email=email, password=SecretStr("newpwd"))) diff --git a/tests/conftest.py b/tests/conftest.py index 36141de1d..d8cd40942 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ from server.seedwork.application.messages import MessageBus from tests.factories import CreateTagFactory -from .helpers import TestUser, create_client, create_test_user +from .helpers import TestPasswordUser, create_client, create_test_password_user if TYPE_CHECKING: from server.api.app import App @@ -106,10 +106,10 @@ async def client(app: "App") -> AsyncIterator[httpx.AsyncClient]: @pytest_asyncio.fixture(name="temp_user") -async def fixture_temp_user() -> TestUser: - return await create_test_user(UserRole.USER) +async def fixture_temp_user() -> TestPasswordUser: + return await create_test_password_user(UserRole.USER) @pytest_asyncio.fixture -async def admin_user() -> TestUser: - return await create_test_user(UserRole.ADMIN) +async def admin_user() -> TestPasswordUser: + return await create_test_password_user(UserRole.ADMIN) diff --git a/tests/factories.py b/tests/factories.py index 27c9be432..b3ea929cd 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -7,7 +7,7 @@ from pydantic import BaseModel from pydantic_factories import ModelFactory, Use -from server.application.auth.commands import CreateUser +from server.application.auth.commands import CreatePasswordUser from server.application.datasets.commands import CreateDataset, UpdateDataset from server.application.organizations.commands import CreateOrganization from server.application.tags.commands import CreateTag @@ -38,8 +38,8 @@ def get_mock_value(cls, field_type: Any) -> Any: return super().get_mock_value(field_type) -class CreateUserFactory(Factory[CreateUser]): - __model__ = CreateUser +class CreatePasswordUserFactory(Factory[CreatePasswordUser]): + __model__ = CreatePasswordUser organization_siret = Use(lambda: LEGACY_ORGANIZATION_SIRET) diff --git a/tests/helpers.py b/tests/helpers.py index 33bb1ed9d..4a1c65a2f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -5,11 +5,11 @@ from pydantic import BaseModel from server.config.di import resolve -from server.domain.auth.entities import User, UserRole -from server.domain.auth.repositories import UserRepository +from server.domain.auth.entities import PasswordUser, UserRole +from server.domain.auth.repositories import PasswordUserRepository from server.seedwork.application.messages import MessageBus -from .factories import CreateUserFactory +from .factories import CreatePasswordUserFactory def create_client(app: Callable) -> httpx.AsyncClient: @@ -23,7 +23,7 @@ def to_payload(obj: BaseModel) -> dict: return json.loads(obj.json()) -class TestUser(User): +class TestPasswordUser(PasswordUser): """ A user that exposes the plaintext password for testing purposes. """ @@ -38,18 +38,18 @@ def auth(self, request: httpx.Request) -> httpx.Request: Usage: response = client.post(..., auth=test_user.auth) """ - request.headers["Authorization"] = f"Bearer {self.api_token}" + request.headers["Authorization"] = f"Bearer {self.account.api_token}" return request -async def create_test_user(role: UserRole) -> TestUser: +async def create_test_password_user(role: UserRole) -> TestPasswordUser: bus = resolve(MessageBus) - user_repository = resolve(UserRepository) + password_user_repository = resolve(PasswordUserRepository) - command = CreateUserFactory.build() + command = CreatePasswordUserFactory.build() await bus.execute(command, role=role) - user = await user_repository.get_by_email(command.email) + user = await password_user_repository.get_by_email(command.email) assert user is not None - return TestUser(**user.dict(), password=command.password.get_secret_value()) + return TestPasswordUser(**user.dict(), password=command.password.get_secret_value()) diff --git a/tests/infrastructure/test_catalogs.py b/tests/infrastructure/test_catalogs.py index 0b59a6c66..b1e512c27 100644 --- a/tests/infrastructure/test_catalogs.py +++ b/tests/infrastructure/test_catalogs.py @@ -1,7 +1,7 @@ import pytest from pydantic import EmailStr -from server.application.auth.queries import GetUserByEmail +from server.application.auth.queries import GetAccountByEmail from server.application.datasets.queries import GetDatasetByID from server.config.di import resolve from server.domain.organizations.types import Siret @@ -10,7 +10,7 @@ from server.infrastructure.organizations.repositories import OrganizationModel from server.seedwork.application.messages import MessageBus -from ..factories import CreateDatasetFactory, CreateUserFactory +from ..factories import CreateDatasetFactory, CreatePasswordUserFactory @pytest.mark.asyncio @@ -32,10 +32,12 @@ async def test_catalog_creation_and_relationships() -> None: # Add a user to the organization... email = "test@mydomain.org" - await bus.execute(CreateUserFactory.build(organization_siret=siret, email=email)) + await bus.execute( + CreatePasswordUserFactory.build(organization_siret=siret, email=email) + ) - user = await bus.execute(GetUserByEmail(email=EmailStr("test@mydomain.org"))) - assert user.organization_siret == siret + account = await bus.execute(GetAccountByEmail(email=EmailStr("test@mydomain.org"))) + assert account.organization_siret == siret # Add a dataset to the catalog... dataset_id = await bus.execute(CreateDatasetFactory.build(organization_siret=siret)) diff --git a/tests/tools/test_initdata.py b/tests/tools/test_initdata.py index aa01505ab..11148af48 100644 --- a/tests/tools/test_initdata.py +++ b/tests/tools/test_initdata.py @@ -5,8 +5,8 @@ import pytest from pydantic import EmailStr, SecretStr -from server.application.auth.commands import DeleteUser -from server.application.auth.queries import Login +from server.application.auth.commands import DeletePasswordUser +from server.application.auth.queries import LoginPasswordUser from server.application.datasets.commands import UpdateDataset from server.application.datasets.queries import GetAllDatasets, GetDatasetByID from server.config.di import resolve @@ -85,12 +85,14 @@ async def test_initdata_env_password( monkeypatch.setenv("TOOLS_PASSWORDS", json.dumps({"test@admin.org": "testpwd"})) await initdata.main(path, no_input=True) - user = await bus.execute( - Login(email=EmailStr("test@admin.org"), password=SecretStr("testpwd")) + account = await bus.execute( + LoginPasswordUser( + email=EmailStr("test@admin.org"), password=SecretStr("testpwd") + ) ) # (Delete user to prevent email collision below.) - await bus.execute(DeleteUser(id=user.id)) + await bus.execute(DeletePasswordUser(account_id=account.id)) # If not set, it would be prompted in the terminal. monkeypatch.delenv("TOOLS_PASSWORDS") diff --git a/tests/unit/test_permissions.py b/tests/unit/test_permissions.py index 589ba30ee..3a61ccd76 100644 --- a/tests/unit/test_permissions.py +++ b/tests/unit/test_permissions.py @@ -17,11 +17,11 @@ from server.api.types import APIRequest from server.domain.auth.entities import UserRole -from ..helpers import TestUser +from ..helpers import TestPasswordUser @pytest.mark.asyncio -async def test_is_authenticated(temp_user: TestUser) -> None: +async def test_is_authenticated(temp_user: TestPasswordUser) -> None: app = FastAPI() app.add_middleware(AuthMiddleware, backend=TokenAuthBackend()) @@ -34,13 +34,15 @@ async def index() -> str: response = await client.get("/") assert response.status_code == 401 - headers = {"Authorization": f"Bearer {temp_user.api_token}"} + headers = {"Authorization": f"Bearer {temp_user.account.api_token}"} response = await client.get("/", headers=headers) assert response.status_code == 200 @pytest.mark.asyncio -async def test_has_role(temp_user: TestUser, admin_user: TestUser) -> None: +async def test_has_role( + temp_user: TestPasswordUser, admin_user: TestPasswordUser +) -> None: app = FastAPI() app.add_middleware(AuthMiddleware, backend=TokenAuthBackend()) @@ -53,11 +55,11 @@ async def index() -> str: response = await client.get("/") assert response.status_code == 401 - headers = {"Authorization": f"Bearer {temp_user.api_token}"} + headers = {"Authorization": f"Bearer {temp_user.account.api_token}"} response = await client.get("/", headers=headers) assert response.status_code == 403 - headers = {"Authorization": f"Bearer {admin_user.api_token}"} + headers = {"Authorization": f"Bearer {admin_user.account.api_token}"} response = await client.get("/", headers=headers) assert response.status_code == 200 diff --git a/tools/changepassword.py b/tools/changepassword.py index 2dfc8ab32..5e31ac9f1 100644 --- a/tools/changepassword.py +++ b/tools/changepassword.py @@ -6,23 +6,23 @@ from server.application.auth.commands import ChangePassword from server.config.di import bootstrap, resolve -from server.domain.auth.entities import User -from server.domain.auth.repositories import UserRepository +from server.domain.auth.entities import PasswordUser +from server.domain.auth.repositories import PasswordUserRepository from server.seedwork.application.messages import MessageBus -async def _prompt_user() -> User: - repository = resolve(UserRepository) +async def _prompt_password_user() -> PasswordUser: + repository = resolve(PasswordUserRepository) email = click.prompt("Email") - user = await repository.get_by_email(email) + password_user = await repository.get_by_email(email) - if user is None: - click.echo(click.style(f"User does not exist: {email}", fg="red")) + if password_user is None: + click.echo(click.style(f"PasswordUser does not exist: {email}", fg="red")) sys.exit(1) - return user + return password_user def _prompt_password() -> SecretStr: @@ -37,10 +37,12 @@ def _prompt_password() -> SecretStr: async def main() -> None: bus = resolve(MessageBus) - user = await _prompt_user() + user = await _prompt_password_user() password = _prompt_password() - await bus.execute(ChangePassword(email=EmailStr(user.email), password=password)) + await bus.execute( + ChangePassword(email=EmailStr(user.account.email), password=password) + ) if __name__ == "__main__": diff --git a/tools/initdata.py b/tools/initdata.py index f7c910a49..0dfeab660 100644 --- a/tools/initdata.py +++ b/tools/initdata.py @@ -11,12 +11,12 @@ from dotenv import load_dotenv from pydantic import BaseModel, ValidationError, parse_raw_as -from server.application.auth.commands import CreateUser +from server.application.auth.commands import CreatePasswordUser from server.application.datasets.commands import CreateDataset, UpdateDataset from server.application.tags.commands import CreateTag from server.config.di import bootstrap, resolve from server.domain.auth.entities import UserRole -from server.domain.auth.repositories import UserRepository +from server.domain.auth.repositories import PasswordUserRepository from server.domain.datasets.entities import Dataset from server.domain.datasets.repositories import DatasetRepository from server.domain.tags.repositories import TagRepository @@ -51,13 +51,13 @@ async def handle_user( item: dict, *, no_input: bool, env_passwords: Dict[str, str] ) -> None: bus = resolve(MessageBus) - repository = resolve(UserRepository) + repository = resolve(PasswordUserRepository) email = item["params"]["email"] existing_user = await repository.get_by_email(email) if existing_user is not None: - print(f"{info('ok')}: User(email={email!r}, ...)") + print(f"{info('ok')}: PasswordUser(email={email!r}, ...)") return extras = UserExtras(**item.get("extras", {})) @@ -74,7 +74,7 @@ async def handle_user( password = click.prompt(f"Password for {email}", hide_input=True) item["params"]["password"] = password - command = CreateUser(**item["params"]) + command = CreatePasswordUser(**item["params"]) await bus.execute(command, id_=item["id"], **extras.dict()) print(f"{success('created')}: {command!r}")