From cc28b1dd2c47045702c40853d374e7e0c09928bb Mon Sep 17 00:00:00 2001 From: Pieter Gijsbers <p.gijsbers@tue.nl> Date: Sat, 25 Jan 2025 11:38:49 +0100 Subject: [PATCH] Hotfix/arff (#1388) * Allow skipping parquet download through environment variable * Allow skip of parquet file, fix bug if no pq file is returned * Declare the environment file in config.py --- openml/config.py | 1 + openml/datasets/dataset.py | 8 ++++++-- openml/datasets/functions.py | 12 ++++++++---- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/openml/config.py b/openml/config.py index a244a317e..d838b070a 100644 --- a/openml/config.py +++ b/openml/config.py @@ -23,6 +23,7 @@ file_handler: logging.handlers.RotatingFileHandler | None = None OPENML_CACHE_DIR_ENV_VAR = "OPENML_CACHE_DIR" +OPENML_SKIP_PARQUET_ENV_VAR = "OPENML_SKIP_PARQUET" class _Config(TypedDict): diff --git a/openml/datasets/dataset.py b/openml/datasets/dataset.py index b00c458e3..5190ac522 100644 --- a/openml/datasets/dataset.py +++ b/openml/datasets/dataset.py @@ -3,6 +3,7 @@ import gzip import logging +import os import pickle import re import warnings @@ -17,6 +18,7 @@ import xmltodict from openml.base import OpenMLBase +from openml.config import OPENML_SKIP_PARQUET_ENV_VAR from openml.exceptions import PyOpenMLError from .data_feature import OpenMLDataFeature @@ -358,8 +360,10 @@ def _download_data(self) -> None: # import required here to avoid circular import. from .functions import _get_dataset_arff, _get_dataset_parquet - if self._parquet_url is not None: - self.parquet_file = str(_get_dataset_parquet(self)) + skip_parquet = os.environ.get(OPENML_SKIP_PARQUET_ENV_VAR, "false").casefold() == "true" + if self._parquet_url is not None and not skip_parquet: + parquet_file = _get_dataset_parquet(self) + self.parquet_file = None if parquet_file is None else str(parquet_file) if self.parquet_file is None: self.data_file = str(_get_dataset_arff(self)) diff --git a/openml/datasets/functions.py b/openml/datasets/functions.py index 61577d9a2..3f3c709f9 100644 --- a/openml/datasets/functions.py +++ b/openml/datasets/functions.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import os import warnings from collections import OrderedDict from pathlib import Path @@ -20,6 +21,7 @@ import openml._api_calls import openml.utils +from openml.config import OPENML_SKIP_PARQUET_ENV_VAR from openml.exceptions import ( OpenMLHashException, OpenMLPrivateDatasetError, @@ -560,7 +562,10 @@ def get_dataset( # noqa: C901, PLR0912 if download_qualities: qualities_file = _get_dataset_qualities_file(did_cache_dir, dataset_id) - if "oml:parquet_url" in description and download_data: + parquet_file = None + skip_parquet = os.environ.get(OPENML_SKIP_PARQUET_ENV_VAR, "false").casefold() == "true" + download_parquet = "oml:parquet_url" in description and not skip_parquet + if download_parquet and (download_data or download_all_files): try: parquet_file = _get_dataset_parquet( description, @@ -568,12 +573,11 @@ def get_dataset( # noqa: C901, PLR0912 ) except urllib3.exceptions.MaxRetryError: parquet_file = None - else: - parquet_file = None arff_file = None if parquet_file is None and download_data: - logger.warning("Failed to download parquet, fallback on ARFF.") + if download_parquet: + logger.warning("Failed to download parquet, fallback on ARFF.") arff_file = _get_dataset_arff(description) remove_dataset_cache = False