Skip to content

Commit

Permalink
Add more type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
mfeurer committed Jul 11, 2023
1 parent 5d2128a commit 34d5a31
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 93 deletions.
27 changes: 17 additions & 10 deletions openml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class OpenMLBase(ABC):
"""Base object for functionality that is shared across entities."""

def __repr__(self):
def __repr__(self) -> str:
body_fields = self._get_repr_body_fields()
return self._apply_repr_template(body_fields)

Expand Down Expand Up @@ -59,7 +59,9 @@ def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]:
# Should be implemented in the base class.
pass

def _apply_repr_template(self, body_fields: List[Tuple[str, str]]) -> str:
def _apply_repr_template(
self, body_fields: List[Tuple[str, Union[str, int, List[str]]]]
) -> str:
"""Generates the header and formats the body for string representation of the object.
Parameters
Expand All @@ -80,7 +82,7 @@ def _apply_repr_template(self, body_fields: List[Tuple[str, str]]) -> str:
return header + body

@abstractmethod
def _to_dict(self) -> "OrderedDict[str, OrderedDict]":
def _to_dict(self) -> "OrderedDict[str, OrderedDict[str, str]]":
"""Creates a dictionary representation of self.
Uses OrderedDict to ensure consistent ordering when converting to xml.
Expand All @@ -107,7 +109,7 @@ def _to_xml(self) -> str:
encoding_specification, xml_body = xml_representation.split("\n", 1)
return xml_body

