diff --git a/openml/base.py b/openml/base.py index 35a9ce58f..565318132 100644 --- a/openml/base.py +++ b/openml/base.py @@ -15,7 +15,7 @@ class OpenMLBase(ABC): """Base object for functionality that is shared across entities.""" - def __repr__(self): + def __repr__(self) -> str: body_fields = self._get_repr_body_fields() return self._apply_repr_template(body_fields) @@ -59,7 +59,9 @@ def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]: # Should be implemented in the base class. pass - def _apply_repr_template(self, body_fields: List[Tuple[str, str]]) -> str: + def _apply_repr_template( + self, body_fields: List[Tuple[str, Union[str, int, List[str]]]] + ) -> str: """Generates the header and formats the body for string representation of the object. Parameters @@ -80,7 +82,7 @@ def _apply_repr_template(self, body_fields: List[Tuple[str, str]]) -> str: return header + body @abstractmethod - def _to_dict(self) -> "OrderedDict[str, OrderedDict]": + def _to_dict(self) -> "OrderedDict[str, OrderedDict[str, str]]": """Creates a dictionary representation of self. Uses OrderedDict to ensure consistent ordering when converting to xml. @@ -107,7 +109,7 @@ def _to_xml(self) -> str: encoding_specification, xml_body = xml_representation.split("\n", 1) return xml_body - def _get_file_elements(self) -> Dict: + def _get_file_elements(self) -> openml._api_calls.FILE_ELEMENTS_TYPE: """Get file_elements to upload to the server, called during Publish. Derived child classes should overwrite this method as necessary. @@ -116,7 +118,7 @@ def _get_file_elements(self) -> Dict: return {} @abstractmethod - def _parse_publish_response(self, xml_response: Dict): + def _parse_publish_response(self, xml_response: Dict[str, str]) -> None: """Parse the id from the xml_response and assign it to self.""" pass @@ -135,11 +137,16 @@ def publish(self) -> "OpenMLBase": self._parse_publish_response(xml_response) return self - def open_in_browser(self): + def open_in_browser(self) -> None: """Opens the OpenML web page corresponding to this object in your default browser.""" - webbrowser.open(self.openml_url) - - def push_tag(self, tag: str): + if self.openml_url is None: + raise ValueError( + "Cannot open element on OpenML.org when attribute `openml_url` is `None`" + ) + else: + webbrowser.open(self.openml_url) + + def push_tag(self, tag: str) -> None: """Annotates this entity with a tag on the server. Parameters @@ -149,7 +156,7 @@ def push_tag(self, tag: str): """ _tag_openml_base(self, tag) - def remove_tag(self, tag: str): + def remove_tag(self, tag: str) -> None: """Removes a tag from this entity on the server. Parameters diff --git a/openml/cli.py b/openml/cli.py index 039ac227c..83539cda5 100644 --- a/openml/cli.py +++ b/openml/cli.py @@ -55,7 +55,7 @@ def wait_until_valid_input( return response -def print_configuration(): +def print_configuration() -> None: file = config.determine_config_file_path() header = f"File '{file}' contains (or defaults to):" print(header) @@ -65,7 +65,7 @@ def print_configuration(): print(f"{field.ljust(max_key_length)}: {value}") -def verbose_set(field, value): +def verbose_set(field: str, value: str) -> None: config.set_field_in_config_file(field, value) print(f"{field} set to '{value}'.") @@ -295,7 +295,7 @@ def configure_field( verbose_set(field, value) -def configure(args: argparse.Namespace): +def configure(args: argparse.Namespace) -> None: """Calls the right submenu(s) to edit `args.field` in the configuration file.""" set_functions = { "apikey": configure_apikey, @@ -307,7 +307,7 @@ def configure(args: argparse.Namespace): "verbosity": configure_verbosity, } - def not_supported_yet(_): + def not_supported_yet(_: str) -> None: print(f"Setting '{args.field}' is not supported yet.") if args.field not in ["all", "none"]: diff --git a/openml/config.py b/openml/config.py index b68455a9b..fc1f9770e 100644 --- a/openml/config.py +++ b/openml/config.py @@ -9,7 +9,7 @@ import os from pathlib import Path import platform -from typing import Tuple, cast, Any, Optional +from typing import Dict, Optional, Tuple, Union, cast import warnings from io import StringIO @@ -19,10 +19,10 @@ logger = logging.getLogger(__name__) openml_logger = logging.getLogger("openml") console_handler = None -file_handler = None +file_handler = None # type: Optional[logging.Handler] -def _create_log_handlers(create_file_handler=True): +def _create_log_handlers(create_file_handler: bool = True) -> None: """Creates but does not attach the log handlers.""" global console_handler, file_handler if console_handler is not None or file_handler is not None: @@ -61,7 +61,7 @@ def _convert_log_levels(log_level: int) -> Tuple[int, int]: return openml_level, python_level -def _set_level_register_and_store(handler: logging.Handler, log_level: int): +def _set_level_register_and_store(handler: logging.Handler, log_level: int) -> None: """Set handler log level, register it if needed, save setting to config file if specified.""" oml_level, py_level = _convert_log_levels(log_level) handler.setLevel(py_level) @@ -73,13 +73,13 @@ def _set_level_register_and_store(handler: logging.Handler, log_level: int): openml_logger.addHandler(handler) -def set_console_log_level(console_output_level: int): +def set_console_log_level(console_output_level: int) -> None: """Set console output to the desired level and register it with openml logger if needed.""" global console_handler _set_level_register_and_store(cast(logging.Handler, console_handler), console_output_level) -def set_file_log_level(file_output_level: int): +def set_file_log_level(file_output_level: int) -> None: """Set file output to the desired level and register it with openml logger if needed.""" global file_handler _set_level_register_and_store(cast(logging.Handler, file_handler), file_output_level) @@ -139,7 +139,8 @@ def set_retry_policy(value: str, n_retries: Optional[int] = None) -> None: if value not in default_retries_by_policy: raise ValueError( - f"Detected retry_policy '{value}' but must be one of {default_retries_by_policy}" + f"Detected retry_policy '{value}' but must be one of " + f"{list(default_retries_by_policy.keys())}" ) if n_retries is not None and not isinstance(n_retries, int): raise TypeError(f"`n_retries` must be of type `int` or `None` but is `{type(n_retries)}`.") @@ -160,7 +161,7 @@ class ConfigurationForExamples: _test_apikey = "c0c42819af31e706efe1f4b88c23c6c1" @classmethod - def start_using_configuration_for_example(cls): + def start_using_configuration_for_example(cls) -> None: """Sets the configuration to connect to the test server with valid apikey. To configuration as was before this call is stored, and can be recovered @@ -187,7 +188,7 @@ def start_using_configuration_for_example(cls): ) @classmethod - def stop_using_configuration_for_example(cls): + def stop_using_configuration_for_example(cls) -> None: """Return to configuration as it was before `start_use_example_configuration`.""" if not cls._start_last_called: # We don't want to allow this because it will (likely) result in the `server` and @@ -200,8 +201,8 @@ def stop_using_configuration_for_example(cls): global server global apikey - server = cls._last_used_server - apikey = cls._last_used_key + server = cast(str, cls._last_used_server) + apikey = cast(str, cls._last_used_key) cls._start_last_called = False @@ -215,7 +216,7 @@ def determine_config_file_path() -> Path: return config_dir / "config" -def _setup(config=None): +def _setup(config: Optional[Dict[str, Union[str, int, bool]]] = None) -> None: """Setup openml package. Called on first import. Reads the config file and sets up apikey, server, cache appropriately. @@ -243,28 +244,22 @@ def _setup(config=None): cache_exists = True if config is None: - config = _parse_config(config_file) + config = cast(Dict[str, Union[str, int, bool]], _parse_config(config_file)) + config = cast(Dict[str, Union[str, int, bool]], config) - def _get(config, key): - return config.get("FAKE_SECTION", key) + avoid_duplicate_runs = bool(config.get("avoid_duplicate_runs")) - avoid_duplicate_runs = config.getboolean("FAKE_SECTION", "avoid_duplicate_runs") - else: - - def _get(config, key): - return config.get(key) - - avoid_duplicate_runs = config.get("avoid_duplicate_runs") + apikey = cast(str, config["apikey"]) + server = cast(str, config["server"]) + short_cache_dir = cast(str, config["cachedir"]) - apikey = _get(config, "apikey") - server = _get(config, "server") - short_cache_dir = _get(config, "cachedir") - - n_retries = _get(config, "connection_n_retries") - if n_retries is not None: - n_retries = int(n_retries) + tmp_n_retries = config["connection_n_retries"] + if tmp_n_retries is not None: + n_retries = int(tmp_n_retries) + else: + n_retries = None - set_retry_policy(_get(config, "retry_policy"), n_retries) + set_retry_policy(cast(str, config["retry_policy"]), n_retries) _root_cache_directory = os.path.expanduser(short_cache_dir) # create the cache subdirectory @@ -287,10 +282,10 @@ def _get(config, key): ) -def set_field_in_config_file(field: str, value: Any): +def set_field_in_config_file(field: str, value: str) -> None: """Overwrites the `field` in the configuration file with the new `value`.""" if field not in _defaults: - return ValueError(f"Field '{field}' is not valid and must be one of '{_defaults.keys()}'.") + raise ValueError(f"Field '{field}' is not valid and must be one of '{_defaults.keys()}'.") globals()[field] = value config_file = determine_config_file_path() @@ -308,7 +303,7 @@ def set_field_in_config_file(field: str, value: Any): fh.write(f"{f} = {value}\n") -def _parse_config(config_file: str): +def _parse_config(config_file: Union[str, Path]) -> Dict[str, str]: """Parse the config file, set up defaults.""" config = configparser.RawConfigParser(defaults=_defaults) @@ -326,11 +321,12 @@ def _parse_config(config_file: str): logger.info("Error opening file %s: %s", config_file, e.args[0]) config_file_.seek(0) config.read_file(config_file_) - return config + config_as_dict = {key: value for key, value in config.items("FAKE_SECTION")} + return config_as_dict -def get_config_as_dict(): - config = dict() +def get_config_as_dict() -> Dict[str, Union[str, int, bool]]: + config = dict() # type: Dict[str, Union[str, int, bool]] config["apikey"] = apikey config["server"] = server config["cachedir"] = _root_cache_directory @@ -340,7 +336,7 @@ def get_config_as_dict(): return config -def get_cache_directory(): +def get_cache_directory() -> str: """Get the current cache directory. This gets the cache directory for the current server relative @@ -366,7 +362,7 @@ def get_cache_directory(): return _cachedir -def set_root_cache_directory(root_cache_directory): +def set_root_cache_directory(root_cache_directory: str) -> None: """Set module-wide base cache directory. Sets the root cache directory, wherin the cache directories are diff --git a/openml/exceptions.py b/openml/exceptions.py index a86434f51..d403cccdd 100644 --- a/openml/exceptions.py +++ b/openml/exceptions.py @@ -1,6 +1,6 @@ # License: BSD 3-Clause -from typing import Optional +from typing import Optional, Set class PyOpenMLError(Exception): @@ -28,7 +28,7 @@ def __init__(self, message: str, code: Optional[int] = None, url: Optional[str] self.url = url super().__init__(message) - def __str__(self): + def __str__(self) -> str: return f"{self.url} returned code {self.code}: {self.message}" @@ -59,7 +59,7 @@ class OpenMLPrivateDatasetError(PyOpenMLError): class OpenMLRunsExistError(PyOpenMLError): """Indicates run(s) already exists on the server when they should not be duplicated.""" - def __init__(self, run_ids: set, message: str): + def __init__(self, run_ids: Set[int], message: str) -> None: if len(run_ids) < 1: raise ValueError("Set of run ids must be non-empty.") self.run_ids = run_ids diff --git a/openml/testing.py b/openml/testing.py index ecb9620e1..806516a8d 100644 --- a/openml/testing.py +++ b/openml/testing.py @@ -7,7 +7,7 @@ import shutil import sys import time -from typing import Dict, Union, cast +from typing import Dict, List, Optional, Tuple, Union, cast # noqa: F401 import unittest import pandas as pd import requests @@ -35,7 +35,8 @@ class TestBase(unittest.TestCase): "task": [], "study": [], "user": [], - } # type: dict + } # type: Dict[str, List[int]] + flow_name_tracker = [] # type: List[str] test_server = "https://test.openml.org/api/v1/xml" # amueller's read/write key that he will throw away later apikey = "610344db6388d9ba34f6db45a3cf71de" @@ -44,7 +45,7 @@ class TestBase(unittest.TestCase): logger = logging.getLogger("unit_tests_published_entities") logger.setLevel(logging.DEBUG) - def setUp(self, n_levels: int = 1): + def setUp(self, n_levels: int = 1) -> None: """Setup variables and temporary directories. In particular, this methods: @@ -100,7 +101,7 @@ def setUp(self, n_levels: int = 1): self.connection_n_retries = openml.config.connection_n_retries openml.config.set_retry_policy("robot", n_retries=20) - def tearDown(self): + def tearDown(self) -> None: os.chdir(self.cwd) try: shutil.rmtree(self.workdir) @@ -115,7 +116,9 @@ def tearDown(self): openml.config.retry_policy = self.retry_policy @classmethod - def _mark_entity_for_removal(self, entity_type, entity_id): + def _mark_entity_for_removal( + self, entity_type: str, entity_id: int, entity_name: Optional[str] = None + ) -> None: """Static record of entities uploaded to test server Dictionary of lists where the keys are 'entity_type'. @@ -127,9 +130,12 @@ def _mark_entity_for_removal(self, entity_type, entity_id): TestBase.publish_tracker[entity_type] = [entity_id] else: TestBase.publish_tracker[entity_type].append(entity_id) + if isinstance(entity_type, openml.flows.OpenMLFlow): + assert entity_name is not None + self.flow_name_tracker.append(entity_name) @classmethod - def _delete_entity_from_tracker(self, entity_type, entity): + def _delete_entity_from_tracker(self, entity_type: str, entity: int) -> None: """Deletes entity records from the static file_tracker Given an entity type and corresponding ID, deletes all entries, including @@ -141,7 +147,9 @@ def _delete_entity_from_tracker(self, entity_type, entity): if entity_type == "flow": delete_index = [ i - for i, (id_, _) in enumerate(TestBase.publish_tracker[entity_type]) + for i, (id_, _) in enumerate( + zip(TestBase.publish_tracker[entity_type], TestBase.flow_name_tracker) + ) if id_ == entity ][0] else: @@ -152,7 +160,7 @@ def _delete_entity_from_tracker(self, entity_type, entity): ][0] TestBase.publish_tracker[entity_type].pop(delete_index) - def _get_sentinel(self, sentinel=None): + def _get_sentinel(self, sentinel: Optional[str] = None) -> str: if sentinel is None: # Create a unique prefix for the flow. Necessary because the flow # is identified by its name and external version online. Having a @@ -164,7 +172,9 @@ def _get_sentinel(self, sentinel=None): sentinel = "TEST%s" % sentinel return sentinel - def _add_sentinel_to_flow_name(self, flow, sentinel=None): + def _add_sentinel_to_flow_name( + self, flow: openml.flows.OpenMLFlow, sentinel: Optional[str] = None + ) -> Tuple[openml.flows.OpenMLFlow, str]: sentinel = self._get_sentinel(sentinel=sentinel) flows_to_visit = list() flows_to_visit.append(flow) @@ -176,7 +186,7 @@ def _add_sentinel_to_flow_name(self, flow, sentinel=None): return flow, sentinel - def _check_dataset(self, dataset): + def _check_dataset(self, dataset: Dict[str, Union[str, int]]) -> None: self.assertEqual(type(dataset), dict) self.assertGreaterEqual(len(dataset), 2) self.assertIn("did", dataset) @@ -187,13 +197,13 @@ def _check_dataset(self, dataset): def _check_fold_timing_evaluations( self, - fold_evaluations: Dict, + fold_evaluations: Dict[str, Dict[int, Dict[int, float]]], num_repeats: int, num_folds: int, max_time_allowed: float = 60000.0, task_type: TaskType = TaskType.SUPERVISED_CLASSIFICATION, check_scores: bool = True, - ): + ) -> None: """ Checks whether the right timing measures are attached to the run (before upload). Test is only performed for versions >= Python3.3 @@ -245,7 +255,10 @@ def _check_fold_timing_evaluations( def check_task_existence( - task_type: TaskType, dataset_id: int, target_name: str, **kwargs + task_type: TaskType, + dataset_id: int, + target_name: str, + **kwargs: Dict[str, Union[str, int, Dict[str, Union[str, int, openml.tasks.TaskType]]]] ) -> Union[int, None]: """Checks if any task with exists on test server that matches the meta data. diff --git a/openml/utils.py b/openml/utils.py index ffcc308dd..80d9cf68c 100644 --- a/openml/utils.py +++ b/openml/utils.py @@ -91,7 +91,7 @@ def _tag_openml_base(oml_object: "OpenMLBase", tag: str, untag: bool = False): _tag_entity(api_type_alias, oml_object.id, tag, untag) -def _tag_entity(entity_type, entity_id, tag, untag=False): +def _tag_entity(entity_type, entity_id, tag, untag=False) -> List[str]: """ Function that tags or untags a given entity on OpenML. As the OpenML API tag functions all consist of the same format, this function covers diff --git a/tests/conftest.py b/tests/conftest.py index 43e2cc3ee..1962c5085 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -74,7 +74,7 @@ def compare_delete_files(old_list: List[pathlib.Path], new_list: List[pathlib.Pa logger.info("Deleted from local: {}".format(file)) -def delete_remote_files(tracker) -> None: +def delete_remote_files(tracker, flow_names) -> None: """Function that deletes the entities passed as input, from the OpenML test server The TestBase class in openml/testing.py has an attribute called publish_tracker. @@ -94,11 +94,11 @@ def delete_remote_files(tracker) -> None: # reordering to delete sub flows at the end of flows # sub-flows have shorter names, hence, sorting by descending order of flow name length if "flow" in tracker: + to_sort = list(zip(tracker["flow"], flow_names)) flow_deletion_order = [ - entity_id - for entity_id, _ in sorted(tracker["flow"], key=lambda x: len(x[1]), reverse=True) + entity_id for entity_id, _ in sorted(to_sort, key=lambda x: len(x[1]), reverse=True) ] - tracker["flow"] = flow_deletion_order + tracker["flow"] = [flow_deletion_order[1] for flow_id, _ in flow_deletion_order] # deleting all collected entities published to test server # 'run's are deleted first to prevent dependency issue of entities on deletion @@ -158,7 +158,7 @@ def pytest_sessionfinish() -> None: # Test file deletion logger.info("Deleting files uploaded to test server for worker {}".format(worker)) - delete_remote_files(TestBase.publish_tracker) + delete_remote_files(TestBase.publish_tracker, TestBase.flow_name_tracker) if worker == "master": # Local file deletion diff --git a/tests/test_flows/test_flow.py b/tests/test_flows/test_flow.py index c3c72f267..115735944 100644 --- a/tests/test_flows/test_flow.py +++ b/tests/test_flows/test_flow.py @@ -190,7 +190,7 @@ def test_publish_flow(self): flow, _ = self._add_sentinel_to_flow_name(flow, None) flow.publish() - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info("collected from {}: {}".format(__file__.split("/")[-1], flow.flow_id)) self.assertIsInstance(flow.flow_id, int) @@ -203,7 +203,7 @@ def test_publish_existing_flow(self, flow_exists_mock): with self.assertRaises(openml.exceptions.PyOpenMLError) as context_manager: flow.publish(raise_error_if_exists=True) - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info( "collected from {}: {}".format(__file__.split("/")[-1], flow.flow_id) ) @@ -218,7 +218,7 @@ def test_publish_flow_with_similar_components(self): flow = self.extension.model_to_flow(clf) flow, _ = self._add_sentinel_to_flow_name(flow, None) flow.publish() - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info("collected from {}: {}".format(__file__.split("/")[-1], flow.flow_id)) # For a flow where both components are published together, the upload # date should be equal @@ -237,7 +237,7 @@ def test_publish_flow_with_similar_components(self): flow1 = self.extension.model_to_flow(clf1) flow1, sentinel = self._add_sentinel_to_flow_name(flow1, None) flow1.publish() - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info("collected from {}: {}".format(__file__.split("/")[-1], flow1.flow_id)) # In order to assign different upload times to the flows! @@ -249,7 +249,7 @@ def test_publish_flow_with_similar_components(self): flow2 = self.extension.model_to_flow(clf2) flow2, _ = self._add_sentinel_to_flow_name(flow2, sentinel) flow2.publish() - TestBase._mark_entity_for_removal("flow", (flow2.flow_id, flow2.name)) + TestBase._mark_entity_for_removal("flow", flow2.flow_id, flow2.name) TestBase.logger.info("collected from {}: {}".format(__file__.split("/")[-1], flow2.flow_id)) # If one component was published before the other, the components in # the flow should have different upload dates @@ -261,7 +261,7 @@ def test_publish_flow_with_similar_components(self): # Child flow has different parameter. Check for storing the flow # correctly on the server should thus not check the child's parameters! flow3.publish() - TestBase._mark_entity_for_removal("flow", (flow3.flow_id, flow3.name)) + TestBase._mark_entity_for_removal("flow", flow3.flow_id, flow3.name) TestBase.logger.info("collected from {}: {}".format(__file__.split("/")[-1], flow3.flow_id)) @pytest.mark.sklearn @@ -278,7 +278,7 @@ def test_semi_legal_flow(self): flow, _ = self._add_sentinel_to_flow_name(flow, None) flow.publish() - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info("collected from {}: {}".format(__file__.split("/")[-1], flow.flow_id)) @pytest.mark.sklearn @@ -308,7 +308,7 @@ def test_publish_error(self, api_call_mock, flow_exists_mock, get_flow_mock): with self.assertRaises(ValueError) as context_manager: flow.publish() - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info( "collected from {}: {}".format(__file__.split("/")[-1], flow.flow_id) ) @@ -391,7 +391,7 @@ def test_existing_flow_exists(self): flow, _ = self._add_sentinel_to_flow_name(flow, None) # publish the flow flow = flow.publish() - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info( "collected from {}: {}".format(__file__.split("/")[-1], flow.flow_id) ) @@ -451,7 +451,7 @@ def test_sklearn_to_upload_to_flow(self): flow, sentinel = self._add_sentinel_to_flow_name(flow, None) flow.publish() - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info("collected from {}: {}".format(__file__.split("/")[-1], flow.flow_id)) self.assertIsInstance(flow.flow_id, int) diff --git a/tests/test_flows/test_flow_functions.py b/tests/test_flows/test_flow_functions.py index f2520cb36..7307ebb28 100644 --- a/tests/test_flows/test_flow_functions.py +++ b/tests/test_flows/test_flow_functions.py @@ -293,7 +293,7 @@ def test_sklearn_to_flow_list_of_lists(self): # Test flow is accepted by server self._add_sentinel_to_flow_name(flow) flow.publish() - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info("collected from {}: {}".format(__file__.split("/")[-1], flow.flow_id)) # Test deserialization works server_flow = openml.flows.get_flow(flow.flow_id, reinstantiate=True) @@ -313,7 +313,7 @@ def test_get_flow_reinstantiate_model(self): extension = openml.extensions.get_extension_by_model(model) flow = extension.model_to_flow(model) flow.publish(raise_error_if_exists=False) - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info("collected from {}: {}".format(__file__.split("/")[-1], flow.flow_id)) downloaded_flow = openml.flows.get_flow(flow.flow_id, reinstantiate=True) @@ -393,7 +393,7 @@ def test_get_flow_id(self): with patch("openml.utils._list_all", list_all): clf = sklearn.tree.DecisionTreeClassifier() flow = openml.extensions.get_extension_by_model(clf).model_to_flow(clf).publish() - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info( "collected from {}: {}".format(__file__.split("/")[-1], flow.flow_id) ) diff --git a/tests/test_runs/test_run_functions.py b/tests/test_runs/test_run_functions.py index 1f8d1df70..fa4fdcf92 100644 --- a/tests/test_runs/test_run_functions.py +++ b/tests/test_runs/test_run_functions.py @@ -262,7 +262,7 @@ def _remove_random_state(flow): flow, _ = self._add_sentinel_to_flow_name(flow, sentinel) if not openml.flows.flow_exists(flow.name, flow.external_version): flow.publish() - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info("collected from test_run_functions: {}".format(flow.flow_id)) task = openml.tasks.get_task(task_id) @@ -1221,7 +1221,7 @@ def test_run_with_illegal_flow_id_1(self): flow_orig = self.extension.model_to_flow(clf) try: flow_orig.publish() # ensures flow exist on server - TestBase._mark_entity_for_removal("flow", (flow_orig.flow_id, flow_orig.name)) + TestBase._mark_entity_for_removal("flow", flow_orig.flow_id, flow_orig.name) TestBase.logger.info("collected from test_run_functions: {}".format(flow_orig.flow_id)) except openml.exceptions.OpenMLServerException: # flow already exists @@ -1246,7 +1246,7 @@ def test_run_with_illegal_flow_id_1_after_load(self): flow_orig = self.extension.model_to_flow(clf) try: flow_orig.publish() # ensures flow exist on server - TestBase._mark_entity_for_removal("flow", (flow_orig.flow_id, flow_orig.name)) + TestBase._mark_entity_for_removal("flow", flow_orig.flow_id, flow_orig.name) TestBase.logger.info("collected from test_run_functions: {}".format(flow_orig.flow_id)) except openml.exceptions.OpenMLServerException: # flow already exists @@ -1582,7 +1582,7 @@ def test_run_flow_on_task_downloaded_flow(self): model = sklearn.ensemble.RandomForestClassifier(n_estimators=33) flow = self.extension.model_to_flow(model) flow.publish(raise_error_if_exists=False) - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info("collected from test_run_functions: {}".format(flow.flow_id)) downloaded_flow = openml.flows.get_flow(flow.flow_id) diff --git a/tests/test_setups/test_setup_functions.py b/tests/test_setups/test_setup_functions.py index 33b2a5551..195eef605 100644 --- a/tests/test_setups/test_setup_functions.py +++ b/tests/test_setups/test_setup_functions.py @@ -44,7 +44,7 @@ def test_nonexisting_setup_exists(self): flow = self.extension.model_to_flow(dectree) flow.name = "TEST%s%s" % (sentinel, flow.name) flow.publish() - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info("collected from {}: {}".format(__file__.split("/")[-1], flow.flow_id)) # although the flow exists (created as of previous statement), @@ -57,7 +57,7 @@ def _existing_setup_exists(self, classif): flow = self.extension.model_to_flow(classif) flow.name = "TEST%s%s" % (get_sentinel(), flow.name) flow.publish() - TestBase._mark_entity_for_removal("flow", (flow.flow_id, flow.name)) + TestBase._mark_entity_for_removal("flow", flow.flow_id, flow.name) TestBase.logger.info("collected from {}: {}".format(__file__.split("/")[-1], flow.flow_id)) # although the flow exists, we can be sure there are no