Skip to content

Commit

Permalink
update mypy version + fix issues + remove deprecatedlist helper (#1628)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin authored Aug 29, 2023
1 parent f42d926 commit 111fdbe
Show file tree
Hide file tree
Showing 7 changed files with 9 additions and 162 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_version() -> str:
extras["quality"] = [
"black==23.7",
"ruff>=0.0.241",
"mypy==0.982",
"mypy==1.5.1",
]

extras["all"] = extras["testing"] + extras["quality"] + extras["typing"]
Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/commands/delete_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
See discussions in https://github.com/huggingface/huggingface_hub/issues/1025.
"""
import os
from argparse import _SubParsersAction
from argparse import Namespace, _SubParsersAction
from functools import wraps
from tempfile import mkstemp
from typing import Any, Callable, Iterable, List, Optional, Union
Expand Down Expand Up @@ -121,7 +121,7 @@ def register_subcommand(parser: _SubParsersAction):

delete_cache_parser.set_defaults(func=DeleteCacheCommand)

def __init__(self, args):
def __init__(self, args: Namespace) -> None:
self.cache_dir: Optional[str] = args.dir
self.disable_tui: bool = args.disable_tui

Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/commands/lfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ def read_msg() -> Optional[Dict]:


class LfsUploadCommand:
def __init__(self, args):
def __init__(self, args) -> None:
self.args = args

def run(self):
def run(self) -> None:
# Immediately after invoking a custom transfer process, git-lfs
# sends initiation data to the process over stdin.
# This tells the process useful information about the configuration.
Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/commands/scan_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
huggingface-cli scan-cache --dir ~/.cache/huggingface/hub
"""
import time
from argparse import _SubParsersAction
from argparse import Namespace, _SubParsersAction
from typing import Optional

from ..utils import CacheNotFound, HFCacheInfo, scan_cache_dir
Expand Down Expand Up @@ -49,7 +49,7 @@ def register_subcommand(parser: _SubParsersAction):
)
scan_cache_parser.set_defaults(func=ScanCacheCommand)

def __init__(self, args):
def __init__(self, args: Namespace) -> None:
self.verbosity: int = args.verbose
self.cache_dir: Optional[str] = args.dir

Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/utils/_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class CacheNotFound(Exception):
"""Exception thrown when the Huggingface cache is not found."""

cache_dir = Union[str, Path]
cache_dir: Union[str, Path]

def __init__(self, msg: str, cache_dir: Union[str, Path], *args, **kwargs):
super().__init__(msg, *args, **kwargs)
Expand Down
97 changes: 1 addition & 96 deletions src/huggingface_hub/utils/_deprecation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from functools import wraps
from inspect import Parameter, signature
from typing import Generator, Iterable, Optional
from typing import Iterable, Optional


def _deprecate_positional_args(*, version: str):
Expand Down Expand Up @@ -130,98 +130,3 @@ def inner_f(*args, **kwargs):
return inner_f

return _inner_deprecate_method


def _deprecate_list_output(*, version: str):
"""Decorator to deprecate the usage as a list of the output of a method.
To be used when a method currently returns a list of objects but is planned to return
an generator instead in the future. Output is still a list but tweaked to issue a
warning message when it is specifically used as a list (e.g. get/set/del item, get
length,...).
Args:
version (`str`):
The version when output will start to be an generator.
"""

def _inner_deprecate_method(f):
@wraps(f)
def inner_f(*args, **kwargs):
list_value = f(*args, **kwargs)
return DeprecatedList(
list_value,
warning_message=(
"'{f.__name__}' currently returns a list of objects but is planned"
" to be a generator starting from version {version} in order to"
" implement pagination. Please avoid to use"
" `{f.__name__}(...).{attr_name}` or explicitly convert the output"
" to a list first with `[item for item in {f.__name__}(...)]`.".format(
f=f,
version=version,
# Dumb but working workaround to render `attr_name` later
# Taken from https://stackoverflow.com/a/35300723
attr_name="{attr_name}",
)
),
)

return inner_f

return _inner_deprecate_method


