From f443584f994a7a6426197f9f3700ed55f6f8458a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 13 Dec 2023 11:05:03 -0500 Subject: [PATCH] finish strict typing for most modules Updated pep-484 typing to pass mypy "strict" mode, however including per-module qualifications for specific typing elements not yet complete. This allows us to catch specific typing issues that have been ongoing such as import symbols not properly exported. Fixes: #1377 Change-Id: I69db4d23460f02161ac771d5d591b2bc802b8ab1 --- alembic/autogenerate/__init__.py | 20 +++--- alembic/autogenerate/api.py | 10 +-- alembic/autogenerate/compare.py | 9 +-- alembic/autogenerate/render.py | 18 ++--- alembic/autogenerate/rewriter.py | 15 +++-- alembic/command.py | 4 +- alembic/config.py | 29 +++++--- alembic/context.pyi | 9 +-- alembic/ddl/__init__.py | 2 +- alembic/ddl/_autogen.py | 6 +- alembic/ddl/base.py | 21 +++--- alembic/ddl/impl.py | 7 +- alembic/ddl/mssql.py | 5 +- alembic/ddl/mysql.py | 8 ++- alembic/ddl/oracle.py | 5 +- alembic/ddl/postgresql.py | 25 ++++--- alembic/ddl/sqlite.py | 14 ++-- alembic/op.pyi | 54 ++++++++++++--- alembic/operations/base.py | 84 +++++++++++++++++++---- alembic/operations/batch.py | 15 ++--- alembic/operations/ops.py | 81 +++++++++++++--------- alembic/operations/schemaobj.py | 9 +-- alembic/operations/toimpl.py | 3 + alembic/runtime/environment.py | 8 +-- alembic/runtime/migration.py | 25 ++++--- alembic/script/base.py | 28 ++++---- alembic/script/revision.py | 47 ++++++++----- alembic/script/write_hooks.py | 3 + alembic/util/__init__.py | 62 ++++++++--------- alembic/util/compat.py | 33 ++++++--- alembic/util/langhelpers.py | 111 ++++++++++++++++++++++--------- alembic/util/messaging.py | 9 ++- alembic/util/pyfiles.py | 10 ++- alembic/util/sqla_compat.py | 57 ++++++++++++---- docs/build/unreleased/1377.rst | 9 +++ pyproject.toml | 12 ++-- setup.cfg | 15 ----- tools/write_pyi.py | 6 +- 38 files changed, 581 insertions(+), 307 deletions(-) create mode 100644 docs/build/unreleased/1377.rst diff --git a/alembic/autogenerate/__init__.py b/alembic/autogenerate/__init__.py index cd2ed1c1..445ddb25 100644 --- a/alembic/autogenerate/__init__.py +++ b/alembic/autogenerate/__init__.py @@ -1,10 +1,10 @@ -from .api import _render_migration_diffs -from .api import compare_metadata -from .api import produce_migrations -from .api import render_python_code -from .api import RevisionContext -from .compare import _produce_net_changes -from .compare import comparators -from .render import render_op_text -from .render import renderers -from .rewriter import Rewriter +from .api import _render_migration_diffs as _render_migration_diffs +from .api import compare_metadata as compare_metadata +from .api import produce_migrations as produce_migrations +from .api import render_python_code as render_python_code +from .api import RevisionContext as RevisionContext +from .compare import _produce_net_changes as _produce_net_changes +from .compare import comparators as comparators +from .render import render_op_text as render_op_text +from .render import renderers as renderers +from .rewriter import Rewriter as Rewriter diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index b7f43b19..aa8f32f6 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -28,6 +28,7 @@ from sqlalchemy.engine import Inspector from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.schema import SchemaItem + from sqlalchemy.sql.schema import Table from ..config import Config from ..operations.ops import DowngradeOps @@ -165,6 +166,7 @@ def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any: """ migration_script = produce_migrations(context, metadata) + assert migration_script.upgrade_ops is not None return migration_script.upgrade_ops.as_diffs() @@ -331,7 +333,7 @@ def __init__( self, migration_context: MigrationContext, metadata: Optional[MetaData] = None, - opts: Optional[dict] = None, + opts: Optional[Dict[str, Any]] = None, autogenerate: bool = True, ) -> None: if ( @@ -465,7 +467,7 @@ def run_object_filters( run_filters = run_object_filters @util.memoized_property - def sorted_tables(self): + def sorted_tables(self) -> List[Table]: """Return an aggregate of the :attr:`.MetaData.sorted_tables` collection(s). @@ -481,7 +483,7 @@ def sorted_tables(self): return result @util.memoized_property - def table_key_to_table(self): + def table_key_to_table(self) -> Dict[str, Table]: """Return an aggregate of the :attr:`.MetaData.tables` dictionaries. The :attr:`.MetaData.tables` collection is a dictionary of table key @@ -492,7 +494,7 @@ def table_key_to_table(self): objects contain the same table key, an exception is raised. """ - result = {} + result: Dict[str, Table] = {} for m in util.to_list(self.metadata): intersect = set(result).intersection(set(m.tables)) if intersect: diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index a50d8b81..fcef531a 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations import contextlib @@ -577,9 +580,7 @@ def _compare_indexes_and_uniques( # 5. index things by name, for those objects that have names metadata_names = { cast(str, c.md_name_to_sql_name(autogen_context)): c - for c in metadata_unique_constraints_sig.union( - metadata_indexes_sig # type:ignore[arg-type] - ) + for c in metadata_unique_constraints_sig.union(metadata_indexes_sig) if c.is_named } @@ -1240,7 +1241,7 @@ def _add_fk(obj, compare_to): obj.const, obj.name, "foreign_key_constraint", False, compare_to ): modify_table_ops.ops.append( - ops.CreateForeignKeyOp.from_constraint(const.const) + ops.CreateForeignKeyOp.from_constraint(const.const) # type: ignore[has-type] # noqa: E501 ) log.info( diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index 67cc8c33..317a6dbe 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations from io import StringIO @@ -849,7 +852,7 @@ def _render_Variant_type( ) -> str: base_type, variant_mapping = sqla_compat._get_variant_mapping(type_) base = _repr_type(base_type, autogen_context, _skip_variants=True) - assert base is not None and base is not False + assert base is not None and base is not False # type: ignore[comparison-overlap] # noqa:E501 for dialect in sorted(variant_mapping): typ = variant_mapping[dialect] base += ".with_variant(%s, %r)" % ( @@ -946,7 +949,7 @@ def _fk_colspec( won't fail if the remote table can't be resolved. """ - colspec = fk._get_colspec() # type:ignore[attr-defined] + colspec = fk._get_colspec() tokens = colspec.split(".") tname, colname = tokens[-2:] @@ -1016,8 +1019,7 @@ def _render_foreign_key( % { "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), "cols": ", ".join( - "%r" % _ident(cast("Column", f.parent).name) - for f in constraint.elements + repr(_ident(f.parent.name)) for f in constraint.elements ), "refcols": ", ".join( repr(_fk_colspec(f, apply_metadata_schema, namespace_metadata)) @@ -1058,12 +1060,10 @@ def _render_check_constraint( # ideally SQLAlchemy would give us more of a first class # way to detect this. if ( - constraint._create_rule # type:ignore[attr-defined] - and hasattr( - constraint._create_rule, "target" # type:ignore[attr-defined] - ) + constraint._create_rule + and hasattr(constraint._create_rule, "target") and isinstance( - constraint._create_rule.target, # type:ignore[attr-defined] + constraint._create_rule.target, sqltypes.TypeEngine, ) ): diff --git a/alembic/autogenerate/rewriter.py b/alembic/autogenerate/rewriter.py index 68a93dd0..02ff431c 100644 --- a/alembic/autogenerate/rewriter.py +++ b/alembic/autogenerate/rewriter.py @@ -16,12 +16,14 @@ from ..operations.ops import AddColumnOp from ..operations.ops import AlterColumnOp from ..operations.ops import CreateTableOp + from ..operations.ops import DowngradeOps from ..operations.ops import MigrateOperation from ..operations.ops import MigrationScript from ..operations.ops import ModifyTableOps from ..operations.ops import OpContainer - from ..runtime.environment import _GetRevArg + from ..operations.ops import UpgradeOps from ..runtime.migration import MigrationContext + from ..script.revision import _GetRevArg class Rewriter: @@ -101,7 +103,7 @@ def rewrites( Type[CreateTableOp], Type[ModifyTableOps], ], - ) -> Callable: + ) -> Callable[..., Any]: """Register a function as rewriter for a given type. The function should receive three arguments, which are @@ -156,7 +158,7 @@ def _traverse_script( revision: _GetRevArg, directive: MigrationScript, ) -> None: - upgrade_ops_list = [] + upgrade_ops_list: List[UpgradeOps] = [] for upgrade_ops in directive.upgrade_ops_list: ret = self._traverse_for(context, revision, upgrade_ops) if len(ret) != 1: @@ -164,9 +166,10 @@ def _traverse_script( "Can only return single object for UpgradeOps traverse" ) upgrade_ops_list.append(ret[0]) - directive.upgrade_ops = upgrade_ops_list - downgrade_ops_list = [] + directive.upgrade_ops = upgrade_ops_list # type: ignore + + downgrade_ops_list: List[DowngradeOps] = [] for downgrade_ops in directive.downgrade_ops_list: ret = self._traverse_for(context, revision, downgrade_ops) if len(ret) != 1: @@ -174,7 +177,7 @@ def _traverse_script( "Can only return single object for DowngradeOps traverse" ) downgrade_ops_list.append(ret[0]) - directive.downgrade_ops = downgrade_ops_list + directive.downgrade_ops = downgrade_ops_list # type: ignore @_traverse.dispatch_for(ops.OpContainer) def _traverse_op_container( diff --git a/alembic/command.py b/alembic/command.py index c5233e72..37aa6e67 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -1,3 +1,5 @@ +# mypy: allow-untyped-defs, allow-untyped-calls + from __future__ import annotations import os @@ -18,7 +20,7 @@ from .runtime.environment import ProcessRevisionDirectiveFn -def list_templates(config: Config): +def list_templates(config: Config) -> None: """List available templates. :param config: a :class:`.Config` object. diff --git a/alembic/config.py b/alembic/config.py index 55b5811a..4b2263fd 100644 --- a/alembic/config.py +++ b/alembic/config.py @@ -12,6 +12,7 @@ from typing import Mapping from typing import Optional from typing import overload +from typing import Sequence from typing import TextIO from typing import Union @@ -104,7 +105,7 @@ def __init__( stdout: TextIO = sys.stdout, cmd_opts: Optional[Namespace] = None, config_args: Mapping[str, Any] = util.immutabledict(), - attributes: Optional[dict] = None, + attributes: Optional[Dict[str, Any]] = None, ) -> None: """Construct a new :class:`.Config`""" self.config_file_name = file_ @@ -140,7 +141,7 @@ def __init__( """ @util.memoized_property - def attributes(self): + def attributes(self) -> Dict[str, Any]: """A Python dictionary for storage of additional state. @@ -159,7 +160,7 @@ def attributes(self): """ return {} - def print_stdout(self, text: str, *arg) -> None: + def print_stdout(self, text: str, *arg: Any) -> None: """Render a message to standard out. When :meth:`.Config.print_stdout` is called with additional args @@ -183,7 +184,7 @@ def print_stdout(self, text: str, *arg) -> None: util.write_outstream(self.stdout, output, "\n", **self.messaging_opts) @util.memoized_property - def file_config(self): + def file_config(self) -> ConfigParser: """Return the underlying ``ConfigParser`` object. Direct access to the .ini file is available here, @@ -321,7 +322,9 @@ def get_main_option( ) -> Optional[str]: ... - def get_main_option(self, name, default=None): + def get_main_option( + self, name: str, default: Optional[str] = None + ) -> Optional[str]: """Return an option from the 'main' section of the .ini file. This defaults to being a key from the ``[alembic]`` @@ -351,7 +354,9 @@ def __init__(self, prog: Optional[str] = None) -> None: self._generate_args(prog) def _generate_args(self, prog: Optional[str]) -> None: - def add_options(fn, parser, positional, kwargs): + def add_options( + fn: Any, parser: Any, positional: Any, kwargs: Any + ) -> None: kwargs_opts = { "template": ( "-t", @@ -554,7 +559,9 @@ def add_options(fn, parser, positional, kwargs): ) subparsers = parser.add_subparsers() - positional_translations = {command.stamp: {"revision": "revisions"}} + positional_translations: Dict[Any, Any] = { + command.stamp: {"revision": "revisions"} + } for fn in [getattr(command, n) for n in dir(command)]: if ( @@ -609,7 +616,7 @@ def run_cmd(self, config: Config, options: Namespace) -> None: else: util.err(str(e), **config.messaging_opts) - def main(self, argv=None): + def main(self, argv: Optional[Sequence[str]] = None) -> None: options = self.parser.parse_args(argv) if not hasattr(options, "cmd"): # see http://bugs.python.org/issue9253, argparse @@ -624,7 +631,11 @@ def main(self, argv=None): self.run_cmd(cfg, options) -def main(argv=None, prog=None, **kwargs): +def main( + argv: Optional[Sequence[str]] = None, + prog: Optional[str] = None, + **kwargs: Any, +) -> None: """The console runner function for Alembic.""" CommandLine(prog=prog).main(argv=argv) diff --git a/alembic/context.pyi b/alembic/context.pyi index e8d98210..80619fb2 100644 --- a/alembic/context.pyi +++ b/alembic/context.pyi @@ -160,8 +160,8 @@ def configure( MigrationContext, Column[Any], Column[Any], - TypeEngine, - TypeEngine, + TypeEngine[Any], + TypeEngine[Any], ], Optional[bool], ], @@ -636,7 +636,8 @@ def configure( """ def execute( - sql: Union[Executable, str], execution_options: Optional[dict] = None + sql: Union[Executable, str], + execution_options: Optional[Dict[str, Any]] = None, ) -> None: """Execute the given SQL using the current change context. @@ -805,7 +806,7 @@ def is_offline_mode() -> bool: """ -def is_transactional_ddl(): +def is_transactional_ddl() -> bool: """Return True if the context is configured to expect a transactional DDL capable backend. diff --git a/alembic/ddl/__init__.py b/alembic/ddl/__init__.py index cfcc47e0..f2f72b3d 100644 --- a/alembic/ddl/__init__.py +++ b/alembic/ddl/__init__.py @@ -3,4 +3,4 @@ from . import oracle from . import postgresql from . import sqlite -from .impl import DefaultImpl +from .impl import DefaultImpl as DefaultImpl diff --git a/alembic/ddl/_autogen.py b/alembic/ddl/_autogen.py index cc1a1fc4..e22153c4 100644 --- a/alembic/ddl/_autogen.py +++ b/alembic/ddl/_autogen.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations from typing import Any @@ -19,7 +22,6 @@ from sqlalchemy.sql.schema import UniqueConstraint from typing_extensions import TypeGuard -from alembic.ddl.base import _fk_spec from .. import util from ..util import sqla_compat @@ -275,7 +277,7 @@ def __init__( ondelete, deferrable, initially, - ) = _fk_spec(const) + ) = sqla_compat._fk_spec(const) self._sig: Tuple[Any, ...] = ( self.source_schema, diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py index 339db0c4..7a85a5c1 100644 --- a/alembic/ddl/base.py +++ b/alembic/ddl/base.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations import functools @@ -173,7 +176,7 @@ def __init__( self.comment = comment -@compiles(RenameTable) +@compiles(RenameTable) # type: ignore[misc] def visit_rename_table( element: RenameTable, compiler: DDLCompiler, **kw ) -> str: @@ -183,7 +186,7 @@ def visit_rename_table( ) -@compiles(AddColumn) +@compiles(AddColumn) # type: ignore[misc] def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str: return "%s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -191,7 +194,7 @@ def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str: ) -@compiles(DropColumn) +@compiles(DropColumn) # type: ignore[misc] def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str: return "%s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -199,7 +202,7 @@ def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str: ) -@compiles(ColumnNullable) +@compiles(ColumnNullable) # type: ignore[misc] def visit_column_nullable( element: ColumnNullable, compiler: DDLCompiler, **kw ) -> str: @@ -210,7 +213,7 @@ def visit_column_nullable( ) -@compiles(ColumnType) +@compiles(ColumnType) # type: ignore[misc] def visit_column_type(element: ColumnType, compiler: DDLCompiler, **kw) -> str: return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -219,7 +222,7 @@ def visit_column_type(element: ColumnType, compiler: DDLCompiler, **kw) -> str: ) -@compiles(ColumnName) +@compiles(ColumnName) # type: ignore[misc] def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str: return "%s RENAME %s TO %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -228,7 +231,7 @@ def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str: ) -@compiles(ColumnDefault) +@compiles(ColumnDefault) # type: ignore[misc] def visit_column_default( element: ColumnDefault, compiler: DDLCompiler, **kw ) -> str: @@ -241,7 +244,7 @@ def visit_column_default( ) -@compiles(ComputedColumnDefault) +@compiles(ComputedColumnDefault) # type: ignore[misc] def visit_computed_column( element: ComputedColumnDefault, compiler: DDLCompiler, **kw ): @@ -251,7 +254,7 @@ def visit_computed_column( ) -@compiles(IdentityColumnDefault) +@compiles(IdentityColumnDefault) # type: ignore[misc] def visit_identity_column( element: IdentityColumnDefault, compiler: DDLCompiler, **kw ): diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 571a3041..2e4f1ae9 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations import logging @@ -23,8 +26,8 @@ from . import _autogen from . import base -from ._autogen import _constraint_sig -from ._autogen import ComparisonResult +from ._autogen import _constraint_sig as _constraint_sig +from ._autogen import ComparisonResult as ComparisonResult from .. import util from ..util import sqla_compat diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py index 9b0fff88..baa43d5e 100644 --- a/alembic/ddl/mssql.py +++ b/alembic/ddl/mssql.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations import re @@ -9,7 +12,6 @@ from typing import Union from sqlalchemy import types as sqltypes -from sqlalchemy.ext.compiler import compiles from sqlalchemy.schema import Column from sqlalchemy.schema import CreateIndex from sqlalchemy.sql.base import Executable @@ -30,6 +32,7 @@ from .impl import DefaultImpl from .. import util from ..util import sqla_compat +from ..util.sqla_compat import compiles if TYPE_CHECKING: from typing import Literal diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py index 5a2af5ce..f312173e 100644 --- a/alembic/ddl/mysql.py +++ b/alembic/ddl/mysql.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations import re @@ -8,7 +11,6 @@ from sqlalchemy import schema from sqlalchemy import types as sqltypes -from sqlalchemy.ext.compiler import compiles from .base import alter_table from .base import AlterColumn @@ -23,6 +25,7 @@ from ..util import sqla_compat from ..util.sqla_compat import _is_mariadb from ..util.sqla_compat import _is_type_bound +from ..util.sqla_compat import compiles if TYPE_CHECKING: from typing import Literal @@ -160,8 +163,7 @@ def _is_mysql_allowed_functional_default( ) -> bool: return ( type_ is not None - and type_._type_affinity # type:ignore[attr-defined] - is sqltypes.DateTime + and type_._type_affinity is sqltypes.DateTime and server_default is not None ) diff --git a/alembic/ddl/oracle.py b/alembic/ddl/oracle.py index e56bb210..54011740 100644 --- a/alembic/ddl/oracle.py +++ b/alembic/ddl/oracle.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations import re @@ -5,7 +8,6 @@ from typing import Optional from typing import TYPE_CHECKING -from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import sqltypes from .base import AddColumn @@ -22,6 +24,7 @@ from .base import IdentityColumnDefault from .base import RenameTable from .impl import DefaultImpl +from ..util.sqla_compat import compiles if TYPE_CHECKING: from sqlalchemy.dialects.oracle.base import OracleDDLCompiler diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index 68628c8e..6507fcbd 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations import logging @@ -30,7 +33,6 @@ from .base import alter_table from .base import AlterColumn from .base import ColumnComment -from .base import compiles from .base import format_column_name from .base import format_table_name from .base import format_type @@ -45,6 +47,7 @@ from ..operations.base import BatchOperations from ..operations.base import Operations from ..util import sqla_compat +from ..util.sqla_compat import compiles if TYPE_CHECKING: from typing import Literal @@ -136,7 +139,9 @@ def compare_server_default( metadata_default = literal_column(metadata_default) # run a real compare against the server - return not self.connection.scalar( + conn = self.connection + assert conn is not None + return not conn.scalar( sqla_compat._select( literal_column(conn_col_default) == metadata_default ) @@ -623,9 +628,8 @@ def from_constraint( # type:ignore[override] return cls( constraint.name, constraint_table.name, - [ - (expr, op) - for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa + [ # type: ignore + (expr, op) for expr, name, op in constraint._render_exprs ], where=cast("ColumnElement[bool] | None", constraint.where), schema=constraint_table.schema, @@ -652,7 +656,7 @@ def to_constraint( expr, name, oper, - ) in excl._render_exprs: # type:ignore[attr-defined] + ) in excl._render_exprs: t.append_column(Column(name, NULLTYPE)) t.append_constraint(excl) return excl @@ -710,7 +714,7 @@ def batch_create_exclude_constraint( constraint_name: str, *elements: Any, **kw: Any, - ): + ) -> Optional[Table]: """Issue a "create exclude constraint" instruction using the current batch migration context. @@ -782,10 +786,13 @@ def do_expr_where_opts(): args = [ "(%s, %r)" % ( - _render_potential_column(sqltext, autogen_context), + _render_potential_column( + sqltext, # type:ignore[arg-type] + autogen_context, + ), opstring, ) - for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa + for sqltext, name, opstring in constraint._render_exprs ] if constraint.where is not None: args.append( diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py index c6186c60..762e8ca1 100644 --- a/alembic/ddl/sqlite.py +++ b/alembic/ddl/sqlite.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations import re @@ -11,13 +14,13 @@ from sqlalchemy import JSON from sqlalchemy import schema from sqlalchemy import sql -from sqlalchemy.ext.compiler import compiles from .base import alter_table from .base import format_table_name from .base import RenameTable from .impl import DefaultImpl from .. import util +from ..util.sqla_compat import compiles if TYPE_CHECKING: from sqlalchemy.engine.reflection import Inspector @@ -71,13 +74,13 @@ def requires_recreate_in_batch( def add_constraint(self, const: Constraint): # attempt to distinguish between an # auto-gen constraint and an explicit one - if const._create_rule is None: # type:ignore[attr-defined] + if const._create_rule is None: raise NotImplementedError( "No support for ALTER of constraints in SQLite dialect. " "Please refer to the batch mode feature which allows for " "SQLite migrations using a copy-and-move strategy." ) - elif const._create_rule(self): # type:ignore[attr-defined] + elif const._create_rule(self): util.warn( "Skipping unsupported ALTER for " "creation of implicit constraint. " @@ -86,7 +89,7 @@ def add_constraint(self, const: Constraint): ) def drop_constraint(self, const: Constraint): - if const._create_rule is None: # type:ignore[attr-defined] + if const._create_rule is None: raise NotImplementedError( "No support for ALTER of constraints in SQLite dialect. " "Please refer to the batch mode feature which allows for " @@ -177,8 +180,7 @@ def cast_for_batch_migrate( new_type: TypeEngine, ) -> None: if ( - existing.type._type_affinity # type:ignore[attr-defined] - is not new_type._type_affinity # type:ignore[attr-defined] + existing.type._type_affinity is not new_type._type_affinity and not isinstance(new_type, JSON) ): existing_transfer["expr"] = cast( diff --git a/alembic/op.pyi b/alembic/op.pyi index 944b5ae1..83deac1e 100644 --- a/alembic/op.pyi +++ b/alembic/op.pyi @@ -12,6 +12,7 @@ from typing import List from typing import Literal from typing import Mapping from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import Type @@ -35,12 +36,28 @@ if TYPE_CHECKING: from sqlalchemy.sql.type_api import TypeEngine from sqlalchemy.util import immutabledict - from .operations.ops import BatchOperations + from .operations.base import BatchOperations + from .operations.ops import AddColumnOp + from .operations.ops import AddConstraintOp + from .operations.ops import AlterColumnOp + from .operations.ops import AlterTableOp + from .operations.ops import BulkInsertOp + from .operations.ops import CreateIndexOp + from .operations.ops import CreateTableCommentOp + from .operations.ops import CreateTableOp + from .operations.ops import DropColumnOp + from .operations.ops import DropConstraintOp + from .operations.ops import DropIndexOp + from .operations.ops import DropTableCommentOp + from .operations.ops import DropTableOp + from .operations.ops import ExecuteSQLOp from .operations.ops import MigrateOperation from .runtime.migration import MigrationContext from .util.sqla_compat import _literal_bindparam _T = TypeVar("_T") +_C = TypeVar("_C", bound=Callable[..., Any]) + ### end imports ### def add_column( @@ -132,8 +149,8 @@ def alter_column( comment: Union[str, Literal[False], None] = False, server_default: Any = False, new_column_name: Optional[str] = None, - type_: Union[TypeEngine, Type[TypeEngine], None] = None, - existing_type: Union[TypeEngine, Type[TypeEngine], None] = None, + type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None, + existing_type: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None, existing_server_default: Union[ str, bool, Identity, Computed, None ] = False, @@ -230,7 +247,7 @@ def batch_alter_table( table_name: str, schema: Optional[str] = None, recreate: Literal["auto", "always", "never"] = "auto", - partial_reordering: Optional[tuple] = None, + partial_reordering: Optional[Tuple[Any, ...]] = None, copy_from: Optional[Table] = None, table_args: Tuple[Any, ...] = (), table_kwargs: Mapping[str, Any] = immutabledict({}), @@ -377,7 +394,7 @@ def batch_alter_table( def bulk_insert( table: Union[Table, TableClause], - rows: List[dict], + rows: List[Dict[str, Any]], *, multiinsert: bool = True, ) -> None: @@ -1162,7 +1179,7 @@ def get_context() -> MigrationContext: """ -def implementation_for(op_cls: Any) -> Callable[..., Any]: +def implementation_for(op_cls: Any) -> Callable[[_C], _C]: """Register an implementation for a given :class:`.MigrateOperation`. This is part of the operation extensibility API. @@ -1174,7 +1191,7 @@ def implementation_for(op_cls: Any) -> Callable[..., Any]: """ def inline_literal( - value: Union[str, int], type_: Optional[TypeEngine] = None + value: Union[str, int], type_: Optional[TypeEngine[Any]] = None ) -> _literal_bindparam: r"""Produce an 'inline literal' expression, suitable for using in an INSERT, UPDATE, or DELETE statement. @@ -1218,6 +1235,27 @@ def inline_literal( """ +@overload +def invoke(operation: CreateTableOp) -> Table: ... +@overload +def invoke( + operation: Union[ + AddConstraintOp, + DropConstraintOp, + CreateIndexOp, + DropIndexOp, + AddColumnOp, + AlterColumnOp, + AlterTableOp, + CreateTableCommentOp, + DropTableCommentOp, + DropColumnOp, + BulkInsertOp, + DropTableOp, + ExecuteSQLOp, + ] +) -> None: ... +@overload def invoke(operation: MigrateOperation) -> Any: """Given a :class:`.MigrateOperation`, invoke it in terms of this :class:`.Operations` instance. @@ -1226,7 +1264,7 @@ def invoke(operation: MigrateOperation) -> Any: def register_operation( name: str, sourcename: Optional[str] = None -) -> Callable[[_T], _T]: +) -> Callable[[Type[_T]], Type[_T]]: """Register a new operation for this class. This method is normally used to add new operations diff --git a/alembic/operations/base.py b/alembic/operations/base.py index e3207be7..bafe441a 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -1,3 +1,5 @@ +# mypy: allow-untyped-calls + from __future__ import annotations from contextlib import contextmanager @@ -10,7 +12,9 @@ from typing import Iterator from typing import List # noqa from typing import Mapping +from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence # noqa from typing import Tuple from typing import Type # noqa @@ -47,12 +51,28 @@ from sqlalchemy.types import TypeEngine from .batch import BatchOperationsImpl + from .ops import AddColumnOp + from .ops import AddConstraintOp + from .ops import AlterColumnOp + from .ops import AlterTableOp + from .ops import BulkInsertOp + from .ops import CreateIndexOp + from .ops import CreateTableCommentOp + from .ops import CreateTableOp + from .ops import DropColumnOp + from .ops import DropConstraintOp + from .ops import DropIndexOp + from .ops import DropTableCommentOp + from .ops import DropTableOp + from .ops import ExecuteSQLOp from .ops import MigrateOperation from ..ddl import DefaultImpl from ..runtime.migration import MigrationContext __all__ = ("Operations", "BatchOperations") _T = TypeVar("_T") +_C = TypeVar("_C", bound=Callable[..., Any]) + class AbstractOperations(util.ModuleClsProxy): """Base class for Operations and BatchOperations. @@ -86,7 +106,7 @@ def __init__( @classmethod def register_operation( cls, name: str, sourcename: Optional[str] = None - ) -> Callable[[_T], _T]: + ) -> Callable[[Type[_T]], Type[_T]]: """Register a new operation for this class. This method is normally used to add new operations @@ -103,7 +123,7 @@ def register_operation( """ - def register(op_cls): + def register(op_cls: Type[_T]) -> Type[_T]: if sourcename is None: fn = getattr(op_cls, name) source_name = fn.__name__ @@ -122,8 +142,11 @@ def register(op_cls): *spec, formatannotation=formatannotation_fwdref ) num_defaults = len(spec[3]) if spec[3] else 0 + + defaulted_vals: Tuple[Any, ...] + if num_defaults: - defaulted_vals = name_args[0 - num_defaults :] + defaulted_vals = tuple(name_args[0 - num_defaults :]) else: defaulted_vals = () @@ -164,7 +187,7 @@ def %(name)s%(args)s: globals_ = dict(globals()) globals_.update({"op_cls": op_cls}) - lcl = {} + lcl: Dict[str, Any] = {} exec(func_text, globals_, lcl) setattr(cls, name, lcl[name]) @@ -180,7 +203,7 @@ def %(name)s%(args)s: return register @classmethod - def implementation_for(cls, op_cls: Any) -> Callable[..., Any]: + def implementation_for(cls, op_cls: Any) -> Callable[[_C], _C]: """Register an implementation for a given :class:`.MigrateOperation`. This is part of the operation extensibility API. @@ -191,7 +214,7 @@ def implementation_for(cls, op_cls: Any) -> Callable[..., Any]: """ - def decorate(fn): + def decorate(fn: _C) -> _C: cls._to_impl.dispatch_for(op_cls)(fn) return fn @@ -213,7 +236,7 @@ def batch_alter_table( table_name: str, schema: Optional[str] = None, recreate: Literal["auto", "always", "never"] = "auto", - partial_reordering: Optional[tuple] = None, + partial_reordering: Optional[Tuple[Any, ...]] = None, copy_from: Optional[Table] = None, table_args: Tuple[Any, ...] = (), table_kwargs: Mapping[str, Any] = util.immutabledict(), @@ -382,6 +405,35 @@ def get_context(self) -> MigrationContext: return self.migration_context + @overload + def invoke(self, operation: CreateTableOp) -> Table: + ... + + @overload + def invoke( + self, + operation: Union[ + AddConstraintOp, + DropConstraintOp, + CreateIndexOp, + DropIndexOp, + AddColumnOp, + AlterColumnOp, + AlterTableOp, + CreateTableCommentOp, + DropTableCommentOp, + DropColumnOp, + BulkInsertOp, + DropTableOp, + ExecuteSQLOp, + ], + ) -> None: + ... + + @overload + def invoke(self, operation: MigrateOperation) -> Any: + ... + def invoke(self, operation: MigrateOperation) -> Any: """Given a :class:`.MigrateOperation`, invoke it in terms of this :class:`.Operations` instance. @@ -659,8 +711,10 @@ def alter_column( comment: Union[str, Literal[False], None] = False, server_default: Any = False, new_column_name: Optional[str] = None, - type_: Union[TypeEngine, Type[TypeEngine], None] = None, - existing_type: Union[TypeEngine, Type[TypeEngine], None] = None, + type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None, + existing_type: Union[ + TypeEngine[Any], Type[TypeEngine[Any]], None + ] = None, existing_server_default: Union[ str, bool, Identity, Computed, None ] = False, @@ -756,7 +810,7 @@ def alter_column( def bulk_insert( self, table: Union[Table, TableClause], - rows: List[dict], + rows: List[Dict[str, Any]], *, multiinsert: bool = True, ) -> None: @@ -1560,7 +1614,7 @@ class BatchOperations(AbstractOperations): impl: BatchOperationsImpl - def _noop(self, operation): + def _noop(self, operation: Any) -> NoReturn: raise NotImplementedError( "The %s method does not apply to a batch table alter operation." % operation @@ -1596,8 +1650,10 @@ def alter_column( comment: Union[str, Literal[False], None] = False, server_default: Any = False, new_column_name: Optional[str] = None, - type_: Union[TypeEngine, Type[TypeEngine], None] = None, - existing_type: Union[TypeEngine, Type[TypeEngine], None] = None, + type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None, + existing_type: Union[ + TypeEngine[Any], Type[TypeEngine[Any]], None + ] = None, existing_server_default: Union[ str, bool, Identity, Computed, None ] = False, @@ -1652,7 +1708,7 @@ def create_check_constraint( def create_exclude_constraint( self, constraint_name: str, *elements: Any, **kw: Any - ): + ) -> Optional[Table]: """Issue a "create exclude constraint" instruction using the current batch migration context. diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py index 8c88e885..fd7ab990 100644 --- a/alembic/operations/batch.py +++ b/alembic/operations/batch.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations from typing import Any @@ -17,7 +20,7 @@ from sqlalchemy import schema as sql_schema from sqlalchemy import Table from sqlalchemy import types as sqltypes -from sqlalchemy.events import SchemaEventTarget +from sqlalchemy.sql.schema import SchemaEventTarget from sqlalchemy.util import OrderedDict from sqlalchemy.util import topological @@ -374,7 +377,7 @@ def _gather_indexes_from_both_tables(self) -> List[Index]: for idx_existing in self.indexes.values(): # this is a lift-and-move from Table.to_metadata - if idx_existing._column_flag: # type: ignore + if idx_existing._column_flag: continue idx_copy = Index( @@ -403,9 +406,7 @@ def _gather_indexes_from_both_tables(self) -> List[Index]: def _setup_referent( self, metadata: MetaData, constraint: ForeignKeyConstraint ) -> None: - spec = constraint.elements[ - 0 - ]._get_colspec() # type:ignore[attr-defined] + spec = constraint.elements[0]._get_colspec() parts = spec.split(".") tname = parts[-2] if len(parts) == 3: @@ -546,9 +547,7 @@ def alter_column( else: sql_schema.DefaultClause( server_default # type: ignore[arg-type] - )._set_parent( # type:ignore[attr-defined] - existing - ) + )._set_parent(existing) if autoincrement is not None: existing.autoincrement = bool(autoincrement) diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index 07b3e574..7b65191c 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -5,6 +5,7 @@ from typing import Any from typing import Callable from typing import cast +from typing import Dict from typing import FrozenSet from typing import Iterator from typing import List @@ -15,6 +16,7 @@ from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from sqlalchemy.types import NULLTYPE @@ -53,6 +55,9 @@ from ..runtime.migration import MigrationContext from ..script.revision import _RevIdType +_T = TypeVar("_T", bound=Any) +_AC = TypeVar("_AC", bound="AddConstraintOp") + class MigrateOperation: """base class for migration command and organization objects. @@ -70,7 +75,7 @@ class MigrateOperation: """ @util.memoized_property - def info(self): + def info(self) -> Dict[Any, Any]: """A dictionary that may be used to store arbitrary information along with this :class:`.MigrateOperation` object. @@ -92,12 +97,14 @@ class AddConstraintOp(MigrateOperation): add_constraint_ops = util.Dispatcher() @property - def constraint_type(self): + def constraint_type(self) -> str: raise NotImplementedError() @classmethod - def register_add_constraint(cls, type_: str) -> Callable: - def go(klass): + def register_add_constraint( + cls, type_: str + ) -> Callable[[Type[_AC]], Type[_AC]]: + def go(klass: Type[_AC]) -> Type[_AC]: cls.add_constraint_ops.dispatch_for(type_)(klass.from_constraint) return klass @@ -105,7 +112,7 @@ def go(klass): @classmethod def from_constraint(cls, constraint: Constraint) -> AddConstraintOp: - return cls.add_constraint_ops.dispatch(constraint.__visit_name__)( + return cls.add_constraint_ops.dispatch(constraint.__visit_name__)( # type: ignore[no-any-return] # noqa: E501 constraint ) @@ -398,7 +405,7 @@ def from_constraint( uq_constraint = cast("UniqueConstraint", constraint) - kw: dict = {} + kw: Dict[str, Any] = {} if uq_constraint.deferrable: kw["deferrable"] = uq_constraint.deferrable if uq_constraint.initially: @@ -532,7 +539,7 @@ def to_diff_tuple(self) -> Tuple[str, ForeignKeyConstraint]: @classmethod def from_constraint(cls, constraint: Constraint) -> CreateForeignKeyOp: fk_constraint = cast("ForeignKeyConstraint", constraint) - kw: dict = {} + kw: Dict[str, Any] = {} if fk_constraint.onupdate: kw["onupdate"] = fk_constraint.onupdate if fk_constraint.ondelete: @@ -897,7 +904,7 @@ def to_diff_tuple(self) -> Tuple[str, Index]: def from_index(cls, index: Index) -> CreateIndexOp: assert index.table is not None return cls( - index.name, # type: ignore[arg-type] + index.name, index.table.name, index.expressions, schema=index.table.schema, @@ -1183,7 +1190,7 @@ def from_table( return cls( table.name, - list(table.c) + list(table.constraints), # type:ignore[arg-type] + list(table.c) + list(table.constraints), schema=table.schema, _namespace_metadata=_namespace_metadata, # given a Table() object, this Table will contain full Index() @@ -1535,7 +1542,7 @@ def batch_create_table_comment( ) return operations.invoke(op) - def reverse(self): + def reverse(self) -> Union[CreateTableCommentOp, DropTableCommentOp]: """Reverses the COMMENT ON operation against a table.""" if self.existing_comment is None: return DropTableCommentOp( @@ -1551,14 +1558,16 @@ def reverse(self): schema=self.schema, ) - def to_table(self, migration_context=None): + def to_table( + self, migration_context: Optional[MigrationContext] = None + ) -> Table: schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.table( self.table_name, schema=self.schema, comment=self.comment ) - def to_diff_tuple(self): + def to_diff_tuple(self) -> Tuple[Any, ...]: return ("add_table_comment", self.to_table(), self.existing_comment) @@ -1630,18 +1639,20 @@ def batch_drop_table_comment( ) return operations.invoke(op) - def reverse(self): + def reverse(self) -> CreateTableCommentOp: """Reverses the COMMENT ON operation against a table.""" return CreateTableCommentOp( self.table_name, self.existing_comment, schema=self.schema ) - def to_table(self, migration_context=None): + def to_table( + self, migration_context: Optional[MigrationContext] = None + ) -> Table: schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.table(self.table_name, schema=self.schema) - def to_diff_tuple(self): + def to_diff_tuple(self) -> Tuple[Any, ...]: return ("remove_table_comment", self.to_table()) @@ -1818,8 +1829,10 @@ def alter_column( comment: Optional[Union[str, Literal[False]]] = False, server_default: Any = False, new_column_name: Optional[str] = None, - type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, - existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, + type_: Optional[Union[TypeEngine[Any], Type[TypeEngine[Any]]]] = None, + existing_type: Optional[ + Union[TypeEngine[Any], Type[TypeEngine[Any]]] + ] = None, existing_server_default: Optional[ Union[str, bool, Identity, Computed] ] = False, @@ -1939,8 +1952,10 @@ def batch_alter_column( comment: Optional[Union[str, Literal[False]]] = False, server_default: Any = False, new_column_name: Optional[str] = None, - type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, - existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, + type_: Optional[Union[TypeEngine[Any], Type[TypeEngine[Any]]]] = None, + existing_type: Optional[ + Union[TypeEngine[Any], Type[TypeEngine[Any]]] + ] = None, existing_server_default: Optional[ Union[str, bool, Identity, Computed] ] = False, @@ -2020,11 +2035,11 @@ def to_diff_tuple( ) -> Tuple[str, Optional[str], str, Column[Any]]: return ("add_column", self.schema, self.table_name, self.column) - def to_column(self) -> Column: + def to_column(self) -> Column[Any]: return self.column @classmethod - def from_column(cls, col: Column) -> AddColumnOp: + def from_column(cls, col: Column[Any]) -> AddColumnOp: return cls(col.table.name, col, schema=col.table.schema) @classmethod @@ -2215,7 +2230,7 @@ def from_column_and_tablename( def to_column( self, migration_context: Optional[MigrationContext] = None - ) -> Column: + ) -> Column[Any]: if self._reverse is not None: return self._reverse.column schema_obj = schemaobj.SchemaObjects(migration_context) @@ -2299,7 +2314,7 @@ class BulkInsertOp(MigrateOperation): def __init__( self, table: Union[Table, TableClause], - rows: List[dict], + rows: List[Dict[str, Any]], *, multiinsert: bool = True, ) -> None: @@ -2312,7 +2327,7 @@ def bulk_insert( cls, operations: Operations, table: Union[Table, TableClause], - rows: List[dict], + rows: List[Dict[str, Any]], *, multiinsert: bool = True, ) -> None: @@ -2608,7 +2623,7 @@ def __init__( self.upgrade_token = upgrade_token def reverse_into(self, downgrade_ops: DowngradeOps) -> DowngradeOps: - downgrade_ops.ops[:] = list( # type:ignore[index] + downgrade_ops.ops[:] = list( reversed([op.reverse() for op in self.ops]) ) return downgrade_ops @@ -2635,7 +2650,7 @@ def __init__( super().__init__(ops=ops) self.downgrade_token = downgrade_token - def reverse(self): + def reverse(self) -> UpgradeOps: return UpgradeOps( ops=list(reversed([op.reverse() for op in self.ops])) ) @@ -2666,6 +2681,8 @@ class MigrationScript(MigrateOperation): """ _needs_render: Optional[bool] + _upgrade_ops: List[UpgradeOps] + _downgrade_ops: List[DowngradeOps] def __init__( self, @@ -2693,7 +2710,7 @@ def __init__( self.downgrade_ops = downgrade_ops @property - def upgrade_ops(self): + def upgrade_ops(self) -> Optional[UpgradeOps]: """An instance of :class:`.UpgradeOps`. .. seealso:: @@ -2712,13 +2729,15 @@ def upgrade_ops(self): return self._upgrade_ops[0] @upgrade_ops.setter - def upgrade_ops(self, upgrade_ops): + def upgrade_ops( + self, upgrade_ops: Union[UpgradeOps, List[UpgradeOps]] + ) -> None: self._upgrade_ops = util.to_list(upgrade_ops) for elem in self._upgrade_ops: assert isinstance(elem, UpgradeOps) @property - def downgrade_ops(self): + def downgrade_ops(self) -> Optional[DowngradeOps]: """An instance of :class:`.DowngradeOps`. .. seealso:: @@ -2737,7 +2756,9 @@ def downgrade_ops(self): return self._downgrade_ops[0] @downgrade_ops.setter - def downgrade_ops(self, downgrade_ops): + def downgrade_ops( + self, downgrade_ops: Union[DowngradeOps, List[DowngradeOps]] + ) -> None: self._downgrade_ops = util.to_list(downgrade_ops) for elem in self._downgrade_ops: assert isinstance(elem, DowngradeOps) diff --git a/alembic/operations/schemaobj.py b/alembic/operations/schemaobj.py index 799f1139..32b26e9b 100644 --- a/alembic/operations/schemaobj.py +++ b/alembic/operations/schemaobj.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations from typing import Any @@ -274,10 +277,8 @@ def _ensure_table_for_fk(self, metadata: MetaData, fk: ForeignKey) -> None: ForeignKey. """ - if isinstance(fk._colspec, str): # type:ignore[attr-defined] - table_key, cname = fk._colspec.rsplit( # type:ignore[attr-defined] - ".", 1 - ) + if isinstance(fk._colspec, str): + table_key, cname = fk._colspec.rsplit(".", 1) sname, tname = self._parse_table_key(table_key) if table_key not in metadata.tables: rel_t = sa_schema.Table(tname, metadata, schema=sname) diff --git a/alembic/operations/toimpl.py b/alembic/operations/toimpl.py index ff77ab75..4759f7fd 100644 --- a/alembic/operations/toimpl.py +++ b/alembic/operations/toimpl.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from typing import TYPE_CHECKING from sqlalchemy import schema as sa_schema diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index 34ae1847..d64b2adc 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -228,9 +228,9 @@ def is_offline_mode(self) -> bool: has been configured. """ - return self.context_opts.get("as_sql", False) + return self.context_opts.get("as_sql", False) # type: ignore[no-any-return] # noqa: E501 - def is_transactional_ddl(self): + def is_transactional_ddl(self) -> bool: """Return True if the context is configured to expect a transactional DDL capable backend. @@ -339,7 +339,7 @@ def get_tag_argument(self) -> Optional[str]: line. """ - return self.context_opts.get("tag", None) + return self.context_opts.get("tag", None) # type: ignore[no-any-return] # noqa: E501 @overload def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]: @@ -950,7 +950,7 @@ def run_migrations(self, **kw: Any) -> None: def execute( self, sql: Union[Executable, str], - execution_options: Optional[dict] = None, + execution_options: Optional[Dict[str, Any]] = None, ) -> None: """Execute the given SQL using the current change context. diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 24e3d644..10a632bb 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations from contextlib import contextmanager @@ -521,7 +524,7 @@ def get_current_heads(self) -> Tuple[str, ...]: start_from_rev = None elif start_from_rev is not None and self.script: start_from_rev = [ - cast("Script", self.script.get_revision(sfr)).revision + self.script.get_revision(sfr).revision for sfr in util.to_list(start_from_rev) if sfr not in (None, "base") ] @@ -652,7 +655,7 @@ def _in_connection_transaction(self) -> bool: def execute( self, sql: Union[Executable, str], - execution_options: Optional[dict] = None, + execution_options: Optional[Dict[str, Any]] = None, ) -> None: """Execute a SQL construct or string statement. @@ -1000,6 +1003,12 @@ class MigrationStep: is_upgrade: bool migration_fn: Any + if TYPE_CHECKING: + + @property + def doc(self) -> Optional[str]: + ... + @property def name(self) -> str: return self.migration_fn.__name__ @@ -1048,13 +1057,9 @@ def __init__( self.revision = revision self.is_upgrade = is_upgrade if is_upgrade: - self.migration_fn = ( - revision.module.upgrade # type:ignore[attr-defined] - ) + self.migration_fn = revision.module.upgrade else: - self.migration_fn = ( - revision.module.downgrade # type:ignore[attr-defined] - ) + self.migration_fn = revision.module.downgrade def __repr__(self): return "RevisionStep(%r, is_upgrade=%r)" % ( @@ -1070,7 +1075,7 @@ def __eq__(self, other: object) -> bool: ) @property - def doc(self) -> str: + def doc(self) -> Optional[str]: return self.revision.doc @property @@ -1283,7 +1288,7 @@ def stamp_revision(self, **kw: Any) -> None: def __eq__(self, other): return ( isinstance(other, StampStep) - and other.from_revisions == self.revisions + and other.from_revisions == self.from_revisions and other.to_revisions == self.to_revisions and other.branch_move == self.branch_move and self.is_upgrade == other.is_upgrade diff --git a/alembic/script/base.py b/alembic/script/base.py index 5766d838..5945ca59 100644 --- a/alembic/script/base.py +++ b/alembic/script/base.py @@ -41,7 +41,7 @@ from zoneinfo import ZoneInfoNotFoundError else: from backports.zoneinfo import ZoneInfo # type: ignore[import-not-found,no-redef] # noqa: E501 - from backports.zoneinfo import ZoneInfoNotFoundError # type: ignore[import-not-found,no-redef] # noqa: E501 + from backports.zoneinfo import ZoneInfoNotFoundError # type: ignore[no-redef] # noqa: E501 except ImportError: ZoneInfo = None # type: ignore[assignment, misc] @@ -119,7 +119,7 @@ def versions(self) -> str: return loc[0] @util.memoized_property - def _version_locations(self): + def _version_locations(self) -> Sequence[str]: if self.version_locations: return [ os.path.abspath(util.coerce_resource_to_filename(location)) @@ -303,24 +303,22 @@ def walk_revisions( ): yield cast(Script, rev) - def get_revisions(self, id_: _GetRevArg) -> Tuple[Optional[Script], ...]: + def get_revisions(self, id_: _GetRevArg) -> Tuple[Script, ...]: """Return the :class:`.Script` instance with the given rev identifier, symbolic name, or sequence of identifiers. """ with self._catch_revision_errors(): return cast( - Tuple[Optional[Script], ...], + Tuple[Script, ...], self.revision_map.get_revisions(id_), ) - def get_all_current(self, id_: Tuple[str, ...]) -> Set[Optional[Script]]: + def get_all_current(self, id_: Tuple[str, ...]) -> Set[Script]: with self._catch_revision_errors(): - return cast( - Set[Optional[Script]], self.revision_map._get_all_current(id_) - ) + return cast(Set[Script], self.revision_map._get_all_current(id_)) - def get_revision(self, id_: str) -> Optional[Script]: + def get_revision(self, id_: str) -> Script: """Return the :class:`.Script` instance with the given rev id. .. seealso:: @@ -330,7 +328,7 @@ def get_revision(self, id_: str) -> Optional[Script]: """ with self._catch_revision_errors(): - return cast(Optional[Script], self.revision_map.get_revision(id_)) + return cast(Script, self.revision_map.get_revision(id_)) def as_revision_number( self, id_: Optional[str] @@ -585,7 +583,7 @@ def run_env(self) -> None: util.load_python_file(self.dir, "env.py") @property - def env_py_location(self): + def env_py_location(self) -> str: return os.path.abspath(os.path.join(self.dir, "env.py")) def _generate_template(self, src: str, dest: str, **kw: Any) -> None: @@ -684,7 +682,7 @@ def generate_revision( self.revision_map.get_revisions(head), ) for h in heads: - assert h != "base" + assert h != "base" # type: ignore[comparison-overlap] if len(set(heads)) != len(heads): raise util.CommandError("Duplicate head revisions specified") @@ -823,7 +821,7 @@ def __init__(self, module: ModuleType, rev_id: str, path: str): self.path = path super().__init__( rev_id, - module.down_revision, # type: ignore[attr-defined] + module.down_revision, branch_labels=util.to_tuple( getattr(module, "branch_labels", None), default=() ), @@ -856,7 +854,7 @@ def longdoc(self) -> str: if doc: if hasattr(self.module, "_alembic_source_encoding"): doc = doc.decode( # type: ignore[attr-defined] - self.module._alembic_source_encoding # type: ignore[attr-defined] # noqa + self.module._alembic_source_encoding ) return doc.strip() # type: ignore[union-attr] else: @@ -898,7 +896,7 @@ def log_entry(self) -> str: ) return entry - def __str__(self): + def __str__(self) -> str: return "%s -> %s%s%s%s, %s" % ( self._format_down_revision(), self.revision, diff --git a/alembic/script/revision.py b/alembic/script/revision.py index 03502644..77a802cd 100644 --- a/alembic/script/revision.py +++ b/alembic/script/revision.py @@ -14,6 +14,7 @@ from typing import List from typing import Optional from typing import overload +from typing import Protocol from typing import Sequence from typing import Set from typing import Tuple @@ -47,6 +48,18 @@ _revision_illegal_chars = ["@", "-", "+"] +class _CollectRevisionsProtocol(Protocol): + def __call__( + self, + upper: _RevisionIdentifierType, + lower: _RevisionIdentifierType, + inclusive: bool, + implicit_base: bool, + assert_relative_length: bool, + ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase], ...]]: + ... + + class RevisionError(Exception): pass @@ -396,7 +409,7 @@ def _normalize_depends_on( for rev in self._get_ancestor_nodes( [revision], include_dependencies=False, - map_=cast(_RevisionMapType, map_), + map_=map_, ): if rev is revision: continue @@ -791,7 +804,7 @@ def iterate_revisions( The iterator yields :class:`.Revision` objects. """ - fn: Callable + fn: _CollectRevisionsProtocol if select_for_downgrade: fn = self._collect_downgrade_revisions else: @@ -818,7 +831,7 @@ def _get_descendant_nodes( ) -> Iterator[Any]: if omit_immediate_dependencies: - def fn(rev): + def fn(rev: Revision) -> Iterable[str]: if rev not in targets: return rev._all_nextrev else: @@ -826,12 +839,12 @@ def fn(rev): elif include_dependencies: - def fn(rev): + def fn(rev: Revision) -> Iterable[str]: return rev._all_nextrev else: - def fn(rev): + def fn(rev: Revision) -> Iterable[str]: return rev.nextrev return self._iterate_related_revisions( @@ -847,12 +860,12 @@ def _get_ancestor_nodes( ) -> Iterator[Revision]: if include_dependencies: - def fn(rev): + def fn(rev: Revision) -> Iterable[str]: return rev._normalized_down_revisions else: - def fn(rev): + def fn(rev: Revision) -> Iterable[str]: return rev._versioned_down_revisions return self._iterate_related_revisions( @@ -861,7 +874,7 @@ def fn(rev): def _iterate_related_revisions( self, - fn: Callable, + fn: Callable[[Revision], Iterable[str]], targets: Collection[Optional[_RevisionOrBase]], map_: Optional[_RevisionMapType], check: bool = False, @@ -923,7 +936,7 @@ def _topological_sort( id_to_rev = self._revision_map - def get_ancestors(rev_id): + def get_ancestors(rev_id: str) -> Set[str]: return { r.revision for r in self._get_ancestor_nodes([id_to_rev[rev_id]]) @@ -1041,7 +1054,7 @@ def _walk( children: Sequence[Optional[_RevisionOrBase]] for _ in range(abs(steps)): if steps > 0: - assert initial != "base" + assert initial != "base" # type: ignore[comparison-overlap] # Walk up walk_up = [ is_revision(rev) @@ -1055,7 +1068,7 @@ def _walk( children = walk_up else: # Walk down - if initial == "base": + if initial == "base": # type: ignore[comparison-overlap] children = () else: children = self.get_revisions( @@ -1189,7 +1202,7 @@ def _parse_downgrade_target( # No relative destination given, revision specified is absolute. branch_label, _, symbol = target.rpartition("@") if not branch_label: - branch_label = None # type:ignore[assignment] + branch_label = None return branch_label, self.get_revision(symbol) def _parse_upgrade_target( @@ -1301,11 +1314,11 @@ def _parse_upgrade_target( def _collect_downgrade_revisions( self, upper: _RevisionIdentifierType, - target: _RevisionIdentifierType, + lower: _RevisionIdentifierType, inclusive: bool, implicit_base: bool, assert_relative_length: bool, - ) -> Any: + ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase], ...]]: """ Compute the set of current revisions specified by :upper, and the downgrade target specified by :target. Return all dependents of target @@ -1316,7 +1329,7 @@ def _collect_downgrade_revisions( branch_label, target_revision = self._parse_downgrade_target( current_revisions=upper, - target=target, + target=lower, assert_relative_length=assert_relative_length, ) if target_revision == "base": @@ -1408,7 +1421,7 @@ def _collect_upgrade_revisions( inclusive: bool, implicit_base: bool, assert_relative_length: bool, - ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase]]]: + ) -> Tuple[Set[Revision], Tuple[Revision, ...]]: """ Compute the set of required revisions specified by :upper, and the current set of active revisions specified by :lower. Find the @@ -1500,7 +1513,7 @@ def _collect_upgrade_revisions( ) needs.intersection_update(lower_descendents) - return needs, tuple(targets) # type:ignore[return-value] + return needs, tuple(targets) def _get_all_current( self, id_: Tuple[str, ...] diff --git a/alembic/script/write_hooks.py b/alembic/script/write_hooks.py index b44ce644..99771479 100644 --- a/alembic/script/write_hooks.py +++ b/alembic/script/write_hooks.py @@ -1,3 +1,6 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations import shlex diff --git a/alembic/util/__init__.py b/alembic/util/__init__.py index 3c1e27ca..4724e1f0 100644 --- a/alembic/util/__init__.py +++ b/alembic/util/__init__.py @@ -1,34 +1,34 @@ -from .editor import open_in_editor -from .exc import AutogenerateDiffsDetected -from .exc import CommandError -from .langhelpers import _with_legacy_names -from .langhelpers import asbool -from .langhelpers import dedupe_tuple -from .langhelpers import Dispatcher -from .langhelpers import EMPTY_DICT -from .langhelpers import immutabledict -from .langhelpers import memoized_property -from .langhelpers import ModuleClsProxy -from .langhelpers import not_none -from .langhelpers import rev_id -from .langhelpers import to_list -from .langhelpers import to_tuple -from .langhelpers import unique_list -from .messaging import err -from .messaging import format_as_comma -from .messaging import msg -from .messaging import obfuscate_url_pw -from .messaging import status -from .messaging import warn -from .messaging import write_outstream -from .pyfiles import coerce_resource_to_filename -from .pyfiles import load_python_file -from .pyfiles import pyc_file_from_path -from .pyfiles import template_to_file -from .sqla_compat import has_computed -from .sqla_compat import sqla_13 -from .sqla_compat import sqla_14 -from .sqla_compat import sqla_2 +from .editor import open_in_editor as open_in_editor +from .exc import AutogenerateDiffsDetected as AutogenerateDiffsDetected +from .exc import CommandError as CommandError +from .langhelpers import _with_legacy_names as _with_legacy_names +from .langhelpers import asbool as asbool +from .langhelpers import dedupe_tuple as dedupe_tuple +from .langhelpers import Dispatcher as Dispatcher +from .langhelpers import EMPTY_DICT as EMPTY_DICT +from .langhelpers import immutabledict as immutabledict +from .langhelpers import memoized_property as memoized_property +from .langhelpers import ModuleClsProxy as ModuleClsProxy +from .langhelpers import not_none as not_none +from .langhelpers import rev_id as rev_id +from .langhelpers import to_list as to_list +from .langhelpers import to_tuple as to_tuple +from .langhelpers import unique_list as unique_list +from .messaging import err as err +from .messaging import format_as_comma as format_as_comma +from .messaging import msg as msg +from .messaging import obfuscate_url_pw as obfuscate_url_pw +from .messaging import status as status +from .messaging import warn as warn +from .messaging import write_outstream as write_outstream +from .pyfiles import coerce_resource_to_filename as coerce_resource_to_filename +from .pyfiles import load_python_file as load_python_file +from .pyfiles import pyc_file_from_path as pyc_file_from_path +from .pyfiles import template_to_file as template_to_file +from .sqla_compat import has_computed as has_computed +from .sqla_compat import sqla_13 as sqla_13 +from .sqla_compat import sqla_14 as sqla_14 +from .sqla_compat import sqla_2 as sqla_2 if not sqla_13: diff --git a/alembic/util/compat.py b/alembic/util/compat.py index 5b8f3d95..e185cc41 100644 --- a/alembic/util/compat.py +++ b/alembic/util/compat.py @@ -1,3 +1,5 @@ +# mypy: no-warn-unused-ignores + from __future__ import annotations from configparser import ConfigParser @@ -5,11 +7,20 @@ import os import sys import typing +from typing import Any +from typing import List +from typing import Optional from typing import Sequence from typing import Union -from sqlalchemy.util import inspect_getfullargspec # noqa -from sqlalchemy.util.compat import inspect_formatargspec # noqa +if True: + # zimports hack for too-long names + from sqlalchemy.util import ( # noqa: F401 + inspect_getfullargspec as inspect_getfullargspec, + ) + from sqlalchemy.util.compat import ( # noqa: F401 + inspect_formatargspec as inspect_formatargspec, + ) is_posix = os.name == "posix" @@ -27,9 +38,13 @@ def close(self) -> None: if py39: - from importlib import resources as importlib_resources - from importlib import metadata as importlib_metadata - from importlib.metadata import EntryPoint + from importlib import resources as _resources + + importlib_resources = _resources + from importlib import metadata as _metadata + + importlib_metadata = _metadata + from importlib.metadata import EntryPoint as EntryPoint else: import importlib_resources # type:ignore # noqa import importlib_metadata # type:ignore # noqa @@ -39,12 +54,14 @@ def close(self) -> None: def importlib_metadata_get(group: str) -> Sequence[EntryPoint]: ep = importlib_metadata.entry_points() if hasattr(ep, "select"): - return ep.select(group=group) # type: ignore + return ep.select(group=group) else: return ep.get(group, ()) # type: ignore -def formatannotation_fwdref(annotation, base_module=None): +def formatannotation_fwdref( + annotation: Any, base_module: Optional[Any] = None +) -> str: """vendored from python 3.7""" # copied over _formatannotation from sqlalchemy 2.0 @@ -65,7 +82,7 @@ def formatannotation_fwdref(annotation, base_module=None): def read_config_parser( file_config: ConfigParser, file_argument: Sequence[Union[str, os.PathLike[str]]], -) -> list[str]: +) -> List[str]: if py310: return file_config.read(file_argument, encoding="locale") else: diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py index 34d48bc6..4a5bf09a 100644 --- a/alembic/util/langhelpers.py +++ b/alembic/util/langhelpers.py @@ -5,33 +5,46 @@ import textwrap from typing import Any from typing import Callable +from typing import cast from typing import Dict from typing import List from typing import Mapping +from typing import MutableMapping +from typing import NoReturn from typing import Optional from typing import overload from typing import Sequence +from typing import Set from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union import uuid import warnings -from sqlalchemy.util import asbool # noqa -from sqlalchemy.util import immutabledict # noqa -from sqlalchemy.util import memoized_property # noqa -from sqlalchemy.util import to_list # noqa -from sqlalchemy.util import unique_list # noqa +from sqlalchemy.util import asbool as asbool # noqa: F401 +from sqlalchemy.util import immutabledict as immutabledict # noqa: F401 +from sqlalchemy.util import to_list as to_list # noqa: F401 +from sqlalchemy.util import unique_list as unique_list from .compat import inspect_getfullargspec +if True: + # zimports workaround :( + from sqlalchemy.util import ( # noqa: F401 + memoized_property as memoized_property, + ) + EMPTY_DICT: Mapping[Any, Any] = immutabledict() -_T = TypeVar("_T") +_T = TypeVar("_T", bound=Any) + +_C = TypeVar("_C", bound=Callable[..., Any]) class _ModuleClsMeta(type): - def __setattr__(cls, key: str, value: Callable) -> None: + def __setattr__(cls, key: str, value: Callable[..., Any]) -> None: super().__setattr__(key, value) cls._update_module_proxies(key) # type: ignore @@ -45,9 +58,13 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta): """ - _setups: Dict[type, Tuple[set, list]] = collections.defaultdict( - lambda: (set(), []) - ) + _setups: Dict[ + Type[Any], + Tuple[ + Set[str], + List[Tuple[MutableMapping[str, Any], MutableMapping[str, Any]]], + ], + ] = collections.defaultdict(lambda: (set(), [])) @classmethod def _update_module_proxies(cls, name: str) -> None: @@ -70,18 +87,33 @@ def _remove_proxy(self) -> None: del globals_[attr_name] @classmethod - def create_module_class_proxy(cls, globals_, locals_): + def create_module_class_proxy( + cls, + globals_: MutableMapping[str, Any], + locals_: MutableMapping[str, Any], + ) -> None: attr_names, modules = cls._setups[cls] modules.append((globals_, locals_)) cls._setup_proxy(globals_, locals_, attr_names) @classmethod - def _setup_proxy(cls, globals_, locals_, attr_names): + def _setup_proxy( + cls, + globals_: MutableMapping[str, Any], + locals_: MutableMapping[str, Any], + attr_names: Set[str], + ) -> None: for methname in dir(cls): cls._add_proxied_attribute(methname, globals_, locals_, attr_names) @classmethod - def _add_proxied_attribute(cls, methname, globals_, locals_, attr_names): + def _add_proxied_attribute( + cls, + methname: str, + globals_: MutableMapping[str, Any], + locals_: MutableMapping[str, Any], + attr_names: Set[str], + ) -> None: if not methname.startswith("_"): meth = getattr(cls, methname) if callable(meth): @@ -92,10 +124,15 @@ def _add_proxied_attribute(cls, methname, globals_, locals_, attr_names): attr_names.add(methname) @classmethod - def _create_method_proxy(cls, name, globals_, locals_): + def _create_method_proxy( + cls, + name: str, + globals_: MutableMapping[str, Any], + locals_: MutableMapping[str, Any], + ) -> Callable[..., Any]: fn = getattr(cls, name) - def _name_error(name, from_): + def _name_error(name: str, from_: Exception) -> NoReturn: raise NameError( "Can't invoke function '%s', as the proxy object has " "not yet been " @@ -119,7 +156,9 @@ def _name_error(name, from_): translations, ) - def translate(fn_name, spec, translations, args, kw): + def translate( + fn_name: str, spec: Any, translations: Any, args: Any, kw: Any + ) -> Any: return_kw = {} return_args = [] @@ -176,15 +215,15 @@ def %(name)s(%(args)s): "doc": fn.__doc__, } ) - lcl = {} + lcl: MutableMapping[str, Any] = {} - exec(func_text, globals_, lcl) - return lcl[name] + exec(func_text, cast("Dict[str, Any]", globals_), lcl) + return cast("Callable[..., Any]", lcl[name]) -def _with_legacy_names(translations): - def decorate(fn): - fn._legacy_translations = translations +def _with_legacy_names(translations: Any) -> Any: + def decorate(fn: _C) -> _C: + fn._legacy_translations = translations # type: ignore[attr-defined] return fn return decorate @@ -195,21 +234,25 @@ def rev_id() -> str: @overload -def to_tuple(x: Any, default: tuple) -> tuple: +def to_tuple(x: Any, default: Tuple[Any, ...]) -> Tuple[Any, ...]: ... @overload -def to_tuple(x: None, default: Optional[_T] = None) -> _T: +def to_tuple(x: None, default: Optional[_T] = ...) -> _T: ... @overload -def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple: +def to_tuple( + x: Any, default: Optional[Tuple[Any, ...]] = None +) -> Tuple[Any, ...]: ... -def to_tuple(x, default=None): +def to_tuple( + x: Any, default: Optional[Tuple[Any, ...]] = None +) -> Optional[Tuple[Any, ...]]: if x is None: return default elif isinstance(x, str): @@ -226,13 +269,13 @@ def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]: class Dispatcher: def __init__(self, uselist: bool = False) -> None: - self._registry: Dict[tuple, Any] = {} + self._registry: Dict[Tuple[Any, ...], Any] = {} self.uselist = uselist def dispatch_for( self, target: Any, qualifier: str = "default" - ) -> Callable: - def decorate(fn): + ) -> Callable[[_C], _C]: + def decorate(fn: _C) -> _C: if self.uselist: self._registry.setdefault((target, qualifier), []).append(fn) else: @@ -244,7 +287,7 @@ def decorate(fn): def dispatch(self, obj: Any, qualifier: str = "default") -> Any: if isinstance(obj, str): - targets: Sequence = [obj] + targets: Sequence[Any] = [obj] elif isinstance(obj, type): targets = obj.__mro__ else: @@ -259,11 +302,13 @@ def dispatch(self, obj: Any, qualifier: str = "default") -> Any: raise ValueError("no dispatch function for object: %s" % obj) def _fn_or_list( - self, fn_or_list: Union[List[Callable], Callable] - ) -> Callable: + self, fn_or_list: Union[List[Callable[..., Any]], Callable[..., Any]] + ) -> Callable[..., Any]: if self.uselist: - def go(*arg, **kw): + def go(*arg: Any, **kw: Any) -> None: + if TYPE_CHECKING: + assert isinstance(fn_or_list, Sequence) for fn in fn_or_list: fn(*arg, **kw) diff --git a/alembic/util/messaging.py b/alembic/util/messaging.py index 35592c0e..5f14d597 100644 --- a/alembic/util/messaging.py +++ b/alembic/util/messaging.py @@ -5,6 +5,7 @@ import logging import sys import textwrap +from typing import Iterator from typing import Optional from typing import TextIO from typing import Union @@ -53,7 +54,9 @@ def write_outstream( @contextmanager -def status(status_msg: str, newline: bool = False, quiet: bool = False): +def status( + status_msg: str, newline: bool = False, quiet: bool = False +) -> Iterator[None]: msg(status_msg + " ...", newline, flush=True, quiet=quiet) try: yield @@ -66,7 +69,7 @@ def status(status_msg: str, newline: bool = False, quiet: bool = False): write_outstream(sys.stdout, " done\n") -def err(message: str, quiet: bool = False): +def err(message: str, quiet: bool = False) -> None: log.error(message) msg(f"FAILED: {message}", quiet=quiet) sys.exit(-1) @@ -74,7 +77,7 @@ def err(message: str, quiet: bool = False): def obfuscate_url_pw(input_url: str) -> str: u = url.make_url(input_url) - return sqla_compat.url_render_as_string(u, hide_password=True) + return sqla_compat.url_render_as_string(u, hide_password=True) # type: ignore # noqa: E501 def warn(msg: str, stacklevel: int = 2) -> None: diff --git a/alembic/util/pyfiles.py b/alembic/util/pyfiles.py index e7576731..973bd458 100644 --- a/alembic/util/pyfiles.py +++ b/alembic/util/pyfiles.py @@ -8,6 +8,8 @@ import os import re import tempfile +from types import ModuleType +from typing import Any from typing import Optional from mako import exceptions @@ -18,7 +20,7 @@ def template_to_file( - template_file: str, dest: str, output_encoding: str, **kw + template_file: str, dest: str, output_encoding: str, **kw: Any ) -> None: template = Template(filename=template_file) try: @@ -82,7 +84,7 @@ def pyc_file_from_path(path: str) -> Optional[str]: return None -def load_python_file(dir_: str, filename: str): +def load_python_file(dir_: str, filename: str) -> ModuleType: """Load a file from the given path as a Python module.""" module_id = re.sub(r"\W", "_", filename) @@ -99,10 +101,12 @@ def load_python_file(dir_: str, filename: str): module = load_module_py(module_id, pyc_path) elif ext in (".pyc", ".pyo"): module = load_module_py(module_id, path) + else: + assert False return module -def load_module_py(module_id: str, path: str): +def load_module_py(module_id: str, path: str) -> ModuleType: spec = importlib.util.spec_from_file_location(module_id, path) assert spec module = importlib.util.module_from_spec(spec) diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index 9332a062..8489c19f 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -1,13 +1,20 @@ +# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls +# mypy: no-warn-return-any, allow-any-generics + from __future__ import annotations import contextlib import re from typing import Any +from typing import Callable from typing import Dict from typing import Iterable from typing import Iterator from typing import Mapping from typing import Optional +from typing import Protocol +from typing import Set +from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -18,7 +25,6 @@ from sqlalchemy import sql from sqlalchemy import types as sqltypes from sqlalchemy.engine import url -from sqlalchemy.ext.compiler import compiles from sqlalchemy.schema import CheckConstraint from sqlalchemy.schema import Column from sqlalchemy.schema import ForeignKeyConstraint @@ -33,6 +39,7 @@ from typing_extensions import TypeGuard if TYPE_CHECKING: + from sqlalchemy import ClauseElement from sqlalchemy import Index from sqlalchemy import Table from sqlalchemy.engine import Connection @@ -51,6 +58,11 @@ _CE = TypeVar("_CE", bound=Union["ColumnElement[Any]", "SchemaItem"]) +class _CompilerProtocol(Protocol): + def __call__(self, element: Any, compiler: Any, **kw: Any) -> str: + ... + + def _safe_int(value: str) -> Union[int, str]: try: return int(value) @@ -70,7 +82,7 @@ def _safe_int(value: str) -> Union[int, str]: sqlalchemy_version = __version__ try: - from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME + from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME # type: ignore[attr-defined] # noqa: E501 except ImportError: from sqlalchemy.sql.elements import _NONE_NAME as _NONE_NAME # type: ignore # noqa: E501 @@ -79,8 +91,18 @@ class _Unsupported: "Placeholder for unsupported SQLAlchemy classes" +if TYPE_CHECKING: + + def compiles( + element: Type[ClauseElement], *dialects: str + ) -> Callable[[_CompilerProtocol], _CompilerProtocol]: + ... + +else: + from sqlalchemy.ext.compiler import compiles + try: - from sqlalchemy import Computed + from sqlalchemy import Computed as Computed except ImportError: if not TYPE_CHECKING: @@ -94,7 +116,7 @@ class Computed(_Unsupported): has_computed_reflection = _vers >= (1, 3, 16) try: - from sqlalchemy import Identity + from sqlalchemy import Identity as Identity except ImportError: if not TYPE_CHECKING: @@ -250,7 +272,7 @@ def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]: def _copy(schema_item: _CE, **kw) -> _CE: if hasattr(schema_item, "_copy"): - return schema_item._copy(**kw) # type: ignore[union-attr] + return schema_item._copy(**kw) else: return schema_item.copy(**kw) # type: ignore[union-attr] @@ -368,7 +390,12 @@ def _get_variant_mapping(type_): return type_.impl, type_.mapping -def _fk_spec(constraint): +def _fk_spec(constraint: ForeignKeyConstraint) -> Any: + if TYPE_CHECKING: + assert constraint.columns is not None + assert constraint.elements is not None + assert isinstance(constraint.parent, Table) + source_columns = [ constraint.columns[key].name for key in constraint.column_keys ] @@ -397,7 +424,7 @@ def _fk_spec(constraint): def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool: - spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined] + spec = constraint.elements[0]._get_colspec() tokens = spec.split(".") tokens.pop(-1) # colname tablekey = ".".join(tokens) @@ -409,13 +436,13 @@ def _is_type_bound(constraint: Constraint) -> bool: # this deals with SQLAlchemy #3260, don't copy CHECK constraints # that will be generated by the type. # new feature added for #3260 - return constraint._type_bound # type: ignore[attr-defined] + return constraint._type_bound def _find_columns(clause): """locate Column objects within the given expression.""" - cols = set() + cols: Set[ColumnElement[Any]] = set() traverse(clause, {}, {"column": cols.add}) return cols @@ -562,9 +589,7 @@ def _get_constraint_final_name( if isinstance(constraint, schema.Index): # name should not be quoted. d = dialect.ddl_compiler(dialect, None) # type: ignore[arg-type] - return d._prepared_index_name( # type: ignore[attr-defined] - constraint - ) + return d._prepared_index_name(constraint) else: # name should not be quoted. return dialect.identifier_preparer.format_constraint(constraint) @@ -608,7 +633,11 @@ def _insert_inline(table: Union[TableClause, Table]) -> Insert: if sqla_14: from sqlalchemy import create_mock_engine - from sqlalchemy import select as _select + + # weird mypy workaround + from sqlalchemy import select as _sa_select + + _select = _sa_select else: from sqlalchemy import create_engine @@ -617,7 +646,7 @@ def create_mock_engine(url, executor, **kw): # type: ignore[misc] "postgresql://", strategy="mock", executor=executor ) - def _select(*columns, **kw) -> Select: # type: ignore[no-redef] + def _select(*columns, **kw) -> Select: return sql.select(list(columns), **kw) # type: ignore[call-overload] diff --git a/docs/build/unreleased/1377.rst b/docs/build/unreleased/1377.rst new file mode 100644 index 00000000..a8bb6c14 --- /dev/null +++ b/docs/build/unreleased/1377.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, typing + :tickets: 1377 + + Updated pep-484 typing to pass mypy "strict" mode, however including + per-module qualifications for specific typing elements not yet complete. + This allows us to catch specific typing issues that have been ongoing + such as import symbols not properly exported. + diff --git a/pyproject.toml b/pyproject.toml index f66269af..b9b1f44a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,15 +16,15 @@ exclude = [ show_error_codes = true [[tool.mypy.overrides]] + module = [ - 'alembic.operations.ops', - 'alembic.op', - 'alembic.context', - 'alembic.autogenerate.api', - 'alembic.runtime.*', + "alembic.*" ] -disallow_incomplete_defs = true +warn_unused_ignores = true +strict = true + + [[tool.mypy.overrides]] module = [ diff --git a/setup.cfg b/setup.cfg index 5c330383..fa957eca 100644 --- a/setup.cfg +++ b/setup.cfg @@ -125,18 +125,3 @@ python_files=tests/test_*.py markers = backend: tests that should run on all backends; typically dialect-sensitive -[mypy] -show_error_codes = True -allow_redefinition = True - -[mypy-mako.*] -ignore_missing_imports = True - -[mypy-sqlalchemy.testing.*] -ignore_missing_imports = True - -[mypy-importlib_resources.*] -ignore_missing_imports = True - -[mypy-importlib_metadata.*] -ignore_missing_imports = True diff --git a/tools/write_pyi.py b/tools/write_pyi.py index 5abb26ef..363d727e 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -127,9 +127,7 @@ def generate_pyi_for_proxy( {"entrypoint": "zimports", "options": "-e"}, ignore_output=ignore_output, ) - # note that we do not distribute pyproject.toml with the distribution - # right now due to user complaints, so we can't refer to it here because - # this all has to run as part of the test suite + console_scripts( str(destination_path), {"entrypoint": "black", "options": "-l79"}, @@ -190,6 +188,8 @@ def _formatannotation(annotation, base_module=None): else: retval = annotation + retval = re.sub(r"TypeEngine\b", "TypeEngine[Any]", retval) + retval = retval.replace("~", "") # typevar repr as "~T" for trim in TRIM_MODULE: retval = retval.replace(trim, "")