Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete the downloaded zip and folder in retrieve_dataset #2150

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions rai_test_utils/rai_test_utils/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

"""Namespace for utility functions used in tests."""

from .utils import is_valid_uuid, retrieve_dataset
from .utils import DOWNLOADED_DATASET_DIR, is_valid_uuid, retrieve_dataset

__all__ = [
"is_valid_uuid",
"retrieve_dataset"
"retrieve_dataset",
"DOWNLOADED_DATASET_DIR"
]
19 changes: 13 additions & 6 deletions rai_test_utils/rai_test_utils/utilities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# Licensed under the MIT License.

import os
import shutil
import uuid

DOWNLOADED_DATASET_DIR = 'datasets.4.27.2021'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems oddly specific. Isn't it going to change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we use a specific zip folder with the date it was created for these datasets (that are put on our blob storage for tests)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so if the dataset changes this need to be updated. How will we remember?



def is_valid_uuid(id: str):
"""Check if the given id is a valid uuid.
Expand All @@ -29,7 +32,7 @@ def retrieve_dataset(dataset, **kwargs):
:rtype: object
"""
# if data not extracted, download zip and extract
outdirname = 'datasets.4.27.2021'
outdirname = DOWNLOADED_DATASET_DIR
if not os.path.exists(outdirname):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the dataset is downloaded we don't re-download it here... we just re-use the downloaded one

try:
from urllib import urlretrieve
Expand All @@ -48,17 +51,21 @@ def retrieve_dataset(dataset, **kwargs):
if extension == '.npz':
# sparse format file
from scipy.sparse import load_npz
return load_npz(filepath)
in_memory_dataset = load_npz(filepath)
elif extension == '.svmlight':
from sklearn import datasets
return datasets.load_svmlight_file(filepath)
in_memory_dataset = datasets.load_svmlight_file(filepath)
elif extension == '.json':
import json
with open(filepath, encoding='utf-8') as f:
dataset = json.load(f)
return dataset
in_memory_dataset = json.load(f)
elif extension == '.csv':
import pandas as pd
return pd.read_csv(filepath, **kwargs)
in_memory_dataset = pd.read_csv(filepath, **kwargs)
else:
raise Exception('Unrecognized file extension: ' + extension)

shutil.rmtree(outdirname)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't this increase network calls and also chance tests fail due to networking issues? also I think this might increase test time a lot? we currently just re-use the downloaded file in all test cases instead of re-downloading it every time

os.remove(zipfilename)

return in_memory_dataset
11 changes: 10 additions & 1 deletion rai_test_utils/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.

from rai_test_utils.utilities import is_valid_uuid
import os

from rai_test_utils.utilities import (DOWNLOADED_DATASET_DIR, is_valid_uuid,
retrieve_dataset)


class TestUtils:
Expand All @@ -11,3 +14,9 @@ def test_is_valid_uuid(self):
assert not is_valid_uuid("123e4567-e89b-12d3-a456-42661417400g")
assert not is_valid_uuid("123e4567-e89b-12d3-a456-42661417400-")
assert not is_valid_uuid("123e4567-e89b-12d3-a456-42661417400-143")

def test_retrieve_dataset(self):
energy_data = retrieve_dataset('energyefficiency2012_data.train.csv')
assert energy_data is not None
assert not os.path.exists(DOWNLOADED_DATASET_DIR)
assert not os.path.exists(DOWNLOADED_DATASET_DIR + '.zip')