Skip to content

Commit

Permalink
Hotfix/arff (#1388)
Browse files Browse the repository at this point in the history
* Allow skipping parquet download through environment variable

* Allow skip of parquet file, fix bug if no pq file is returned

* Declare the environment file in config.py
  • Loading branch information
PGijsbers authored Jan 25, 2025
1 parent a4fb848 commit cc28b1d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
1 change: 1 addition & 0 deletions openml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
file_handler: logging.handlers.RotatingFileHandler | None = None

OPENML_CACHE_DIR_ENV_VAR = "OPENML_CACHE_DIR"
OPENML_SKIP_PARQUET_ENV_VAR = "OPENML_SKIP_PARQUET"


class _Config(TypedDict):
Expand Down
8 changes: 6 additions & 2 deletions openml/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import gzip
import logging
import os
import pickle
import re
import warnings
Expand All @@ -17,6 +18,7 @@
import xmltodict

from openml.base import OpenMLBase
from openml.config import OPENML_SKIP_PARQUET_ENV_VAR
from openml.exceptions import PyOpenMLError

from .data_feature import OpenMLDataFeature
Expand Down Expand Up @@ -358,8 +360,10 @@ def _download_data(self) -> None:
# import required here to avoid circular import.
from .functions import _get_dataset_arff, _get_dataset_parquet

if self._parquet_url is not None:
self.parquet_file = str(_get_dataset_parquet(self))
skip_parquet = os.environ.get(OPENML_SKIP_PARQUET_ENV_VAR, "false").casefold() == "true"
if self._parquet_url is not None and not skip_parquet:
parquet_file = _get_dataset_parquet(self)
self.parquet_file = None if parquet_file is None else str(parquet_file)
if self.parquet_file is None:
self.data_file = str(_get_dataset_arff(self))

Expand Down
12 changes: 8 additions & 4 deletions openml/datasets/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
import os
import warnings
from collections import OrderedDict
from pathlib import Path
Expand All @@ -20,6 +21,7 @@

import openml._api_calls
import openml.utils
from openml.config import OPENML_SKIP_PARQUET_ENV_VAR
from openml.exceptions import (
OpenMLHashException,
OpenMLPrivateDatasetError,
Expand Down Expand Up @@ -560,20 +562,22 @@ def get_dataset( # noqa: C901, PLR0912
if download_qualities:
qualities_file = _get_dataset_qualities_file(did_cache_dir, dataset_id)

if "oml:parquet_url" in description and download_data:
parquet_file = None
skip_parquet = os.environ.get(OPENML_SKIP_PARQUET_ENV_VAR, "false").casefold() == "true"
download_parquet = "oml:parquet_url" in description and not skip_parquet
if download_parquet and (download_data or download_all_files):
try:
parquet_file = _get_dataset_parquet(
description,
download_all_files=download_all_files,
)
except urllib3.exceptions.MaxRetryError:
parquet_file = None
else:
parquet_file = None

arff_file = None
if parquet_file is None and download_data:
logger.warning("Failed to download parquet, fallback on ARFF.")
if download_parquet:
logger.warning("Failed to download parquet, fallback on ARFF.")
arff_file = _get_dataset_arff(description)

remove_dataset_cache = False
Expand Down

0 comments on commit cc28b1d

Please sign in to comment.