From 2ff846b1abb79a06ad703a6347f304189c5f1d26 Mon Sep 17 00:00:00 2001 From: Michael Milton Date: Wed, 14 Aug 2024 00:41:31 +1000 Subject: [PATCH] Use a dummy class to contain the type signature --- codegen/id_mapping/generate.py | 152 +++++++---- unipressed/id_mapping/core.py | 30 +- unipressed/id_mapping/types.py | 486 ++++++++++++++++++--------------- 3 files changed, 387 insertions(+), 281 deletions(-) diff --git a/codegen/id_mapping/generate.py b/codegen/id_mapping/generate.py index f595754..fe67c31 100644 --- a/codegen/id_mapping/generate.py +++ b/codegen/id_mapping/generate.py @@ -1,69 +1,107 @@ import ast from collections import defaultdict from dataclasses import dataclass, field -from typing import Sequence +from typing import List import black import requests import typer -from codegen.util import make_literal - app = typer.Typer(result_callback=lambda x: print(x)) +def make_function( + source_type: ast.expr, dest_type: ast.expr, taxon_id: bool, overload: bool = True +) -> ast.FunctionDef: + """ + Makes a `submit()` function definition, as used by the ID mapper + + Params: + source_type: Type of the `source` argument + dest_type: Type of the `dest` argument + taxon_id: If true, include the `taxon_id` parameter + overload: If true, this is a function overload + """ + args: List[ast.arg] = [ + ast.arg( + arg="cls", + ), + # source: Literal[...] + ast.arg( + arg="source", + annotation=source_type, + ), + # source: dest[...] + ast.arg( + arg="dest", + annotation=dest_type, + ), + # ids: Iterable[str] + ast.arg( + "ids", + ast.Subscript(ast.Name("Iterable"), ast.Name("str")), + ), + ] + defaults: list[ast.expr | None] = [None, None, None] + + if taxon_id: + # taxon_id: Optional[str] = None + args.append( + # taxon_id: bool + ast.arg( + "taxon_id", + annotation=ast.Subscript(ast.Name("Optional"), ast.Name("str")), + ) + ) + defaults.append(ast.Constant(None)) + + decorator_list: list[ast.expr] = [ + ast.Name("classmethod"), + ] + if overload: + decorator_list.append( + ast.Name("overload"), + ) + + return ast.FunctionDef( + name=f"submit", + args=ast.arguments( + posonlyargs=[], args=args, kwonlyargs=[], kw_defaults=[], defaults=defaults # type: ignore + ), + body=[ast.Expr(ast.Constant(value=...))], + decorator_list=decorator_list, + ) + + @dataclass class Rule: + """ + Represents a "rule" in the Uniprot API terminology, which is a method overload + in the Unipressed world. A rule is a set of allowed conversions from one database + to another. + """ + + #: Rule ID id: int = 0 + #: List of databases that can be converted to, in this rule tos: list[ast.Constant] = field(default_factory=list) + #: List of databases that can be converted from, in this rule froms: list[ast.Constant] = field(default_factory=list) + #: Whether this rule supports specifying the taxon ID taxon_id: bool = False - def to_typed_dict(self) -> ast.ClassDef: - body: Sequence[ast.AnnAssign] = [ - # source: Literal[...] - ast.AnnAssign( - ast.Name("source"), - ast.Subscript( - value=ast.Name("Literal"), - slice=ast.Tuple(elts=self.froms), # type: ignore - ), - simple=1, - ), - # dest: Literal[...] - ast.AnnAssign( - ast.Name("dest"), - ast.Subscript( - value=ast.Name("Literal"), - slice=ast.Tuple(elts=self.tos), # type: ignore - ), - simple=1, + def to_function(self) -> ast.FunctionDef: + return make_function( + source_type=ast.Subscript( + value=ast.Name("Literal"), + slice=ast.Tuple(elts=self.froms), ), - # ids: Iterable[str] - ast.AnnAssign( - ast.Name("ids"), - ast.Subscript(ast.Name("Iterable"), ast.Name("str")), - simple=1, + dest_type=ast.Subscript( + value=ast.Name("Literal"), + slice=ast.Tuple(elts=self.tos), # type: ignore ), - ] - - if self.taxon_id: - body.append( - # taxon_id: bool - ast.AnnAssign( - target=ast.Name("taxon_id"), - annotation=ast.Name("bool"), - value=None, - simple=1, - ) - ) - - return ast.ClassDef( - name=f"Rule{self.id}", - bases=[ast.Name("TypedDict")], - body=body, # type: ignore - keywords=[], - decorator_list=[], + taxon_id=self.taxon_id, + overload=True, ) @@ -71,6 +109,7 @@ def to_typed_dict(self) -> ast.ClassDef: def main(): rules: defaultdict[int, Rule] = defaultdict(Rule) + # Build up a list of rules type_info = requests.get( "https://rest.uniprot.org/configure/idmapping/fields" ).json() @@ -85,13 +124,14 @@ def main(): rule.taxon_id = rule_info["taxonId"] rule.id = rule_info["ruleId"] + # Create a class that has one method overload per rule module = ast.Module( body=[ ast.ImportFrom( module="typing_extensions", names=[ ast.alias("Literal"), - ast.alias("TypedDict"), + ast.alias("overload"), ], level=0, ), @@ -99,14 +139,30 @@ def main(): module="typing", names=[ ast.alias("Iterable"), + ast.alias("Optional"), ], level=0, ), - *[rule.to_typed_dict() for rule in rules.values()], + ast.ClassDef( + name="SubmitDummyClass", + body=[ + *[rule.to_function() for rule in rules.values()], + make_function( + source_type=ast.Name("str"), + dest_type=ast.Name("str"), + taxon_id=True, + overload=False, + ), + ], + decorator_list=[], + bases=[], + keywords=[], + ), ], type_ignores=[], ) + # Produce the formatted output print( black.format_file_contents( ast.unparse(ast.fix_missing_locations(module)), diff --git a/unipressed/id_mapping/core.py b/unipressed/id_mapping/core.py index 54d5d2b..a6a8327 100644 --- a/unipressed/id_mapping/core.py +++ b/unipressed/id_mapping/core.py @@ -1,10 +1,10 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Iterable, Unpack +from typing import Callable, Iterable import requests -from typing_extensions import Literal, TypeAlias, TypedDict, overload +from typing_extensions import Literal, ParamSpec, TypeAlias, TypedDict, TypeVar import unipressed.id_mapping.types as id_types from unipressed.util import iter_pages @@ -17,6 +17,24 @@ "ERROR", ] +Param1 = ParamSpec("Param1") +Param2 = ParamSpec("Param2") +Ret1 = TypeVar("Ret1") +Ret2 = TypeVar("Ret2") + + +def copy_signature( + f: Callable[Param1, Ret1] +) -> Callable[[Callable[Param2, Ret2]], Callable[Param1, Ret2]]: + """ + Copies the argument signature from function f and applies it to the decorated function, but keeps the return value + """ + + def _inner(f: Callable[Param2, Ret2]): + return f + + return _inner # type: ignore + class IdMappingError(Exception): pass @@ -29,17 +47,13 @@ class IdMappingClient: """ @classmethod - def _submit(cls, source: From, dest: To, ids: Iterable[str]) -> requests.Response: + def _submit(cls, source: str, dest: str, ids: Iterable[str]) -> requests.Response: return requests.post( "https://rest.uniprot.org/idmapping/run", data={"ids": ",".join(ids), "from": source, "to": dest}, ) - @classmethod - @overload - def submit(cls, **kwargs: Unpack[id_types.Rule1]): - ... - + @copy_signature(id_types.SubmitDummyClass.submit) @classmethod def submit( cls, source: str, dest: str, ids: Iterable[str], taxon_id: str | None = None diff --git a/unipressed/id_mapping/types.py b/unipressed/id_mapping/types.py index 774cd00..6f99f69 100644 --- a/unipressed/id_mapping/types.py +++ b/unipressed/id_mapping/types.py @@ -1,235 +1,271 @@ -from typing import Iterable +from typing import Iterable, Optional -from typing_extensions import Literal, TypedDict +from typing_extensions import Literal, overload -class Rule1(TypedDict): - source: Literal["UniProtKB_AC-ID",] - dest: Literal[ - "CCDS", - "PIR", - "PDB", - "BioGRID", - "ComplexPortal", - "DIP", - "STRING", - "ChEMBL", - "DrugBank", - "GuidetoPHARMACOLOGY", - "SwissLipids", - "Allergome", - "ESTHER", - "MEROPS", - "PeroxiBase", - "REBASE", - "TCDB", - "GlyConnect", - "BioMuta", - "DMDM", - "CPTAC", - "ProteomicsDB", - "DNASU", - "Ensembl", - "GeneID", - "KEGG", - "PATRIC", - "UCSC", - "WBParaSite", - "ArachnoServer", - "Araport", - "CGD", - "ConoServer", - "dictyBase", - "EchoBASE", - "euHCVdb", - "VEuPathDB", - "FlyBase", - "GeneCards", - "GeneReviews", - "HGNC", - "LegioList", - "Leproma", - "MaizeGDB", - "MGI", - "MIM", - "neXtProt", - "OpenTargets", - "Orphanet", - "PharmGKB", - "PomBase", - "PseudoCAP", - "RGD", - "SGD", - "TubercuList", - "VGNC", - "WormBase", - "Xenbase", - "ZFIN", - "eggNOG", - "GeneTree", - "HOGENOM", - "OMA", - "OrthoDB", - "TreeFam", - "BioCyc", - "Reactome", - "UniPathway", - "PlantReactome", - "ChiTaRS", - "GeneWiki", - "GenomeRNAi", - "PHI-base", - "CollecTF", - "IDEAL", - "DisProt", - "UniProtKB", - "UniProtKB-Swiss-Prot", - "UniParc", - "UniRef50", - "UniRef90", - "UniRef100", - "Gene_Name", - "CRC64", - "EMBL-GenBank-DDBJ", - "EMBL-GenBank-DDBJ_CDS", - "GI_number", - "RefSeq_Nucleotide", - "RefSeq_Protein", - "Ensembl_Protein", - "Ensembl_Transcript", - "Ensembl_Genomes", - "Ensembl_Genomes_Protein", - "Ensembl_Genomes_Transcript", - "WBParaSite_Transcript-Protein", - "WormBase_Protein", - "WormBase_Transcript", - ] - ids: Iterable[str] +class SubmitDummyClass: + @classmethod + @overload + def submit( + cls, + source: Literal["UniProtKB_AC-ID",], + dest: Literal[ + "CCDS", + "PIR", + "PDB", + "BioGRID", + "ComplexPortal", + "DIP", + "STRING", + "ChEMBL", + "DrugBank", + "GuidetoPHARMACOLOGY", + "SwissLipids", + "Allergome", + "ESTHER", + "MEROPS", + "PeroxiBase", + "REBASE", + "TCDB", + "GlyConnect", + "BioMuta", + "DMDM", + "CPTAC", + "ProteomicsDB", + "DNASU", + "Ensembl", + "GeneID", + "KEGG", + "PATRIC", + "UCSC", + "WBParaSite", + "ArachnoServer", + "Araport", + "CGD", + "ConoServer", + "dictyBase", + "EchoBASE", + "euHCVdb", + "VEuPathDB", + "FlyBase", + "GeneCards", + "GeneReviews", + "HGNC", + "LegioList", + "Leproma", + "MaizeGDB", + "MGI", + "MIM", + "neXtProt", + "OpenTargets", + "Orphanet", + "PharmGKB", + "PomBase", + "PseudoCAP", + "RGD", + "SGD", + "TubercuList", + "VGNC", + "WormBase", + "Xenbase", + "ZFIN", + "eggNOG", + "GeneTree", + "HOGENOM", + "OMA", + "OrthoDB", + "TreeFam", + "BioCyc", + "Reactome", + "UniPathway", + "PlantReactome", + "ChiTaRS", + "GeneWiki", + "GenomeRNAi", + "PHI-base", + "CollecTF", + "IDEAL", + "DisProt", + "UniProtKB", + "UniProtKB-Swiss-Prot", + "UniParc", + "UniRef50", + "UniRef90", + "UniRef100", + "Gene_Name", + "CRC64", + "EMBL-GenBank-DDBJ", + "EMBL-GenBank-DDBJ_CDS", + "GI_number", + "RefSeq_Nucleotide", + "RefSeq_Protein", + "Ensembl_Protein", + "Ensembl_Transcript", + "Ensembl_Genomes", + "Ensembl_Genomes_Protein", + "Ensembl_Genomes_Transcript", + "WBParaSite_Transcript-Protein", + "WormBase_Protein", + "WormBase_Transcript", + ], + ids: Iterable[str], + ): + ... + @classmethod + @overload + def submit( + cls, + source: Literal["UniParc",], + dest: Literal["UniProtKB", "UniProtKB-Swiss-Prot", "UniParc"], + ids: Iterable[str], + ): + ... -class Rule2(TypedDict): - source: Literal["UniParc",] - dest: Literal["UniProtKB", "UniProtKB-Swiss-Prot", "UniParc"] - ids: Iterable[str] + @classmethod + @overload + def submit( + cls, + source: Literal["UniRef50",], + dest: Literal["UniProtKB", "UniProtKB-Swiss-Prot", "UniRef50"], + ids: Iterable[str], + ): + ... + @classmethod + @overload + def submit( + cls, + source: Literal["UniRef90",], + dest: Literal["UniProtKB", "UniProtKB-Swiss-Prot", "UniRef90"], + ids: Iterable[str], + ): + ... -class Rule3(TypedDict): - source: Literal["UniRef50",] - dest: Literal["UniProtKB", "UniProtKB-Swiss-Prot", "UniRef50"] - ids: Iterable[str] + @classmethod + @overload + def submit( + cls, + source: Literal["UniRef100",], + dest: Literal["UniProtKB", "UniProtKB-Swiss-Prot", "UniRef100"], + ids: Iterable[str], + ): + ... + @classmethod + @overload + def submit( + cls, + source: Literal["Gene_Name",], + dest: Literal["UniProtKB", "UniProtKB-Swiss-Prot"], + ids: Iterable[str], + taxon_id: Optional[str] = None, + ): + ... -class Rule4(TypedDict): - source: Literal["UniRef90",] - dest: Literal["UniProtKB", "UniProtKB-Swiss-Prot", "UniRef90"] - ids: Iterable[str] + @classmethod + @overload + def submit( + cls, + source: Literal[ + "CRC64", + "CCDS", + "EMBL-GenBank-DDBJ", + "EMBL-GenBank-DDBJ_CDS", + "GI_number", + "PIR", + "RefSeq_Nucleotide", + "RefSeq_Protein", + "PDB", + "BioGRID", + "ComplexPortal", + "DIP", + "STRING", + "ChEMBL", + "DrugBank", + "GuidetoPHARMACOLOGY", + "SwissLipids", + "Allergome", + "ESTHER", + "MEROPS", + "PeroxiBase", + "REBASE", + "TCDB", + "GlyConnect", + "BioMuta", + "DMDM", + "CPTAC", + "ProteomicsDB", + "DNASU", + "Ensembl", + "Ensembl_Genomes", + "Ensembl_Genomes_Protein", + "Ensembl_Genomes_Transcript", + "Ensembl_Protein", + "Ensembl_Transcript", + "GeneID", + "KEGG", + "PATRIC", + "UCSC", + "WBParaSite", + "WBParaSite_Transcript-Protein", + "ArachnoServer", + "Araport", + "CGD", + "ConoServer", + "dictyBase", + "EchoBASE", + "euHCVdb", + "FlyBase", + "GeneCards", + "GeneReviews", + "HGNC", + "LegioList", + "Leproma", + "MaizeGDB", + "MGI", + "MIM", + "neXtProt", + "OpenTargets", + "Orphanet", + "PharmGKB", + "PomBase", + "PseudoCAP", + "RGD", + "SGD", + "TubercuList", + "VEuPathDB", + "VGNC", + "WormBase", + "WormBase_Protein", + "WormBase_Transcript", + "Xenbase", + "ZFIN", + "eggNOG", + "GeneTree", + "HOGENOM", + "OMA", + "OrthoDB", + "TreeFam", + "BioCyc", + "PlantReactome", + "Reactome", + "UniPathway", + "ChiTaRS", + "GeneWiki", + "GenomeRNAi", + "PHI-base", + "CollecTF", + "DisProt", + "IDEAL", + ], + dest: Literal["UniProtKB", "UniProtKB-Swiss-Prot"], + ids: Iterable[str], + ): + ... - -class Rule5(TypedDict): - source: Literal["UniRef100",] - dest: Literal["UniProtKB", "UniProtKB-Swiss-Prot", "UniRef100"] - ids: Iterable[str] - - -class Rule6(TypedDict): - source: Literal["Gene_Name",] - dest: Literal["UniProtKB", "UniProtKB-Swiss-Prot"] - ids: Iterable[str] - taxon_id: bool - - -class Rule7(TypedDict): - source: Literal[ - "CRC64", - "CCDS", - "EMBL-GenBank-DDBJ", - "EMBL-GenBank-DDBJ_CDS", - "GI_number", - "PIR", - "RefSeq_Nucleotide", - "RefSeq_Protein", - "PDB", - "BioGRID", - "ComplexPortal", - "DIP", - "STRING", - "ChEMBL", - "DrugBank", - "GuidetoPHARMACOLOGY", - "SwissLipids", - "Allergome", - "ESTHER", - "MEROPS", - "PeroxiBase", - "REBASE", - "TCDB", - "GlyConnect", - "BioMuta", - "DMDM", - "CPTAC", - "ProteomicsDB", - "DNASU", - "Ensembl", - "Ensembl_Genomes", - "Ensembl_Genomes_Protein", - "Ensembl_Genomes_Transcript", - "Ensembl_Protein", - "Ensembl_Transcript", - "GeneID", - "KEGG", - "PATRIC", - "UCSC", - "WBParaSite", - "WBParaSite_Transcript-Protein", - "ArachnoServer", - "Araport", - "CGD", - "ConoServer", - "dictyBase", - "EchoBASE", - "euHCVdb", - "FlyBase", - "GeneCards", - "GeneReviews", - "HGNC", - "LegioList", - "Leproma", - "MaizeGDB", - "MGI", - "MIM", - "neXtProt", - "OpenTargets", - "Orphanet", - "PharmGKB", - "PomBase", - "PseudoCAP", - "RGD", - "SGD", - "TubercuList", - "VEuPathDB", - "VGNC", - "WormBase", - "WormBase_Protein", - "WormBase_Transcript", - "Xenbase", - "ZFIN", - "eggNOG", - "GeneTree", - "HOGENOM", - "OMA", - "OrthoDB", - "TreeFam", - "BioCyc", - "PlantReactome", - "Reactome", - "UniPathway", - "ChiTaRS", - "GeneWiki", - "GenomeRNAi", - "PHI-base", - "CollecTF", - "DisProt", - "IDEAL", - ] - dest: Literal["UniProtKB", "UniProtKB-Swiss-Prot"] - ids: Iterable[str] + @classmethod + def submit( + cls, source: str, dest: str, ids: Iterable[str], taxon_id: Optional[str] = None + ): + ...