Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
placate 3.8
  • Loading branch information
zzstoatzz committed Dec 2, 2024
1 parent 9807a81 commit a1f3e29
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 119 deletions.
42 changes: 26 additions & 16 deletions packages/atproto_client/models/string_formats.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import re
from datetime import datetime
from typing import Callable, Mapping, Set, Union
from inspect import signature
from typing import Mapping, Set, Union, cast
from urllib.parse import urlparse

from atproto_core.nsid import validate_nsid as atproto_core_validate_nsid
from pydantic import BeforeValidator, Field, ValidationInfo
from pydantic_core import core_schema
from typing_extensions import Annotated, Literal
from typing_extensions import Annotated, Literal, TypeAlias

_OPT_IN_KEY: Literal['strict_string_format'] = 'strict_string_format'

Expand Down Expand Up @@ -50,20 +51,29 @@
r')?$'
)

WithOrWithoutInfoValidator: TypeAlias = Union[
core_schema.WithInfoValidatorFunction, core_schema.NoInfoValidatorFunction
]


def only_validate_if_strict(validate_fn: core_schema.WithInfoValidatorFunction) -> Callable:
"""Skip validation if not opting into strict validation."""
def only_validate_if_strict(validate_fn: WithOrWithoutInfoValidator) -> WithOrWithoutInfoValidator:
"""Skip pydantic validation if not opting into strict validation via context."""
params = list(signature(validate_fn).parameters.values())
validator_wants_info = len(params) > 1 and params[1].annotation is ValidationInfo

def wrapper(v: str, info: ValidationInfo) -> str:
"""Could likely be generalized to support arbitrary signatures."""
if info and isinstance(info.context, Mapping) and info.context.get(_OPT_IN_KEY, False):
return validate_fn(v, info)
if validator_wants_info:
return cast(core_schema.WithInfoValidatorFunction, validate_fn)(v, info)
return cast(core_schema.NoInfoValidatorFunction, validate_fn)(v)
return v

return wrapper


@only_validate_if_strict
def validate_handle(v: str, info: ValidationInfo) -> str:
def validate_handle(v: str) -> str:
# Check ASCII first
if not v.isascii():
raise ValueError('Invalid handle: must contain only ASCII characters')
Expand All @@ -78,7 +88,7 @@ def validate_handle(v: str, info: ValidationInfo) -> str:


@only_validate_if_strict
def validate_did(v: str, info: ValidationInfo) -> str:
def validate_did(v: str) -> str:
# Check for invalid characters
if any(c in v for c in '/?#[]@'):
raise ValueError('Invalid DID: cannot contain /, ?, #, [, ], or @ characters')
Expand All @@ -98,7 +108,7 @@ def validate_did(v: str, info: ValidationInfo) -> str:


