Skip to content

Commit

Permalink
Use a dummy class to contain the type signature
Browse files Browse the repository at this point in the history
  • Loading branch information
multimeric committed Aug 13, 2024
1 parent 6a50fef commit 2ff846b
Show file tree
Hide file tree
Showing 3 changed files with 387 additions and 281 deletions.
152 changes: 104 additions & 48 deletions codegen/id_mapping/generate.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,115 @@
import ast
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Sequence
from typing import List

import black
import requests
import typer

from codegen.util import make_literal

app = typer.Typer(result_callback=lambda x: print(x))


def make_function(
source_type: ast.expr, dest_type: ast.expr, taxon_id: bool, overload: bool = True
) -> ast.FunctionDef:
"""
Makes a `submit()` function definition, as used by the ID mapper
Params:
source_type: Type of the `source` argument
dest_type: Type of the `dest` argument
taxon_id: If true, include the `taxon_id` parameter
overload: If true, this is a function overload
"""
args: List[ast.arg] = [
ast.arg(
arg="cls",
),
# source: Literal[...]
ast.arg(
arg="source",
annotation=source_type,
),
# source: dest[...]
ast.arg(
arg="dest",
annotation=dest_type,
),
# ids: Iterable[str]
ast.arg(
"ids",
ast.Subscript(ast.Name("Iterable"), ast.Name("str")),
),
]
defaults: list[ast.expr | None] = [None, None, None]

if taxon_id:
# taxon_id: Optional[str] = None
args.append(
# taxon_id: bool
ast.arg(
"taxon_id",
annotation=ast.Subscript(ast.Name("Optional"), ast.Name("str")),
)
)
defaults.append(ast.Constant(None))

decorator_list: list[ast.expr] = [
ast.Name("classmethod"),
]
if overload:
decorator_list.append(
ast.Name("overload"),
)

return ast.FunctionDef(
name=f"submit",
args=ast.arguments(
posonlyargs=[], args=args, kwonlyargs=[], kw_defaults=[], defaults=defaults # type: ignore
),
body=[ast.Expr(ast.Constant(value=...))],
decorator_list=decorator_list,
)


@dataclass
class Rule:
"""
Represents a "rule" in the Uniprot API terminology, which is a method overload
in the Unipressed world. A rule is a set of allowed conversions from one database
to another.
"""

#: Rule ID
id: int = 0
#: List of databases that can be converted to, in this rule
tos: list[ast.Constant] = field(default_factory=list)
#: List of databases that can be converted from, in this rule
froms: list[ast.Constant] = field(default_factory=list)
#: Whether this rule supports specifying the taxon ID
taxon_id: bool = False

def to_typed_dict(self) -> ast.ClassDef:
body: Sequence[ast.AnnAssign] = [
# source: Literal[...]
ast.AnnAssign(
ast.Name("source"),
ast.Subscript(
value=ast.Name("Literal"),
slice=ast.Tuple(elts=self.froms), # type: ignore
),
simple=1,
),
# dest: Literal[...]
ast.AnnAssign(
ast.Name("dest"),
ast.Subscript(
value=ast.Name("Literal"),
slice=ast.Tuple(elts=self.tos), # type: ignore
),
simple=1,
def to_function(self) -> ast.FunctionDef:
return make_function(
source_type=ast.Subscript(
value=ast.Name("Literal"),
slice=ast.Tuple(elts=self.froms),
),
# ids: Iterable[str]
ast.AnnAssign(
ast.Name("ids"),
ast.Subscript(ast.Name("Iterable"), ast.Name("str")),
simple=1,
dest_type=ast.Subscript(
value=ast.Name("Literal"),
slice=ast.Tuple(elts=self.tos), # type: ignore
),
]

if self.taxon_id:
body.append(
# taxon_id: bool
ast.AnnAssign(
target=ast.Name("taxon_id"),
annotation=ast.Name("bool"),
value=None,
simple=1,
)
)

return ast.ClassDef(
name=f"Rule{self.id}",
bases=[ast.Name("TypedDict")],
body=body, # type: ignore
keywords=[],
decorator_list=[],
taxon_id=self.taxon_id,
overload=True,
)


@app.command()
def main():
rules: defaultdict[int, Rule] = defaultdict(Rule)

# Build up a list of rules
type_info = requests.get(
"https://rest.uniprot.org/configure/idmapping/fields"
).json()
Expand All @@ -85,28 +124,45 @@ def main():
rule.taxon_id = rule_info["taxonId"]
rule.id = rule_info["ruleId"]

# Create a class that has one method overload per rule
module = ast.Module(
body=[
ast.ImportFrom(
module="typing_extensions",
names=[
ast.alias("Literal"),
ast.alias("TypedDict"),
ast.alias("overload"),
],
level=0,
),
ast.ImportFrom(
module="typing",
names=[
ast.alias("Iterable"),
ast.alias("Optional"),
],
level=0,
),
*[rule.to_typed_dict() for rule in rules.values()],
ast.ClassDef(
name="SubmitDummyClass",
body=[
*[rule.to_function() for rule in rules.values()],
make_function(
source_type=ast.Name("str"),
dest_type=ast.Name("str"),
taxon_id=True,
overload=False,
),
],
decorator_list=[],
bases=[],
keywords=[],
),
],
type_ignores=[],
)

# Produce the formatted output
print(
black.format_file_contents(
ast.unparse(ast.fix_missing_locations(module)),
Expand Down
30 changes: 22 additions & 8 deletions unipressed/id_mapping/core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Unpack
from typing import Callable, Iterable

import requests
from typing_extensions import Literal, TypeAlias, TypedDict, overload
from typing_extensions import Literal, ParamSpec, TypeAlias, TypedDict, TypeVar

import unipressed.id_mapping.types as id_types
from unipressed.util import iter_pages
Expand All @@ -17,6 +17,24 @@
"ERROR",
]

Param1 = ParamSpec("Param1")
Param2 = ParamSpec("Param2")
Ret1 = TypeVar("Ret1")
Ret2 = TypeVar("Ret2")


def copy_signature(
f: Callable[Param1, Ret1]
) -> Callable[[Callable[Param2, Ret2]], Callable[Param1, Ret2]]:
"""
Copies the argument signature from function f and applies it to the decorated function, but keeps the return value
"""

def _inner(f: Callable[Param2, Ret2]):
return f

return _inner # type: ignore


class IdMappingError(Exception):
pass
Expand All @@ -29,17 +47,13 @@ class IdMappingClient:
"""

@classmethod
def _submit(cls, source: From, dest: To, ids: Iterable[str]) -> requests.Response:
def _submit(cls, source: str, dest: str, ids: Iterable[str]) -> requests.Response:
return requests.post(
"https://rest.uniprot.org/idmapping/run",
data={"ids": ",".join(ids), "from": source, "to": dest},
)

@classmethod
@overload
def submit(cls, **kwargs: Unpack[id_types.Rule1]):
...

@copy_signature(id_types.SubmitDummyClass.submit)
@classmethod
def submit(
cls, source: str, dest: str, ids: Iterable[str], taxon_id: str | None = None
Expand Down
Loading

0 comments on commit 2ff846b

Please sign in to comment.