def _get_file_elements(self) -> Dict:
def _get_file_elements(self) -> openml._api_calls.FILE_ELEMENTS_TYPE:
"""Get file_elements to upload to the server, called during Publish.
Derived child classes should overwrite this method as necessary.
Expand All @@ -116,7 +118,7 @@ def _get_file_elements(self) -> Dict:
return {}

@abstractmethod
def _parse_publish_response(self, xml_response: Dict):
def _parse_publish_response(self, xml_response: Dict[str, str]) -> None:
"""Parse the id from the xml_response and assign it to self."""
pass

Expand All @@ -135,11 +137,16 @@ def publish(self) -> "OpenMLBase":
self._parse_publish_response(xml_response)
return self

def open_in_browser(self):
def open_in_browser(self) -> None:
"""Opens the OpenML web page corresponding to this object in your default browser."""
webbrowser.open(self.openml_url)

def push_tag(self, tag: str):
if self.openml_url is None:
raise ValueError(
"Cannot open element on OpenML.org when attribute `openml_url` is `None`"
)
else:
webbrowser.open(self.openml_url)

def push_tag(self, tag: str) -> None:
"""Annotates this entity with a tag on the server.
Parameters
Expand All @@ -149,7 +156,7 @@ def push_tag(self, tag: str):
"""
_tag_openml_base(self, tag)

def remove_tag(self, tag: str):
def remove_tag(self, tag: str) -> None:
"""Removes a tag from this entity on the server.
Parameters
Expand Down
8 changes: 4 additions & 4 deletions openml/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def wait_until_valid_input(
return response


def print_configuration():
def print_configuration() -> None:
file = config.determine_config_file_path()
header = f"File '{file}' contains (or defaults to):"
print(header)
Expand All @@ -65,7 +65,7 @@ def print_configuration():
print(f"{field.ljust(max_key_length)}: {value}")


def verbose_set(field, value):
def verbose_set(field: str, value: str) -> None:
config.set_field_in_config_file(field, value)
print(f"{field} set to '{value}'.")

Expand Down Expand Up @@ -295,7 +295,7 @@ def configure_field(
verbose_set(field, value)


def configure(args: argparse.Namespace):
def configure(args: argparse.Namespace) -> None:
"""Calls the right submenu(s) to edit `args.field` in the configuration file."""
set_functions = {
"apikey": configure_apikey,
Expand All @@ -307,7 +307,7 @@ def configure(args: argparse.Namespace):
"verbosity": configure_verbosity,
}

def not_supported_yet(_):
def not_supported_yet(_: str) -> None:
print(f"Setting '{args.field}' is not supported yet.")

if args.field not in ["all", "none"]:
Expand Down
72 changes: 34 additions & 38 deletions openml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
from pathlib import Path
import platform
from typing import Tuple, cast, Any, Optional
from typing import Dict, Optional, Tuple, Union, cast
import warnings

from io import StringIO
Expand All @@ -19,10 +19,10 @@
logger = logging.getLogger(__name__)
openml_logger = logging.getLogger("openml")
console_handler = None
file_handler = None
file_handler = None # type: Optional[logging.Handler]


def _create_log_handlers(create_file_handler=True):
def _create_log_handlers(create_file_handler: bool = True) -> None:
"""Creates but does not attach the log handlers."""
global console_handler, file_handler
if console_handler is not None or file_handler is not None:
Expand Down Expand Up @@ -61,7 +61,7 @@ def _convert_log_levels(log_level: int) -> Tuple[int, int]:
return openml_level, python_level


def _set_level_register_and_store(handler: logging.Handler, log_level: int):
def _set_level_register_and_store(handler: logging.Handler, log_level: int) -> None:
"""Set handler log level, register it if needed, save setting to config file if specified."""
oml_level, py_level = _convert_log_levels(log_level)
handler.setLevel(py_level)
Expand All @@ -73,13 +73,13 @@ def _set_level_register_and_store(handler: logging.Handler, log_level: int):
openml_logger.addHandler(handler)


def set_console_log_level(console_output_level: int):
def set_console_log_level(console_output_level: int) -> None:
"""Set console output to the desired level and register it with openml logger if needed."""
global console_handler
_set_level_register_and_store(cast(logging.Handler, console_handler), console_output_level)


def set_file_log_level(file_output_level: int):
def set_file_log_level(file_output_level: int) -> None:
"""Set file output to the desired level and register it with openml logger if needed."""
global file_handler
_set_level_register_and_store(cast(logging.Handler, file_handler), file_output_level)
Expand Down Expand Up @@ -139,7 +139,8 @@ def set_retry_policy(value: str, n_retries: Optional[int] = None) -> None:

if value not in default_retries_by_policy:
raise ValueError(
f"Detected retry_policy '{value}' but must be one of {default_retries_by_policy}"
f"Detected retry_policy '{value}' but must be one of "
f"{list(default_retries_by_policy.keys())}"
)
if n_retries is not None and not isinstance(n_retries, int):
raise TypeError(f"`n_retries` must be of type `int` or `None` but is `{type(n_retries)}`.")
Expand All @@ -160,7 +161,7 @@ class ConfigurationForExamples:
_test_apikey = "c0c42819af31e706efe1f4b88c23c6c1"

@classmethod
def start_using_configuration_for_example(cls):
def start_using_configuration_for_example(cls) -> None:
"""Sets the configuration to connect to the test server with valid apikey.
To configuration as was before this call is stored, and can be recovered
Expand All @@ -187,7 +188,7 @@ def start_using_configuration_for_example(cls):
)

@classmethod
def stop_using_configuration_for_example(cls):
def stop_using_configuration_for_example(cls) -> None:
"""Return to configuration as it was before `start_use_example_configuration`."""
if not cls._start_last_called:
# We don't want to allow this because it will (likely) result in the `server` and
Expand All @@ -200,8 +201,8 @@ def stop_using_configuration_for_example(cls):
global server
global apikey

server = cls._last_used_server
apikey = cls._last_used_key
server = cast(str, cls._last_used_server)
apikey = cast(str, cls._last_used_key)
cls._start_last_called = False


Expand All @@ -215,7 +216,7 @@ def determine_config_file_path() -> Path:
return config_dir / "config"


def _setup(config=None):
def _setup(config: Optional[Dict[str, Union[str, int, bool]]] = None) -> None:
"""Setup openml package. Called on first import.
Reads the config file and sets up apikey, server, cache appropriately.
Expand Down Expand Up @@ -243,28 +244,22 @@ def _setup(config=None):
cache_exists = True

if config is None:
config = _parse_config(config_file)
config = cast(Dict[str, Union[str, int, bool]], _parse_config(config_file))
config = cast(Dict[str, Union[str, int, bool]], config)

def _get(config, key):
return config.get("FAKE_SECTION", key)
avoid_duplicate_runs = bool(config.get("avoid_duplicate_runs"))

avoid_duplicate_runs = config.getboolean("FAKE_SECTION", "avoid_duplicate_runs")
else:

def _get(config, key):
return config.get(key)

avoid_duplicate_runs = config.get("avoid_duplicate_runs")
apikey = cast(str, config["apikey"])
server = cast(str, config["server"])
short_cache_dir = cast(str, config["cachedir"])

apikey = _get(config, "apikey")
server = _get(config, "server")
short_cache_dir = _get(config, "cachedir")

n_retries = _get(config, "connection_n_retries")
if n_retries is not None:
n_retries = int(n_retries)
tmp_n_retries = config["connection_n_retries"]
if tmp_n_retries is not None:
n_retries = int(tmp_n_retries)
else:
n_retries = None

set_retry_policy(_get(config, "retry_policy"), n_retries)
set_retry_policy(cast(str, config["retry_policy"]), n_retries)

_root_cache_directory = os.path.expanduser(short_cache_dir)
# create the cache subdirectory
Expand All @@ -287,10 +282,10 @@ def _get(config, key):
)


def set_field_in_config_file(field: str, value: Any):
def set_field_in_config_file(field: str, value: str) -> None:
"""Overwrites the `field` in the configuration file with the new `value`."""
if field not in _defaults:
return ValueError(f"Field '{field}' is not valid and must be one of '{_defaults.keys()}'.")
raise ValueError(f"Field '{field}' is not valid and must be one of '{_defaults.keys()}'.")

globals()[field] = value
config_file = determine_config_file_path()
Expand All @@ -308,7 +303,7 @@ def set_field_in_config_file(field: str, value: Any):
fh.write(f"{f} = {value}\n")


def _parse_config(config_file: str):
def _parse_config(config_file: Union[str, Path]) -> Dict[str, str]:
"""Parse the config file, set up defaults."""
config = configparser.RawConfigParser(defaults=_defaults)

Expand All @@ -326,11 +321,12 @@ def _parse_config(config_file: str):
logger.info("Error opening file %s: %s", config_file, e.args[0])
config_file_.seek(0)
config.read_file(config_file_)
return config
config_as_dict = {key: value for key, value in config.items("FAKE_SECTION")}
return config_as_dict


def get_config_as_dict():
config = dict()
def get_config_as_dict() -> Dict[str, Union[str, int, bool]]:
config = dict() # type: Dict[str, Union[str, int, bool]]
config["apikey"] = apikey
config["server"] = server
config["cachedir"] = _root_cache_directory
Expand All @@ -340,7 +336,7 @@ def get_config_as_dict():
return config


def get_cache_directory():
def get_cache_directory() -> str:
"""Get the current cache directory.
This gets the cache directory for the current server relative
Expand All @@ -366,7 +362,7 @@ def get_cache_directory():
return _cachedir


def set_root_cache_directory(root_cache_directory):
def set_root_cache_directory(root_cache_directory: str) -> None:
"""Set module-wide base cache directory.
Sets the root cache directory, wherin the cache directories are
Expand Down
6 changes: 3 additions & 3 deletions openml/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# License: BSD 3-Clause

from typing import Optional
from typing import Optional, Set


class PyOpenMLError(Exception):
Expand Down Expand Up @@ -28,7 +28,7 @@ def __init__(self, message: str, code: Optional[int] = None, url: Optional[str]
self.url = url
super().__init__(message)

def __str__(self):
def __str__(self) -> str:
return f"{self.url} returned code {self.code}: {self.message}"


Expand Down Expand Up @@ -59,7 +59,7 @@ class OpenMLPrivateDatasetError(PyOpenMLError):
class OpenMLRunsExistError(PyOpenMLError):
"""Indicates run(s) already exists on the server when they should not be duplicated."""

def __init__(self, run_ids: set, message: str):
def __init__(self, run_ids: Set[int], message: str) -> None:
if len(run_ids) < 1:
raise ValueError("Set of run ids must be non-empty.")
self.run_ids = run_ids
Expand Down
Loading

0 comments on commit 34d5a31

Please sign in to comment.