def _empty_gen() -> Generator:
# Create an empty generator
# Taken from https://stackoverflow.com/a/13243870
return
yield


# Build the set of attributes that are specific to a List object (and will be deprecated)
_LIST_ONLY_ATTRS = frozenset(set(dir([])) - set(dir(_empty_gen())))


class DeprecateListMetaclass(type):
"""Metaclass that overwrites all list-only methods, including magic ones."""

def __new__(cls, clsname, bases, attrs):
# Check consistency
if "_deprecate" not in attrs:
raise TypeError("A `_deprecate` method must be implemented to use `DeprecateListMetaclass`.")
if list not in bases:
raise TypeError("Class must inherit from `list` to use `DeprecateListMetaclass`.")

# Create decorator to deprecate list-only methods, including magic ones
def _with_deprecation(f, name):
@wraps(f)
def _inner(self, *args, **kwargs):
self._deprecate(name) # Use the `_deprecate`
return f(self, *args, **kwargs)

return _inner

# Deprecate list-only methods
for attr in _LIST_ONLY_ATTRS:
attrs[attr] = _with_deprecation(getattr(list, attr), attr)

return super().__new__(cls, clsname, bases, attrs)


class DeprecatedList(list, metaclass=DeprecateListMetaclass):
"""Custom List class for which all calls to a list-specific method is deprecated.
Methods that are shared with a generator are not deprecated.
See `_deprecate_list_output` for more details.
"""

def __init__(self, iterable, warning_message: str):
"""Initialize the list with a default warning message.
Warning message will be formatted at runtime with a "{attr_name}" value.
"""
super().__init__(iterable)
self._deprecation_msg = warning_message

def _deprecate(self, attr_name: str) -> None:
warnings.warn(self._deprecation_msg.format(attr_name=attr_name), FutureWarning)
58 changes: 0 additions & 58 deletions tests/test_utils_deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from huggingface_hub.utils._deprecation import (
_deprecate_arguments,
_deprecate_list_output,
_deprecate_method,
_deprecate_positional_args,
)
Expand Down Expand Up @@ -129,60 +128,3 @@ def dummy_deprecated() -> None:
"'dummy_deprecated' (from 'tests.test_utils_deprecation') is deprecated"
" and will be removed from version 'xxx'. This is a custom message.",
)

def test_deprecate_list_output(self) -> None:
"""Test test_deprecate_list_output throw warnings."""

@_deprecate_list_output(version="xxx")
def dummy_deprecated() -> None:
return [1, 2, 3]

output = dummy_deprecated()

# Still a list !
self.assertIsInstance(output, list)

# __getitem__
with pytest.warns(FutureWarning) as record:
self.assertEqual(output[0], 1)

# (check real message once)
self.assertEqual(
record[0].message.args[0],
"'dummy_deprecated' currently returns a list of objects but is planned to be a generator starting from"
" version xxx in order to implement pagination. Please avoid to use"
" `dummy_deprecated(...).__getitem__` or explicitly convert the output to a list first with `[item for"
" item in dummy_deprecated(...)]`.",
)

# __setitem__
with pytest.warns(FutureWarning):
output[0] = 10

# __delitem__
with pytest.warns(FutureWarning):
del output[1]

# `.append` deprecated (as .index, .pop, .insert, .remove, .sort,...)
with pytest.warns(FutureWarning):
output.append(5)

# Magic __len__ deprecated (as __add__, __mul__, __contains__,...)
with pytest.warns(FutureWarning):
self.assertEqual(len(output), 3)

with pytest.warns(FutureWarning):
self.assertTrue(3 in output)

# List has been modified
# Iterating over items is NOT deprecated !
for item, expected in zip(output, [10, 3, 5]):
self.assertEqual(item, expected)

# If output is converted to list, no warning is raised.
output_as_list = list(iter(dummy_deprecated()))
self.assertEqual(output_as_list[0], 1)
del output_as_list[1]
output_as_list.append(5)
self.assertEqual(len(output_as_list), 3)
self.assertTrue(3 in output_as_list)

0 comments on commit 111fdbe

Please sign in to comment.