Skip to content

Commit

Permalink
Add pydantic compatibility submodule
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Feb 4, 2024
1 parent 960215f commit 927f2ec
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 58 deletions.
24 changes: 24 additions & 0 deletions src/curies/_pydantic_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""A compatibility layer for pydantic 1 and 2."""

from pydantic import __version__ as pydantic_version

__all__ = [
"PYDANTIC_V1",
"field_validator",
"get_field_validator_values",
]

PYDANTIC_V1 = pydantic_version.startswith("1.")

if PYDANTIC_V1:
from pydantic import validator as field_validator
else:
from pydantic import field_validator


def get_field_validator_values(values, key: str):
"""Get the value for the key from a field validator object, cross-compatible with Pydantic 1 and 2."""
if PYDANTIC_V1:
return values[key]
else:
return values.data[key] # type:ignore
87 changes: 29 additions & 58 deletions src/curies/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,10 @@

import requests
from pydantic import BaseModel, Field
from pydantic import __version__ as pydantic_version

# Check if the major version of Pydantic is 1 or 2
if pydantic_version.startswith("1."):
from pydantic import validator as field_validator
else:
from pydantic import field_validator

from pytrie import StringTrie

from ._pydantic_compat import PYDANTIC_V1, field_validator, get_field_validator_values

if TYPE_CHECKING: # pragma: no cover
import pandas
import rdflib
Expand Down Expand Up @@ -278,21 +272,15 @@ class Records(BaseModel):
@field_validator("prefix_synonyms") # type:ignore
def prefix_not_in_synonyms(cls, v: str, values: Mapping[str, Any]) -> str: # noqa:N805
"""Check that the canonical prefix does not apper in the prefix synonym list."""
if pydantic_version.startswith("1."):
prefix = values["prefix"]
else:
prefix = values.data["prefix"] # type:ignore
prefix = get_field_validator_values(values, "prefix")
if prefix in v:
raise ValueError(f"Duplicate of canonical prefix `{prefix}` in prefix synonyms")
return v

@field_validator("uri_prefix_synonyms") # type:ignore
def uri_prefix_not_in_synonyms(cls, v: str, values: Mapping[str, Any]) -> str: # noqa:N805
"""Check that the canonical URI prefix does not apper in the URI prefix synonym list."""
if pydantic_version.startswith("1."):
uri_prefix = values["uri_prefix"]
else:
uri_prefix = values.data["uri_prefix"] # type:ignore
uri_prefix = get_field_validator_values(values, "uri_prefix")
if uri_prefix in v:
raise ValueError(
f"Duplicate of canonical URI prefix `{uri_prefix}` in URI prefix synonyms"
Expand Down Expand Up @@ -1048,8 +1036,7 @@ def is_uri(self, s: str) -> bool:
@overload
def compress_or_standardize(
self, uri_or_curie: str, *, strict: Literal[True] = True, passthrough: bool = False
) -> str:
...
) -> str: ...

# docstr-coverage:excused `overload`
@overload
Expand All @@ -1059,8 +1046,7 @@ def compress_or_standardize(
*,
strict: Literal[False] = False,
passthrough: Literal[True] = True,
) -> str:
...
) -> str: ...

# docstr-coverage:excused `overload`
@overload
Expand All @@ -1070,8 +1056,7 @@ def compress_or_standardize(
*,
strict: Literal[False] = False,
passthrough: Literal[False] = False,
) -> Optional[str]:
...
) -> Optional[str]: ...

def compress_or_standardize(
self, uri_or_curie: str, *, strict: bool = False, passthrough: bool = False
Expand Down Expand Up @@ -1128,22 +1113,21 @@ def compress_strict(self, uri: str) -> str:

# docstr-coverage:excused `overload`
@overload
def compress(self, uri: str, *, strict: Literal[True] = True, passthrough: bool = False) -> str:
...
def compress(
self, uri: str, *, strict: Literal[True] = True, passthrough: bool = False
) -> str: ...

# docstr-coverage:excused `overload`
@overload
def compress(
self, uri: str, *, strict: Literal[False] = False, passthrough: Literal[True] = True
) -> str:
...
) -> str: ...

# docstr-coverage:excused `overload`
@overload
def compress(
self, uri: str, *, strict: Literal[False] = False, passthrough: Literal[False] = False
) -> Optional[str]:
...
) -> Optional[str]: ...

def compress(
self, uri: str, *, strict: bool = False, passthrough: bool = False
Expand Down Expand Up @@ -1238,8 +1222,7 @@ def is_curie(self, s: str) -> bool:
@overload
def expand_or_standardize(
self, curie_or_uri: str, *, strict: Literal[True] = True, passthrough: bool = False
) -> str:
...
) -> str: ...

# docstr-coverage:excused `overload`
@overload
Expand All @@ -1249,8 +1232,7 @@ def expand_or_standardize(
*,
strict: Literal[False] = False,
passthrough: Literal[True] = True,
) -> str:
...
) -> str: ...

# docstr-coverage:excused `overload`
@overload
Expand All @@ -1260,8 +1242,7 @@ def expand_or_standardize(
*,
strict: Literal[False] = False,
passthrough: Literal[False] = False,
) -> Optional[str]:
...
) -> Optional[str]: ...

def expand_or_standardize(
self, curie_or_uri: str, *, strict: bool = False, passthrough: bool = False
Expand Down Expand Up @@ -1318,22 +1299,21 @@ def expand_strict(self, curie: str) -> str:

# docstr-coverage:excused `overload`
@overload
def expand(self, curie: str, *, strict: Literal[True] = True, passthrough: bool = False) -> str:
...
def expand(
self, curie: str, *, strict: Literal[True] = True, passthrough: bool = False
) -> str: ...

# docstr-coverage:excused `overload`
@overload
def expand(
self, curie: str, *, strict: Literal[False] = False, passthrough: Literal[True] = True
) -> str:
...
) -> str: ...

# docstr-coverage:excused `overload`
@overload
def expand(
self, curie: str, *, strict: Literal[False] = False, passthrough: Literal[False] = False
) -> Optional[str]:
...
) -> Optional[str]: ...

def expand(
self, curie: str, *, strict: bool = False, passthrough: bool = False
Expand Down Expand Up @@ -1473,22 +1453,19 @@ def expand_pair_all(self, prefix: str, identifier: str) -> Optional[Collection[s
@overload
def standardize_prefix(
self, prefix: str, *, strict: Literal[True] = True, passthrough: bool = False
) -> str:
...
) -> str: ...

# docstr-coverage:excused `overload`
@overload
def standardize_prefix(
self, prefix: str, *, strict: Literal[False] = False, passthrough: Literal[True] = True
) -> str:
...
) -> str: ...

# docstr-coverage:excused `overload`
@overload
def standardize_prefix(
self, prefix: str, *, strict: Literal[False] = False, passthrough: Literal[False] = False
) -> Optional[str]:
...
) -> Optional[str]: ...

def standardize_prefix(
self, prefix: str, *, strict: bool = False, passthrough: bool = False
Expand Down Expand Up @@ -1532,22 +1509,19 @@ def standardize_prefix(
@overload
def standardize_curie(
self, curie: str, *, strict: Literal[True] = True, passthrough: bool = False
) -> str:
...
) -> str: ...

# docstr-coverage:excused `overload`
@overload
def standardize_curie(
self, curie: str, *, strict: Literal[False] = False, passthrough: Literal[True] = True
) -> str:
...
) -> str: ...

# docstr-coverage:excused `overload`
@overload
def standardize_curie(
self, curie: str, *, strict: Literal[False] = False, passthrough: Literal[False] = False
) -> Optional[str]:
...
) -> Optional[str]: ...

def standardize_curie(
self, curie: str, *, strict: bool = False, passthrough: bool = False
Expand Down Expand Up @@ -1594,22 +1568,19 @@ def standardize_curie(
@overload
def standardize_uri(
self, uri: str, *, strict: Literal[True] = True, passthrough: bool = False
) -> str:
...
) -> str: ...

# docstr-coverage:excused `overload`
@overload
def standardize_uri(
self, uri: str, *, strict: Literal[False] = False, passthrough: Literal[True] = True
) -> str:
...
) -> str: ...

# docstr-coverage:excused `overload`
@overload
def standardize_uri(
self, uri: str, *, strict: Literal[False] = False, passthrough: Literal[False] = False
) -> Optional[str]:
...
) -> Optional[str]: ...

def standardize_uri(
self, uri: str, *, strict: bool = False, passthrough: bool = False
Expand Down

0 comments on commit 927f2ec

Please sign in to comment.