Skip to content

Commit

Permalink
Feat/progress (#1335)
Browse files Browse the repository at this point in the history
* Add progress bar to downloading minio files

* Do not redownload cached files

There is now a way to force a cache clear, so always redownloading
is not useful anymore.

* Set typed values on dictionary to avoid TypeError from Config

* Add regression test for parsing booleans
  • Loading branch information
PGijsbers authored Sep 16, 2024
1 parent b4d038f commit 1d707e6
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 15 deletions.
3 changes: 3 additions & 0 deletions doc/progress.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ Changelog
next
~~~~~~

* ADD #1335: Improve MinIO support.
* Add progress bar for downloading MinIO files. Enable it with setting `show_progress` to true on either `openml.config` or the configuration file.
* When using `download_all_files`, files are only downloaded if they do not yet exist in the cache.
* MAINT #1340: Add Numpy 2.0 support. Update tests to work with scikit-learn <= 1.5.
* ADD #1342: Add HTTP header to requests to indicate they are from openml-python.

Expand Down
9 changes: 9 additions & 0 deletions examples/20_basic/simple_datasets_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@
X, y, categorical_indicator, attribute_names = dataset.get_data(
dataset_format="dataframe", target=dataset.default_target_attribute
)

############################################################################
# Tip: you can get a progress bar for dataset downloads, simply set it in
# the configuration. Either in code or in the configuration file
# (see also the introduction tutorial)

openml.config.show_progress = True


############################################################################
# Visualize the dataset
# =====================
Expand Down
15 changes: 9 additions & 6 deletions openml/_api_calls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# License: BSD 3-Clause
from __future__ import annotations

import contextlib
import hashlib
import logging
import math
Expand All @@ -26,6 +27,7 @@
OpenMLServerException,
OpenMLServerNoResult,
)
from .utils import ProgressBar

_HEADERS = {"user-agent": f"openml-python/{__version__}"}

Expand Down Expand Up @@ -161,12 +163,12 @@ def _download_minio_file(
proxy_client = ProxyManager(proxy) if proxy else None

client = minio.Minio(endpoint=parsed_url.netloc, secure=False, http_client=proxy_client)

try:
client.fget_object(
bucket_name=bucket,
object_name=object_name,
file_path=str(destination),
progress=ProgressBar() if config.show_progress else None,
request_headers=_HEADERS,
)
if destination.is_file() and destination.suffix == ".zip":
Expand Down Expand Up @@ -206,11 +208,12 @@ def _download_minio_bucket(source: str, destination: str | Path) -> None:
if file_object.object_name is None:
raise ValueError("Object name is None.")

_download_minio_file(
source=source.rsplit("/", 1)[0] + "/" + file_object.object_name.rsplit("/", 1)[1],
destination=Path(destination, file_object.object_name.rsplit("/", 1)[1]),
exists_ok=True,
)
with contextlib.suppress(FileExistsError): # Simply use cached version instead
_download_minio_file(
source=source.rsplit("/", 1)[0] + "/" + file_object.object_name.rsplit("/", 1)[1],
destination=Path(destination, file_object.object_name.rsplit("/", 1)[1]),
exists_ok=False,
)


def _download_text_file(
Expand Down
16 changes: 11 additions & 5 deletions openml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class _Config(TypedDict):
avoid_duplicate_runs: bool
retry_policy: Literal["human", "robot"]
connection_n_retries: int
show_progress: bool


def _create_log_handlers(create_file_handler: bool = True) -> None: # noqa: FBT001, FBT002
Expand Down Expand Up @@ -111,6 +112,7 @@ def set_file_log_level(file_output_level: int) -> None:
"avoid_duplicate_runs": True,
"retry_policy": "human",
"connection_n_retries": 5,
"show_progress": False,
}

# Default values are actually added here in the _setup() function which is
Expand All @@ -131,6 +133,7 @@ def get_server_base_url() -> str:


apikey: str = _defaults["apikey"]
show_progress: bool = _defaults["show_progress"]
# The current cache directory (without the server name)
_root_cache_directory = Path(_defaults["cachedir"])
avoid_duplicate_runs = _defaults["avoid_duplicate_runs"]
Expand Down Expand Up @@ -238,6 +241,7 @@ def _setup(config: _Config | None = None) -> None:
global server # noqa: PLW0603
global _root_cache_directory # noqa: PLW0603
global avoid_duplicate_runs # noqa: PLW0603
global show_progress # noqa: PLW0603

config_file = determine_config_file_path()
config_dir = config_file.parent
Expand All @@ -255,6 +259,7 @@ def _setup(config: _Config | None = None) -> None:
avoid_duplicate_runs = config["avoid_duplicate_runs"]
apikey = config["apikey"]
server = config["server"]
show_progress = config["show_progress"]
short_cache_dir = Path(config["cachedir"])
n_retries = int(config["connection_n_retries"])

Expand Down Expand Up @@ -328,11 +333,11 @@ def _parse_config(config_file: str | Path) -> _Config:
logger.info("Error opening file %s: %s", config_file, e.args[0])
config_file_.seek(0)
config.read_file(config_file_)
if isinstance(config["FAKE_SECTION"]["avoid_duplicate_runs"], str):
config["FAKE_SECTION"]["avoid_duplicate_runs"] = config["FAKE_SECTION"].getboolean(
"avoid_duplicate_runs"
) # type: ignore
return dict(config.items("FAKE_SECTION")) # type: ignore
configuration = dict(config.items("FAKE_SECTION"))
for boolean_field in ["avoid_duplicate_runs", "show_progress"]:
if isinstance(config["FAKE_SECTION"][boolean_field], str):
configuration[boolean_field] = config["FAKE_SECTION"].getboolean(boolean_field) # type: ignore
return configuration # type: ignore


def get_config_as_dict() -> _Config:
Expand All @@ -343,6 +348,7 @@ def get_config_as_dict() -> _Config:
"avoid_duplicate_runs": avoid_duplicate_runs,
"connection_n_retries": connection_n_retries,
"retry_policy": retry_policy,
"show_progress": show_progress,
}


