Skip to content

Commit

Permalink
Merge "finish strict typing for most modules" into main
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzeek authored and Gerrit Code Review committed Dec 19, 2023
2 parents 1b0e4bc + f443584 commit 4095eba
Show file tree
Hide file tree
Showing 38 changed files with 581 additions and 307 deletions.
20 changes: 10 additions & 10 deletions alembic/autogenerate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .api import _render_migration_diffs
from .api import compare_metadata
from .api import produce_migrations
from .api import render_python_code
from .api import RevisionContext
from .compare import _produce_net_changes
from .compare import comparators
from .render import render_op_text
from .render import renderers
from .rewriter import Rewriter
from .api import _render_migration_diffs as _render_migration_diffs
from .api import compare_metadata as compare_metadata
from .api import produce_migrations as produce_migrations
from .api import render_python_code as render_python_code
from .api import RevisionContext as RevisionContext
from .compare import _produce_net_changes as _produce_net_changes
from .compare import comparators as comparators
from .render import render_op_text as render_op_text
from .render import renderers as renderers
from .rewriter import Rewriter as Rewriter
10 changes: 6 additions & 4 deletions alembic/autogenerate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sqlalchemy.engine import Inspector
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import SchemaItem
from sqlalchemy.sql.schema import Table

from ..config import Config
from ..operations.ops import DowngradeOps
Expand Down Expand Up @@ -165,6 +166,7 @@ def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:
"""

migration_script = produce_migrations(context, metadata)
assert migration_script.upgrade_ops is not None
return migration_script.upgrade_ops.as_diffs()


Expand Down Expand Up @@ -331,7 +333,7 @@ def __init__(
self,
migration_context: MigrationContext,
metadata: Optional[MetaData] = None,
opts: Optional[dict] = None,
opts: Optional[Dict[str, Any]] = None,
autogenerate: bool = True,
) -> None:
if (
Expand Down Expand Up @@ -465,7 +467,7 @@ def run_object_filters(
run_filters = run_object_filters

@util.memoized_property
def sorted_tables(self):
def sorted_tables(self) -> List[Table]:
"""Return an aggregate of the :attr:`.MetaData.sorted_tables`
collection(s).
Expand All @@ -481,7 +483,7 @@ def sorted_tables(self):
return result

@util.memoized_property
def table_key_to_table(self):
def table_key_to_table(self) -> Dict[str, Table]:
"""Return an aggregate of the :attr:`.MetaData.tables` dictionaries.
The :attr:`.MetaData.tables` collection is a dictionary of table key
Expand All @@ -492,7 +494,7 @@ def table_key_to_table(self):
objects contain the same table key, an exception is raised.
"""
result = {}
result: Dict[str, Table] = {}
for m in util.to_list(self.metadata):
intersect = set(result).intersection(set(m.tables))
if intersect:
Expand Down
9 changes: 5 additions & 4 deletions alembic/autogenerate/compare.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics

from __future__ import annotations

import contextlib
Expand Down Expand Up @@ -577,9 +580,7 @@ def _compare_indexes_and_uniques(
# 5. index things by name, for those objects that have names
metadata_names = {
cast(str, c.md_name_to_sql_name(autogen_context)): c
for c in metadata_unique_constraints_sig.union(
metadata_indexes_sig # type:ignore[arg-type]
)
for c in metadata_unique_constraints_sig.union(metadata_indexes_sig)
if c.is_named
}

Expand Down Expand Up @@ -1240,7 +1241,7 @@ def _add_fk(obj, compare_to):
obj.const, obj.name, "foreign_key_constraint", False, compare_to
):
modify_table_ops.ops.append(
ops.CreateForeignKeyOp.from_constraint(const.const)
ops.CreateForeignKeyOp.from_constraint(const.const) # type: ignore[has-type] # noqa: E501
)