@only_validate_if_strict
def validate_nsid(v: str, info: ValidationInfo) -> str:
def validate_nsid(v: str) -> str:
if (
not atproto_core_validate_nsid(v, soft_fail=True)
or len(v) > MAX_NSID_LENGTH
Expand All @@ -113,14 +123,14 @@ def validate_nsid(v: str, info: ValidationInfo) -> str:


@only_validate_if_strict
def validate_language(v: str, info: ValidationInfo) -> str:
def validate_language(v: str) -> str:
if not LANG_RE.match(v):
raise ValueError('Invalid language code: must be ISO language code (e.g. en or en-US)')
return v


@only_validate_if_strict
def validate_record_key(v: str, info: ValidationInfo) -> str:
def validate_record_key(v: str) -> str:
if v in INVALID_RECORD_KEYS or not RKEY_RE.match(v):
raise ValueError(
'Invalid record key: must contain only alphanumeric, dot, underscore, colon, tilde, or hyphen characters'
Expand All @@ -129,14 +139,14 @@ def validate_record_key(v: str, info: ValidationInfo) -> str:


@only_validate_if_strict
def validate_cid(v: str, info: ValidationInfo) -> str:
def validate_cid(v: str) -> str:
if not CID_RE.match(v):
raise ValueError('Invalid CID: must be a valid Content Identifier with minimum length 8')
return v


@only_validate_if_strict
def validate_at_uri(v: str, info: ValidationInfo) -> str:
def validate_at_uri(v: str) -> str:
if len(v) >= MAX_AT_URI_LENGTH:
raise ValueError(f'Invalid AT-URI: must be under {MAX_AT_URI_LENGTH} chars')

Expand All @@ -147,7 +157,7 @@ def validate_at_uri(v: str, info: ValidationInfo) -> str:


@only_validate_if_strict
def validate_datetime(v: str, info: ValidationInfo) -> str:
def validate_datetime(v: str) -> str:
# Must contain uppercase T and Z if used
if v != v.strip():
raise ValueError('Invalid datetime: no whitespace allowed')
Expand Down Expand Up @@ -180,14 +190,14 @@ def validate_datetime(v: str, info: ValidationInfo) -> str:


@only_validate_if_strict
def validate_tid(v: str, info: ValidationInfo) -> str:
def validate_tid(v: str) -> str:
if not TID_RE.match(v) or (ord(v[0]) & 0x40):
raise ValueError(f'Invalid TID: must be exactly {TID_LENGTH} lowercase letters/numbers')
return v


@only_validate_if_strict
def validate_uri(v: str, info: ValidationInfo) -> str:
def validate_uri(v: str) -> str:
if len(v) >= MAX_URI_LENGTH or ' ' in v:
raise ValueError(f'Invalid URI: must be under {MAX_URI_LENGTH} chars and not contain spaces')
parsed = urlparse(v)
Expand All @@ -212,7 +222,7 @@ def validate_uri(v: str, info: ValidationInfo) -> str:
Uri = Annotated[str, BeforeValidator(validate_uri)]

# Any valid ATProto string format
ATProtoString = Annotated[
AtProtoString = Annotated[
Union[Handle, Did, Nsid, AtUri, Cid, DateTime, Tid, RecordKey, Uri, Language],
Field(description='ATProto string format'),
]
8 changes: 6 additions & 2 deletions packages/atproto_client/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing as t

import typing_extensions as te
from pydantic import ValidationError
from pydantic import BaseModel, ValidationError
from pydantic_core import from_json, to_json

from atproto_client import models
Expand Down Expand Up @@ -108,7 +108,11 @@ def _get_or_create(
return model_data

try:
return model.model_validate(model_data, context={'strict_string_format': strict_string_format})
if issubclass(model, BaseModel):
return model.model_validate(model_data, context={'strict_string_format': strict_string_format})
if not isinstance(model_data, t.Mapping):
raise ModelError(f'Cannot parse model of type {model}')
return model(**model_data)
except ValidationError as e:
raise ModelError(str(e)) from e

Expand Down
159 changes: 61 additions & 98 deletions tests/test_atproto_client/models/tests/test_string_formats.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import lru_cache
from pathlib import Path
from typing import List

Expand All @@ -13,32 +14,6 @@
INTEROP_TEST_FILES_DIR: Path = Path('tests/test_atproto_client/interop-test-files/syntax')


# TODO: 230 passed, 11 xfailed
# These cases appear in both _valid.txt and _invalid.txt files.
# Need investigation to determine if our validation is incorrect or if test data needs updating:
SKIP_THESE_VALUES = [
(
string_formats.AtUri,
'at://did:plc:asdf123',
), # Listed as both valid and invalid in AT-URI files under "enforces spec basics"
(
string_formats.AtUri,
'at://did:plc:asdf123/com.atproto.feed.post',
), # Same AT-URI pattern - appears in both valid/invalid files
(
string_formats.DateTime,
'1985-04-12T23:20:50.123Z',
), # Listed as "preferred" in valid but also appears in invalid under RFC-3339 section
(
string_formats.DateTime,
'1985-04-12T23:20:50.123-00:00',
), # Listed as "supported" in valid but marked invalid under timezone formats
(string_formats.DateTime, '1985-04-12T23:20Z'), # Similar timezone format discrepancy between valid/invalid files
(string_formats.Handle, 'john.test'), # Base pattern appears valid but numeric suffix versions are marked invalid
(string_formats.Nsid, 'one.two.three'), # Same pattern - base form valid but numeric suffixes marked invalid
]


def get_test_cases(filename: str) -> List[str]:
"""Get non-comment, non-empty lines from an interop test file.
Expand All @@ -60,90 +35,82 @@ def get_test_cases(filename: str) -> List[str]:
]


@pytest.fixture
def valid_handles() -> List[str]:
return get_test_cases('handle_syntax_valid.txt')


@pytest.fixture
def valid_dids() -> List[str]:
return get_test_cases('did_syntax_valid.txt')


@pytest.fixture
def valid_nsids() -> List[str]:
return get_test_cases('nsid_syntax_valid.txt')


@pytest.fixture
def valid_aturis() -> List[str]:
return get_test_cases('aturi_syntax_valid.txt')


@pytest.fixture
def valid_datetimes() -> List[str]:
return get_test_cases('datetime_syntax_valid.txt')


@pytest.fixture
def valid_tids() -> List[str]:
return get_test_cases('tid_syntax_valid.txt')


@pytest.fixture
def valid_record_keys() -> List[str]:
return get_test_cases('recordkey_syntax_valid.txt')
@lru_cache
def read_test_data() -> dict:
"""Load all test data once at session start"""
return {
'valid': {
'handle': get_test_cases('handle_syntax_valid.txt'),
'did': get_test_cases('did_syntax_valid.txt'),
'nsid': get_test_cases('nsid_syntax_valid.txt'),
'at_uri': get_test_cases('aturi_syntax_valid.txt'),
'datetime': get_test_cases('datetime_syntax_valid.txt'),
'tid': get_test_cases('tid_syntax_valid.txt'),
'record_key': get_test_cases('recordkey_syntax_valid.txt'),
},
'invalid': {
'handle': get_test_cases('handle_syntax_invalid.txt'),
'did': get_test_cases('did_syntax_invalid.txt'),
'nsid': get_test_cases('nsid_syntax_invalid.txt'),
'at_uri': get_test_cases('aturi_syntax_invalid.txt'),
'datetime': get_test_cases('datetime_syntax_invalid.txt'),
'tid': get_test_cases('tid_syntax_invalid.txt'),
'record_key': get_test_cases('recordkey_syntax_invalid.txt'),
},
}


@pytest.fixture
def valid_data(
valid_handles: List[str],
valid_dids: List[str],
valid_nsids: List[str],
valid_aturis: List[str],
valid_datetimes: List[str],
valid_tids: List[str],
valid_record_keys: List[str],
) -> dict:
def valid_data() -> dict:
"""Get first valid example of each type plus constants"""
test_data = read_test_data()
return {
'handle': valid_handles[0],
'did': valid_dids[0],
'nsid': valid_nsids[0],
'at_uri': valid_aturis[0],
'handle': test_data['valid']['handle'][0],
'did': test_data['valid']['did'][0],
'nsid': test_data['valid']['nsid'][0],
'at_uri': test_data['valid']['at_uri'][0],
'cid': 'bafyreidfayvfuwqa2beehqn7axeeeaej5aqvaowxgwcdt2rw', # No interop test file for CID
'datetime': valid_datetimes[0],
'tid': valid_tids[0],
'record_key': valid_record_keys[0],
'datetime': test_data['valid']['datetime'][0],
'tid': test_data['valid']['tid'][0],
'record_key': test_data['valid']['record_key'][0],
'uri': 'https://example.com', # No interop test file for URI
'language': 'en-US', # No interop test file for language
}


@pytest.fixture
def invalid_data() -> dict:
"""Get first invalid example of each type plus constants"""
test_data = read_test_data()
return {
'handle': get_test_cases('handle_syntax_invalid.txt')[0],
'did': get_test_cases('did_syntax_invalid.txt')[0],
'nsid': get_test_cases('nsid_syntax_invalid.txt')[0],
'at_uri': get_test_cases('aturi_syntax_invalid.txt')[0],
'handle': test_data['invalid']['handle'][0],
'did': test_data['invalid']['did'][0],
'nsid': test_data['invalid']['nsid'][0],
'at_uri': test_data['invalid']['at_uri'][0],
'cid': 'short', # No interop test file for CID
'datetime': get_test_cases('datetime_syntax_invalid.txt')[0],
'tid': get_test_cases('tid_syntax_invalid.txt')[0],
'record_key': get_test_cases('recordkey_syntax_invalid.txt')[0],
'datetime': test_data['invalid']['datetime'][0],
'tid': test_data['invalid']['tid'][0],
'record_key': test_data['invalid']['record_key'][0],
'uri': 'invalid-uri-no-scheme', # No interop test file for URI
'language': 'invalid!', # No interop test file for language
}


@pytest.mark.parametrize(
'validator_type,field_name,invalid_value',
[(string_formats.AtUri, 'at_uri', c) for c in get_test_cases('aturi_syntax_invalid.txt')]
+ [(string_formats.DateTime, 'datetime', c) for c in get_test_cases('datetime_syntax_invalid.txt')]
+ [(string_formats.Handle, 'handle', c) for c in get_test_cases('handle_syntax_invalid.txt')]
+ [(string_formats.Did, 'did', c) for c in get_test_cases('did_syntax_invalid.txt')]
+ [(string_formats.Nsid, 'nsid', c) for c in get_test_cases('nsid_syntax_invalid.txt')]
+ [(string_formats.Tid, 'tid', c) for c in get_test_cases('tid_syntax_invalid.txt')]
+ [(string_formats.RecordKey, 'record_key', c) for c in get_test_cases('recordkey_syntax_invalid.txt')],
[
(validator_type, field_name, invalid_value)
for validator_type, field_name in [
(string_formats.AtUri, 'at_uri'),
(string_formats.DateTime, 'datetime'),
(string_formats.Handle, 'handle'),
(string_formats.Did, 'did'),
(string_formats.Nsid, 'nsid'),
(string_formats.Tid, 'tid'),
(string_formats.RecordKey, 'record_key'),
]
for invalid_value in read_test_data()['invalid'][field_name]
],
)
def test_string_format_validation(validator_type: type, field_name: str, invalid_value: str, valid_data: dict) -> None:
"""Test validation for each string format type."""
Expand Down Expand Up @@ -179,7 +146,7 @@ def test_string_format_validation(validator_type: type, field_name: str, invalid
def test_generic_string_format_validation(valid_value: str) -> None:
"""Test that ATProtoString accepts each valid string format."""

validated = TypeAdapter(string_formats.ATProtoString).validate_python(valid_value, context={_OPT_IN_KEY: True})
validated = TypeAdapter(string_formats.AtProtoString).validate_python(valid_value, context={_OPT_IN_KEY: True})
assert validated == valid_value


Expand All @@ -199,18 +166,14 @@ class FooModel(BaseModel):
assert instance.did == valid_data['did']

# Test invalid handle fails
try:
with pytest.raises(ModelError) as exc_info:
get_or_create({'handle': invalid_data['handle'], 'did': valid_data['did']}, FooModel, strict_string_format=True)
pytest.fail('Handle validation should have failed')
except ModelError as e:
assert 'must be a domain name' in str(e)
assert 'must be a domain name' in str(exc_info.value)

# Test invalid did fails
try:
with pytest.raises(ModelError) as exc_info:
get_or_create({'handle': valid_data['handle'], 'did': invalid_data['did']}, FooModel, strict_string_format=True)
pytest.fail('Did validation should have failed')
except ModelError as e:
assert 'must be in format did:method:identifier' in str(e)
assert 'must be in format did:method:identifier' in str(exc_info.value)

# Test that validation is skipped when strict_string_format=False
instance = get_or_create(
Expand Down
Loading

0 comments on commit a1f3e29

Please sign in to comment.