Expand Down
7 changes: 3 additions & 4 deletions openml/datasets/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,10 +1262,9 @@ def _get_dataset_parquet(
if old_file_path.is_file():
old_file_path.rename(output_file_path)

# For this release, we want to be able to force a new download even if the
# parquet file is already present when ``download_all_files`` is set.
# For now, it would be the only way for the user to fetch the additional
# files in the bucket (no function exists on an OpenMLDataset to do this).
# The call below skips files already on disk, so avoids downloading the parquet file twice.
# To force the old behavior of always downloading everything, use `force_refresh_cache`
# of `get_dataset`
if download_all_files:
openml._api_calls._download_minio_bucket(source=url, destination=cache_directory)

Expand Down
38 changes: 38 additions & 0 deletions openml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import numpy as np
import pandas as pd
import xmltodict
from minio.helpers import ProgressType
from tqdm import tqdm

import openml
import openml._api_calls
Expand Down Expand Up @@ -471,3 +473,39 @@ def _create_lockfiles_dir() -> Path:
with contextlib.suppress(OSError):
path.mkdir(exist_ok=True, parents=True)
return path


class ProgressBar(ProgressType):
"""Progressbar for MinIO function's `progress` parameter."""

def __init__(self) -> None:
self._object_name = ""
self._progress_bar: tqdm | None = None

def set_meta(self, object_name: str, total_length: int) -> None:
"""Initializes the progress bar.
Parameters
----------
object_name: str
Not used.
total_length: int
File size of the object in bytes.
"""
self._object_name = object_name
self._progress_bar = tqdm(total=total_length, unit_scale=True, unit="B")

def update(self, length: int) -> None:
"""Updates the progress bar.
Parameters
----------
length: int
Number of bytes downloaded since last `update` call.
"""
if not self._progress_bar:
raise RuntimeError("Call `set_meta` before calling `update`.")
self._progress_bar.update(length)
if self._progress_bar.total <= self._progress_bar.n:
self._progress_bar.close()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"numpy>=1.6.2",
"minio",
"pyarrow",
"tqdm", # For MinIO download progress bars
"packaging",
]
requires-python = ">=3.8"
Expand Down
41 changes: 41 additions & 0 deletions tests/test_openml/test_api_calls.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from __future__ import annotations

import unittest.mock
from pathlib import Path
from typing import NamedTuple, Iterable, Iterator
from unittest import mock

import minio
import pytest

import openml
import openml.testing
from openml._api_calls import _download_minio_bucket


class TestConfig(openml.testing.TestBase):
Expand All @@ -30,3 +35,39 @@ def test_retry_on_database_error(self, Session_class_mock, _):
openml._api_calls._send_request("get", "/abc", {})

assert Session_class_mock.return_value.__enter__.return_value.get.call_count == 20

class FakeObject(NamedTuple):
object_name: str

class FakeMinio:
def __init__(self, objects: Iterable[FakeObject] | None = None):
self._objects = objects or []

def list_objects(self, *args, **kwargs) -> Iterator[FakeObject]:
yield from self._objects

def fget_object(self, object_name: str, file_path: str, *args, **kwargs) -> None:
if object_name in [obj.object_name for obj in self._objects]:
Path(file_path).write_text("foo")
return
raise FileNotFoundError


@mock.patch.object(minio, "Minio")
def test_download_all_files_observes_cache(mock_minio, tmp_path: Path) -> None:
some_prefix, some_filename = "some/prefix", "dataset.arff"
some_object_path = f"{some_prefix}/{some_filename}"
some_url = f"https://not.real.com/bucket/{some_object_path}"
mock_minio.return_value = FakeMinio(
objects=[
FakeObject(some_object_path),
],
)

_download_minio_bucket(source=some_url, destination=tmp_path)
time_created = (tmp_path / "dataset.arff").stat().st_ctime

_download_minio_bucket(source=some_url, destination=tmp_path)
time_modified = (tmp_path / some_filename).stat().st_mtime

assert time_created == time_modified
10 changes: 10 additions & 0 deletions tests/test_openml/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,13 @@ def test_configuration_file_not_overwritten_on_load():

assert config_file_content == new_file_content
assert "abcd" == read_config["apikey"]

def test_configuration_loads_booleans(tmp_path):
config_file_content = "avoid_duplicate_runs=true\nshow_progress=false"
with (tmp_path/"config").open("w") as config_file:
config_file.write(config_file_content)
read_config = openml.config._parse_config(tmp_path)

# Explicit test to avoid truthy/falsy modes of other types
assert True == read_config["avoid_duplicate_runs"]
assert False == read_config["show_progress"]

0 comments on commit 1d707e6

Please sign in to comment.