log.info(
Expand Down
18 changes: 9 additions & 9 deletions alembic/autogenerate/render.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics

from __future__ import annotations

from io import StringIO
Expand Down Expand Up @@ -849,7 +852,7 @@ def _render_Variant_type(
) -> str:
base_type, variant_mapping = sqla_compat._get_variant_mapping(type_)
base = _repr_type(base_type, autogen_context, _skip_variants=True)
assert base is not None and base is not False
assert base is not None and base is not False # type: ignore[comparison-overlap] # noqa:E501
for dialect in sorted(variant_mapping):
typ = variant_mapping[dialect]
base += ".with_variant(%s, %r)" % (
Expand Down Expand Up @@ -946,7 +949,7 @@ def _fk_colspec(
won't fail if the remote table can't be resolved.
"""
colspec = fk._get_colspec() # type:ignore[attr-defined]
colspec = fk._get_colspec()
tokens = colspec.split(".")
tname, colname = tokens[-2:]

Expand Down Expand Up @@ -1016,8 +1019,7 @@ def _render_foreign_key(
% {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
"cols": ", ".join(
"%r" % _ident(cast("Column", f.parent).name)
for f in constraint.elements
repr(_ident(f.parent.name)) for f in constraint.elements
),
"refcols": ", ".join(
repr(_fk_colspec(f, apply_metadata_schema, namespace_metadata))
Expand Down Expand Up @@ -1058,12 +1060,10 @@ def _render_check_constraint(
# ideally SQLAlchemy would give us more of a first class
# way to detect this.
if (
constraint._create_rule # type:ignore[attr-defined]
and hasattr(
constraint._create_rule, "target" # type:ignore[attr-defined]
)
constraint._create_rule
and hasattr(constraint._create_rule, "target")
and isinstance(
constraint._create_rule.target, # type:ignore[attr-defined]
constraint._create_rule.target,
sqltypes.TypeEngine,
)
):
Expand Down
15 changes: 9 additions & 6 deletions alembic/autogenerate/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
from ..operations.ops import AddColumnOp
from ..operations.ops import AlterColumnOp
from ..operations.ops import CreateTableOp
from ..operations.ops import DowngradeOps
from ..operations.ops import MigrateOperation
from ..operations.ops import MigrationScript
from ..operations.ops import ModifyTableOps
from ..operations.ops import OpContainer
from ..runtime.environment import _GetRevArg
from ..operations.ops import UpgradeOps
from ..runtime.migration import MigrationContext
from ..script.revision import _GetRevArg


class Rewriter:
Expand Down Expand Up @@ -101,7 +103,7 @@ def rewrites(
Type[CreateTableOp],
Type[ModifyTableOps],
],
) -> Callable:
) -> Callable[..., Any]:
"""Register a function as rewriter for a given type.
The function should receive three arguments, which are
Expand Down Expand Up @@ -156,25 +158,26 @@ def _traverse_script(
revision: _GetRevArg,
directive: MigrationScript,
) -> None:
upgrade_ops_list = []
upgrade_ops_list: List[UpgradeOps] = []
for upgrade_ops in directive.upgrade_ops_list:
ret = self._traverse_for(context, revision, upgrade_ops)
if len(ret) != 1:
raise ValueError(
"Can only return single object for UpgradeOps traverse"
)
upgrade_ops_list.append(ret[0])
directive.upgrade_ops = upgrade_ops_list

downgrade_ops_list = []
directive.upgrade_ops = upgrade_ops_list # type: ignore

downgrade_ops_list: List[DowngradeOps] = []
for downgrade_ops in directive.downgrade_ops_list:
ret = self._traverse_for(context, revision, downgrade_ops)
if len(ret) != 1:
raise ValueError(
"Can only return single object for DowngradeOps traverse"
)
downgrade_ops_list.append(ret[0])
directive.downgrade_ops = downgrade_ops_list
directive.downgrade_ops = downgrade_ops_list # type: ignore

@_traverse.dispatch_for(ops.OpContainer)
def _traverse_op_container(
Expand Down
4 changes: 3 additions & 1 deletion alembic/command.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# mypy: allow-untyped-defs, allow-untyped-calls

from __future__ import annotations

import os
Expand All @@ -18,7 +20,7 @@
from .runtime.environment import ProcessRevisionDirectiveFn


def list_templates(config: Config):
def list_templates(config: Config) -> None:
"""List available templates.
:param config: a :class:`.Config` object.
Expand Down
29 changes: 20 additions & 9 deletions alembic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Mapping
from typing import Optional
from typing import overload
from typing import Sequence
from typing import TextIO
from typing import Union

Expand Down Expand Up @@ -104,7 +105,7 @@ def __init__(
stdout: TextIO = sys.stdout,
cmd_opts: Optional[Namespace] = None,
config_args: Mapping[str, Any] = util.immutabledict(),
attributes: Optional[dict] = None,
attributes: Optional[Dict[str, Any]] = None,
) -> None:
"""Construct a new :class:`.Config`"""
self.config_file_name = file_
Expand Down Expand Up @@ -140,7 +141,7 @@ def __init__(
"""

@util.memoized_property
def attributes(self):
def attributes(self) -> Dict[str, Any]:
"""A Python dictionary for storage of additional state.
Expand All @@ -159,7 +160,7 @@ def attributes(self):
"""
return {}

def print_stdout(self, text: str, *arg) -> None:
def print_stdout(self, text: str, *arg: Any) -> None:
"""Render a message to standard out.
When :meth:`.Config.print_stdout` is called with additional args
Expand All @@ -183,7 +184,7 @@ def print_stdout(self, text: str, *arg) -> None:
util.write_outstream(self.stdout, output, "\n", **self.messaging_opts)

@util.memoized_property
def file_config(self):
def file_config(self) -> ConfigParser:
"""Return the underlying ``ConfigParser`` object.
Direct access to the .ini file is available here,
Expand Down Expand Up @@ -321,7 +322,9 @@ def get_main_option(
) -> Optional[str]:
...

def get_main_option(self, name, default=None):
def get_main_option(
self, name: str, default: Optional[str] = None
) -> Optional[str]:
"""Return an option from the 'main' section of the .ini file.
This defaults to being a key from the ``[alembic]``
Expand Down Expand Up @@ -351,7 +354,9 @@ def __init__(self, prog: Optional[str] = None) -> None:
self._generate_args(prog)

def _generate_args(self, prog: Optional[str]) -> None:
def add_options(fn, parser, positional, kwargs):
def add_options(
fn: Any, parser: Any, positional: Any, kwargs: Any
) -> None:
kwargs_opts = {
"template": (
"-t",
Expand Down Expand Up @@ -554,7 +559,9 @@ def add_options(fn, parser, positional, kwargs):
)
subparsers = parser.add_subparsers()

positional_translations = {command.stamp: {"revision": "revisions"}}
positional_translations: Dict[Any, Any] = {
command.stamp: {"revision": "revisions"}
}

for fn in [getattr(command, n) for n in dir(command)]:
if (
Expand Down Expand Up @@ -609,7 +616,7 @@ def run_cmd(self, config: Config, options: Namespace) -> None:
else:
util.err(str(e), **config.messaging_opts)

def main(self, argv=None):
def main(self, argv: Optional[Sequence[str]] = None) -> None:
options = self.parser.parse_args(argv)
if not hasattr(options, "cmd"):
# see http://bugs.python.org/issue9253, argparse
Expand All @@ -624,7 +631,11 @@ def main(self, argv=None):
self.run_cmd(cfg, options)


def main(argv=None, prog=None, **kwargs):
def main(
argv: Optional[Sequence[str]] = None,
prog: Optional[str] = None,
**kwargs: Any,
) -> None:
"""The console runner function for Alembic."""

CommandLine(prog=prog).main(argv=argv)
Expand Down
9 changes: 5 additions & 4 deletions alembic/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def configure(
MigrationContext,
Column[Any],
Column[Any],
TypeEngine,
TypeEngine,
TypeEngine[Any],
TypeEngine[Any],
],
Optional[bool],
],
Expand Down Expand Up @@ -636,7 +636,8 @@ def configure(
"""

def execute(
sql: Union[Executable, str], execution_options: Optional[dict] = None
sql: Union[Executable, str],
execution_options: Optional[Dict[str, Any]] = None,
) -> None:
"""Execute the given SQL using the current change context.
Expand Down Expand Up @@ -805,7 +806,7 @@ def is_offline_mode() -> bool:
"""

def is_transactional_ddl():
def is_transactional_ddl() -> bool:
"""Return True if the context is configured to expect a
transactional DDL capable backend.
Expand Down
2 changes: 1 addition & 1 deletion alembic/ddl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from . import oracle
from . import postgresql
from . import sqlite
from .impl import DefaultImpl
from .impl import DefaultImpl as DefaultImpl
6 changes: 4 additions & 2 deletions alembic/ddl/_autogen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics

from __future__ import annotations

from typing import Any
Expand All @@ -19,7 +22,6 @@
from sqlalchemy.sql.schema import UniqueConstraint
from typing_extensions import TypeGuard

from alembic.ddl.base import _fk_spec
from .. import util
from ..util import sqla_compat

Expand Down Expand Up @@ -275,7 +277,7 @@ def __init__(
ondelete,
deferrable,
initially,
) = _fk_spec(const)
) = sqla_compat._fk_spec(const)

self._sig: Tuple[Any, ...] = (
self.source_schema,
Expand Down
Loading

0 comments on commit 4095eba

Please sign in to comment.