Skip to content

Commit

Permalink
Replace UuidModelUnion with ModelUuidProtocol
Browse files Browse the repository at this point in the history
  • Loading branch information
jace committed Dec 30, 2023
1 parent 760103e commit 59f225f
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 31 deletions.
25 changes: 13 additions & 12 deletions funnel/models/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
)
from .account import Account, AccountEmail, AccountPhone
from .phone_number import PhoneNumber, PhoneNumberMixin
from .typing import UuidModelUnion
from .typing import ModelUuidProtocol

__all__ = [
'SMS_STATUS',
Expand All @@ -151,9 +151,9 @@
# --- Typing ---------------------------------------------------------------------------

# Document generic type
_D = TypeVar('_D', bound=UuidModelUnion)
_D = TypeVar('_D', bound=ModelUuidProtocol)
# Fragment generic type
_F = TypeVar('_F', bound=UuidModelUnion | None)
_F = TypeVar('_F', bound=ModelUuidProtocol | None)
# Type of None (required to detect Optional)
NoneType = type(None)

Expand Down Expand Up @@ -334,12 +334,12 @@ class Notification(NoIdMixin, Model, Generic[_D, _F]):
pref_type: ClassVar[str] = ''

#: Document model, must be specified in subclasses
document_model: ClassVar[type[UuidModelUnion]]
document_model: ClassVar[type[ModelUuidProtocol]]
#: SQL table name for document type, auto-populated from the document model
document_type: ClassVar[str]

#: Fragment model, optional for subclasses
fragment_model: ClassVar[type[UuidModelUnion] | None] = None
fragment_model: ClassVar[type[ModelUuidProtocol] | None] = None
#: SQL table name for fragment type, auto-populated from the fragment model
fragment_type: ClassVar[str | None]

Expand Down Expand Up @@ -557,7 +557,7 @@ def __init__(
# pylint: disable=isinstance-second-argument-not-valid-type
if not isinstance(fragment, self.fragment_model):
raise TypeError(f"{fragment!r} is not of type {self.fragment_model!r}")
kwargs['fragment_uuid'] = fragment.uuid # type: ignore[union-attr]
kwargs['fragment_uuid'] = fragment.uuid
super().__init__(**kwargs)

@property
Expand Down Expand Up @@ -638,9 +638,9 @@ def allow_transport(cls, transport: str) -> bool:
return getattr(cls, 'allow_' + transport)

@property
def role_provider_obj(self) -> _F | _D:
def role_provider_obj(self) -> ModelUuidProtocol:
"""Return fragment if exists, document otherwise, indicating role provider."""
return self.fragment or self.document
return self.fragment or self.document # type: ignore[return-value] # FIXME

def dispatch(self) -> Generator[NotificationRecipient, None, None]:
"""
Expand All @@ -655,6 +655,7 @@ def dispatch(self) -> Generator[NotificationRecipient, None, None]:
Subclasses wanting more control over how their notifications are dispatched
should override this method.
"""

for account, role in self.role_provider_obj.actors_with(
self.roles, with_role=True
):
Expand Down Expand Up @@ -718,8 +719,8 @@ class PreviewNotification(NotificationType):
def __init__( # pylint: disable=super-init-not-called
self,
cls: type[Notification],
document: UuidModelUnion,
fragment: UuidModelUnion | None = None,
document: ModelUuidProtocol,
fragment: ModelUuidProtocol | None = None,
user: Account | None = None,
) -> None:
self.eventid = uuid4()
Expand Down Expand Up @@ -764,14 +765,14 @@ def notification_pref_type(self) -> str:
with_roles(notification_pref_type, read={'owner'})

@cached_property
def document(self) -> UuidModelUnion | None:
def document(self) -> ModelUuidProtocol | None:
"""Document that this notification is for."""
return self.notification.document

with_roles(document, read={'owner'})

@cached_property
def fragment(self) -> UuidModelUnion | None:
def fragment(self) -> ModelUuidProtocol | None:
"""Fragment within this document that this notification is for."""
return self.notification.fragment

Expand Down
18 changes: 16 additions & 2 deletions funnel/models/rsvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,26 @@

from __future__ import annotations

from typing import Literal, Self, overload
from typing import TYPE_CHECKING, Any, Literal, Self, overload

from flask import current_app

from baseframe import __
from coaster.sqlalchemy import StateManager, with_roles
from coaster.utils import LabeledEnum

from . import Mapped, Model, NoIdMixin, UuidMixin, db, relationship, sa, sa_orm, types
from . import (
Mapped,
Model,
NoIdMixin,
UuidMixin,
db,
declared_attr,
relationship,
sa,
sa_orm,
types,
)
from .account import Account, AccountEmail, AccountEmailClaim, AccountPhone
from .project import Project
from .project_membership import project_child_role_map
Expand Down Expand Up @@ -66,6 +77,9 @@ class Rsvp(UuidMixin, NoIdMixin, Model):
call={'owner', 'project_promoter'},
)

if TYPE_CHECKING:
id_: declared_attr[Any] # Fake entry for compatibility with ModelUuidProtocol

__roles__ = {
'owner': {'read': {'created_at', 'updated_at'}},
'project_promoter': {'read': {'created_at', 'updated_at'}},
Expand Down
65 changes: 58 additions & 7 deletions funnel/models/typing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
"""Union types for models with shared functionality."""

from collections.abc import Iterable, Iterator, Sequence
from datetime import datetime
from typing import Any, ClassVar, Protocol, TypeAlias, Union
from typing import (
Any,
ClassVar,
Literal,
Protocol,
TypeAlias,
Union,
overload,
runtime_checkable,
)
from uuid import UUID

from sqlalchemy import Table
from sqlalchemy.orm import Mapped, declared_attr

from coaster.sqlalchemy import QueryProperty
from coaster.sqlalchemy import LazyRoleSet, QueryProperty
from coaster.utils import InspectableSet

from .account import Account, AccountOldId, Team
from .auth_client import AuthClient
Expand Down Expand Up @@ -65,20 +76,60 @@ class ModelProtocol(Protocol):
query: ClassVar[QueryProperty]


class ModelIdProtocol(ModelProtocol, Protocol):
id_: declared_attr[Any]


class ModelTimestampProtocol(ModelProtocol, Protocol):
created_at: declared_attr[datetime]
updated_at: declared_attr[datetime]


class ModelUrlProtocol(Protocol):
@property
def absolute_url(self) -> str | None:
...

def url_for(self, action: str = 'view', **kwargs) -> str:
...


class ModelRoleProtocol(Protocol):
def roles_for(
self, actor: Account | None = None, anchors: Sequence[Any] = ()
) -> LazyRoleSet:
...

@property
def current_roles(self) -> InspectableSet[LazyRoleSet]:
...

@overload
def actors_with(
self, roles: Iterable[str], with_role: Literal[False] = False
) -> Iterator[Account]:
...

@overload
def actors_with(
self, roles: Iterable[str], with_role: Literal[True]
) -> Iterator[tuple[Account, str]]:
...

def actors_with(
self, roles: Iterable[str], with_role: bool = False
) -> Iterator[Account | tuple[Account, str]]:
...


class ModelIdProtocol(
ModelTimestampProtocol, ModelUrlProtocol, ModelRoleProtocol, Protocol
):
id_: declared_attr[Any]


@runtime_checkable # FIXME: This is never used, but needed to make type checkers happy
class ModelUuidProtocol(ModelIdProtocol, Protocol):
uuid: declared_attr[UUID]


class ModelSearchProtocol(ModelUuidProtocol, ModelTimestampProtocol, Protocol):
class ModelSearchProtocol(ModelUuidProtocol, Protocol):
search_vector: Mapped[str]

@property
Expand Down
8 changes: 4 additions & 4 deletions funnel/views/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from ..models import (
Account,
Draft,
ModelUuidProtocol,
Project,
ProjectRedirect,
TicketEvent,
UuidModelUnion,
db,
)
from ..typing import ReturnView
Expand Down Expand Up @@ -133,7 +133,7 @@ class DraftViewProtoMixin:
model: Any
obj: Any

def get_draft(self, obj: UuidModelUnion | None = None) -> Draft | None:
def get_draft(self, obj: ModelUuidProtocol | None = None) -> Draft | None:
"""
Return the draft object for `obj`. Defaults to `self.obj`.
Expand All @@ -151,7 +151,7 @@ def delete_draft(self, obj=None):
raise ValueError(_("There is no draft for the given object"))

def get_draft_data(
self, obj: UuidModelUnion | None = None
self, obj: ModelUuidProtocol | None = None
) -> tuple[None, None] | tuple[UUID | None, dict]:
"""
Return a tuple of draft data.
Expand All @@ -163,7 +163,7 @@ def get_draft_data(
return draft.revision, draft.formdata
return None, None

def autosave_post(self, obj: UuidModelUnion | None = None) -> ReturnView:
def autosave_post(self, obj: ModelUuidProtocol | None = None) -> ReturnView:
"""Handle autosave POST requests."""
obj = obj if obj is not None else self.obj
if 'form.revision' not in request.form:
Expand Down
4 changes: 2 additions & 2 deletions funnel/views/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
Account,
AccountEmail,
AccountPhone,
ModelUuidProtocol,
Notification,
NotificationFor,
NotificationRecipient,
UuidModelUnion,
db,
sa,
)
Expand Down Expand Up @@ -324,7 +324,7 @@ def fragments_query_options(self) -> list: # TODO: full spec
return []

@cached_property
def fragments(self) -> list[RoleAccessProxy[UuidModelUnion]]:
def fragments(self) -> list[RoleAccessProxy[ModelUuidProtocol]]:
query = self.notification_recipient.rolledup_fragments()
if query is None:
return []
Expand Down
8 changes: 4 additions & 4 deletions funnel/views/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class SearchProvider(Generic[_ST]):

@property
def regconfig(self) -> str:
"""Return PostgreSQL regconfig language, defaulting to English."""
"""Return PostgreSQL `regconfig` language, defaulting to English."""
return self.model.search_vector.type.options.get('regconfig', 'english')

@property
Expand All @@ -84,7 +84,7 @@ def title_column(self) -> sa.ColumnElement[str]:
# That makes this return value incorrect, but here we ignore the error as
# class:`CommentSearch` explicitly overrides :meth:`hltitle_column`, and that is
# the only place this property is accessed
return self.model.title # type: ignore[return-value]
return self.model.title # type: ignore[return-value] # FIXME

@property
def hltext(self) -> sa.ColumnElement[str]:
Expand Down Expand Up @@ -149,7 +149,7 @@ def all_count(self, tsquery: sa.Function) -> int:
return self.all_query(tsquery).options(sa_orm.load_only(self.model.id_)).count()


class SearchInAccountProvider(SearchProvider):
class SearchInAccountProvider(SearchProvider[_ST]):
"""Base class for search providers that support searching in an account."""

def account_query(self, tsquery: sa.Function, account: Account) -> Query:
Expand All @@ -165,7 +165,7 @@ def account_count(self, tsquery: sa.Function, account: Account) -> int:
)


class SearchInProjectProvider(SearchInAccountProvider):
class SearchInProjectProvider(SearchInAccountProvider[_ST]):
"""Base class for search providers that support searching in a project."""

def project_query(self, tsquery: sa.Function, project: Project) -> Query:
Expand Down

0 comments on commit 59f225f

Please sign in to comment.