diff --git a/packages/atproto_client/models/string_formats.py b/packages/atproto_client/models/string_formats.py index b9a3699d..c8a8a45a 100644 --- a/packages/atproto_client/models/string_formats.py +++ b/packages/atproto_client/models/string_formats.py @@ -1,14 +1,8 @@ -"""AT Proto string format validation. - -*Note*: These formats are a working empirical understanding of the following resources: - -- https://atproto.com/specs/lexicon - -- https://github.com/bluesky-social/atproto/tree/main/interop-test-files/syntax -""" +"""AT Proto string format validation.""" import re from datetime import datetime +from functools import wraps from typing import Callable, Mapping, Set, Union, cast from urllib.parse import urlparse @@ -60,35 +54,21 @@ ) -class _MaybeStrictValidator: - def __init__(self, validate_fn: Callable[..., str]) -> None: - self.validate_fn = validate_fn - self.__name__ = validate_fn.__name__ - self.__doc__ = validate_fn.__doc__ +def only_validate_if_strict(validate_fn: Callable[..., str]) -> Callable[..., str]: + """Skip pydantic validation if not opting into strict validation via context.""" - def __call__(self, v: str, info: ValidationInfo) -> str: + @wraps(validate_fn) + def wrapper(v: str, info: ValidationInfo) -> str: + """Could likely be generalized to support arbitrary signatures.""" if info and isinstance(info.context, Mapping) and info.context.get(_OPT_IN_KEY, False): - return cast(core_schema.NoInfoValidatorFunction, self.validate_fn)(v) + return cast(core_schema.WithInfoValidatorFunction, validate_fn)(v, info) return v - def __repr__(self) -> str: - return f'' - - -def only_validate_if_strict(validate_fn: Callable[..., str]) -> Callable[..., str]: - """Skip pydantic validation if not opting into strict validation via context. - - Args: - validate_fn: The validation function to conditionally apply - - Returns: - A wrapped validation function that only validates in strict mode - """ - return _MaybeStrictValidator(validate_fn) + return wrapper @only_validate_if_strict -def validate_handle(v: str) -> str: +def validate_handle(v: str, _: ValidationInfo) -> str: """Validate an AT Protocol handle. A handle must be a valid domain name with: @@ -126,7 +106,7 @@ def validate_handle(v: str) -> str: @only_validate_if_strict -def validate_did(v: str) -> str: +def validate_did(v: str, _: ValidationInfo) -> str: """Validate an AT Protocol DID. A DID must follow the pattern: @@ -171,7 +151,7 @@ def validate_did(v: str) -> str: @only_validate_if_strict -def validate_nsid(v: str) -> str: +def validate_nsid(v: str, _: ValidationInfo) -> str: """Validate an AT Protocol NSID (Namespaced Identifier). An NSID must have: @@ -213,7 +193,7 @@ def validate_nsid(v: str) -> str: @only_validate_if_strict -def validate_language(v: str) -> str: +def validate_language(v: str, _: ValidationInfo) -> str: """Validate an ISO language code. Must match pattern: @@ -237,7 +217,7 @@ def validate_language(v: str) -> str: @only_validate_if_strict -def validate_record_key(v: str) -> str: +def validate_record_key(v: str, _: ValidationInfo) -> str: """Validate an AT Protocol record key. A record key must: @@ -266,7 +246,7 @@ def validate_record_key(v: str) -> str: @only_validate_if_strict -def validate_cid(v: str) -> str: +def validate_cid(v: str, _: ValidationInfo) -> str: """Validate a Content Identifier (CID). Must be: @@ -289,7 +269,7 @@ def validate_cid(v: str) -> str: @only_validate_if_strict -def validate_at_uri(v: str) -> str: +def validate_at_uri(v: str, _: ValidationInfo) -> str: """Validate an AT Protocol URI. Must follow pattern: @@ -322,7 +302,7 @@ def validate_at_uri(v: str) -> str: @only_validate_if_strict -def validate_datetime(v: str) -> str: +def validate_datetime(v: str, _: ValidationInfo) -> str: """Validate an ISO 8601/RFC 3339 datetime string. Requirements: @@ -380,7 +360,7 @@ def validate_datetime(v: str) -> str: @only_validate_if_strict -def validate_tid(v: str) -> str: +def validate_tid(v: str, _: ValidationInfo) -> str: """Validate an AT Protocol TID (Temporal ID). Must be: @@ -406,7 +386,7 @@ def validate_tid(v: str) -> str: @only_validate_if_strict -def validate_uri(v: str) -> str: +def validate_uri(v: str, _: ValidationInfo) -> str: """Validate a standard URI. Requirements: @@ -442,21 +422,16 @@ def validate_uri(v: str) -> str: return v -class _ReprBeforeValidator(BeforeValidator): - def __repr__(self) -> str: - return f'<{self.func.__name__} function>' - - -Handle = Annotated[str, _ReprBeforeValidator(validate_handle)] -Did = Annotated[str, _ReprBeforeValidator(validate_did)] -Nsid = Annotated[str, _ReprBeforeValidator(validate_nsid)] -Language = Annotated[str, _ReprBeforeValidator(validate_language)] -RecordKey = Annotated[str, _ReprBeforeValidator(validate_record_key)] -Cid = Annotated[str, _ReprBeforeValidator(validate_cid)] -AtUri = Annotated[str, _ReprBeforeValidator(validate_at_uri)] -DateTime = Annotated[str, _ReprBeforeValidator(validate_datetime)] # see pydantic-extra-types #239 -Tid = Annotated[str, _ReprBeforeValidator(validate_tid)] -Uri = Annotated[str, _ReprBeforeValidator(validate_uri)] +Handle = Annotated[str, BeforeValidator(validate_handle)] +Did = Annotated[str, BeforeValidator(validate_did)] +Nsid = Annotated[str, BeforeValidator(validate_nsid)] +Language = Annotated[str, BeforeValidator(validate_language)] +RecordKey = Annotated[str, BeforeValidator(validate_record_key)] +Cid = Annotated[str, BeforeValidator(validate_cid)] +AtUri = Annotated[str, BeforeValidator(validate_at_uri)] +DateTime = Annotated[str, BeforeValidator(validate_datetime)] # see pydantic-extra-types #239 +Tid = Annotated[str, BeforeValidator(validate_tid)] +Uri = Annotated[str, BeforeValidator(validate_uri)] # Any valid ATProto string format AtProtoString = Annotated[