Skip to content

Commit

Permalink
use wraps
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Dec 16, 2024
1 parent b7ffe78 commit 7e97961
Showing 1 changed file with 29 additions and 54 deletions.
83 changes: 29 additions & 54 deletions packages/atproto_client/models/string_formats.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
"""AT Proto string format validation.
*Note*: These formats are a working empirical understanding of the following resources:
- https://atproto.com/specs/lexicon
- https://github.com/bluesky-social/atproto/tree/main/interop-test-files/syntax
"""
"""AT Proto string format validation."""

import re
from datetime import datetime
from functools import wraps
from typing import Callable, Mapping, Set, Union, cast
from urllib.parse import urlparse

Expand Down Expand Up @@ -60,35 +54,21 @@
)


class _MaybeStrictValidator:
def __init__(self, validate_fn: Callable[..., str]) -> None:
self.validate_fn = validate_fn
self.__name__ = validate_fn.__name__
self.__doc__ = validate_fn.__doc__
def only_validate_if_strict(validate_fn: Callable[..., str]) -> Callable[..., str]:
"""Skip pydantic validation if not opting into strict validation via context."""

def __call__(self, v: str, info: ValidationInfo) -> str:
@wraps(validate_fn)
def wrapper(v: str, info: ValidationInfo) -> str:
"""Could likely be generalized to support arbitrary signatures."""
if info and isinstance(info.context, Mapping) and info.context.get(_OPT_IN_KEY, False):
return cast(core_schema.NoInfoValidatorFunction, self.validate_fn)(v)
return cast(core_schema.WithInfoValidatorFunction, validate_fn)(v, info)
return v

def __repr__(self) -> str:
return f'<validator {self.validate_fn.__name__}>'


def only_validate_if_strict(validate_fn: Callable[..., str]) -> Callable[..., str]:
"""Skip pydantic validation if not opting into strict validation via context.
Args:
validate_fn: The validation function to conditionally apply
Returns:
A wrapped validation function that only validates in strict mode
"""
return _MaybeStrictValidator(validate_fn)
return wrapper


@only_validate_if_strict
def validate_handle(v: str) -> str:
def validate_handle(v: str, _: ValidationInfo) -> str:
"""Validate an AT Protocol handle.
A handle must be a valid domain name with:
Expand Down Expand Up @@ -126,7 +106,7 @@ def validate_handle(v: str) -> str:


@only_validate_if_strict
def validate_did(v: str) -> str:
def validate_did(v: str, _: ValidationInfo) -> str:
"""Validate an AT Protocol DID.
A DID must follow the pattern:
Expand Down Expand Up @@ -171,7 +151,7 @@ def validate_did(v: str) -> str:


@only_validate_if_strict
def validate_nsid(v: str) -> str:
def validate_nsid(v: str, _: ValidationInfo) -> str:
"""Validate an AT Protocol NSID (Namespaced Identifier).
An NSID must have:
Expand Down Expand Up @@ -213,7 +193,7 @@ def validate_nsid(v: str) -> str:


@only_validate_if_strict
def validate_language(v: str) -> str:
def validate_language(v: str, _: ValidationInfo) -> str:
"""Validate an ISO language code.
Must match pattern:
Expand All @@ -237,7 +217,7 @@ def validate_language(v: str) -> str:


@only_validate_if_strict
def validate_record_key(v: str) -> str:
def validate_record_key(v: str, _: ValidationInfo) -> str:
"""Validate an AT Protocol record key.
A record key must:
Expand Down Expand Up @@ -266,7 +246,7 @@ def validate_record_key(v: str) -> str:


@only_validate_if_strict
def validate_cid(v: str) -> str:
def validate_cid(v: str, _: ValidationInfo) -> str:
"""Validate a Content Identifier (CID).
Must be:
Expand All @@ -289,7 +269,7 @@ def validate_cid(v: str) -> str:


@only_validate_if_strict
def validate_at_uri(v: str) -> str:
def validate_at_uri(v: str, _: ValidationInfo) -> str:
"""Validate an AT Protocol URI.
Must follow pattern:
Expand Down Expand Up @@ -322,7 +302,7 @@ def validate_at_uri(v: str) -> str:


@only_validate_if_strict
def validate_datetime(v: str) -> str:
def validate_datetime(v: str, _: ValidationInfo) -> str:
"""Validate an ISO 8601/RFC 3339 datetime string.
Requirements:
Expand Down Expand Up @@ -380,7 +360,7 @@ def validate_datetime(v: str) -> str:


@only_validate_if_strict
def validate_tid(v: str) -> str:
def validate_tid(v: str, _: ValidationInfo) -> str:
"""Validate an AT Protocol TID (Temporal ID).
Must be:
Expand All @@ -406,7 +386,7 @@ def validate_tid(v: str) -> str:


@only_validate_if_strict
def validate_uri(v: str) -> str:
def validate_uri(v: str, _: ValidationInfo) -> str:
"""Validate a standard URI.
Requirements:
Expand Down Expand Up @@ -442,21 +422,16 @@ def validate_uri(v: str) -> str:
return v


class _ReprBeforeValidator(BeforeValidator):
def __repr__(self) -> str:
return f'<{self.func.__name__} function>'


Handle = Annotated[str, _ReprBeforeValidator(validate_handle)]
Did = Annotated[str, _ReprBeforeValidator(validate_did)]
Nsid = Annotated[str, _ReprBeforeValidator(validate_nsid)]
Language = Annotated[str, _ReprBeforeValidator(validate_language)]
RecordKey = Annotated[str, _ReprBeforeValidator(validate_record_key)]
Cid = Annotated[str, _ReprBeforeValidator(validate_cid)]
AtUri = Annotated[str, _ReprBeforeValidator(validate_at_uri)]
DateTime = Annotated[str, _ReprBeforeValidator(validate_datetime)] # see pydantic-extra-types #239
Tid = Annotated[str, _ReprBeforeValidator(validate_tid)]
Uri = Annotated[str, _ReprBeforeValidator(validate_uri)]
Handle = Annotated[str, BeforeValidator(validate_handle)]
Did = Annotated[str, BeforeValidator(validate_did)]
Nsid = Annotated[str, BeforeValidator(validate_nsid)]
Language = Annotated[str, BeforeValidator(validate_language)]
RecordKey = Annotated[str, BeforeValidator(validate_record_key)]
Cid = Annotated[str, BeforeValidator(validate_cid)]
AtUri = Annotated[str, BeforeValidator(validate_at_uri)]
DateTime = Annotated[str, BeforeValidator(validate_datetime)] # see pydantic-extra-types #239
Tid = Annotated[str, BeforeValidator(validate_tid)]
Uri = Annotated[str, BeforeValidator(validate_uri)]

# Any valid ATProto string format
AtProtoString = Annotated[
Expand Down

0 comments on commit 7e97961

Please sign in to comment.