diff --git a/rai_test_utils/rai_test_utils/utilities/__init__.py b/rai_test_utils/rai_test_utils/utilities/__init__.py index e6321bad30..d259212e72 100644 --- a/rai_test_utils/rai_test_utils/utilities/__init__.py +++ b/rai_test_utils/rai_test_utils/utilities/__init__.py @@ -3,9 +3,10 @@ """Namespace for utility functions used in tests.""" -from .utils import is_valid_uuid, retrieve_dataset +from .utils import DOWNLOADED_DATASET_DIR, is_valid_uuid, retrieve_dataset __all__ = [ "is_valid_uuid", - "retrieve_dataset" + "retrieve_dataset", + "DOWNLOADED_DATASET_DIR" ] diff --git a/rai_test_utils/rai_test_utils/utilities/utils.py b/rai_test_utils/rai_test_utils/utilities/utils.py index 4c2081150d..c3be9f4d35 100644 --- a/rai_test_utils/rai_test_utils/utilities/utils.py +++ b/rai_test_utils/rai_test_utils/utilities/utils.py @@ -2,8 +2,11 @@ # Licensed under the MIT License. import os +import shutil import uuid +DOWNLOADED_DATASET_DIR = 'datasets.4.27.2021' + def is_valid_uuid(id: str): """Check if the given id is a valid uuid. @@ -29,7 +32,7 @@ def retrieve_dataset(dataset, **kwargs): :rtype: object """ # if data not extracted, download zip and extract - outdirname = 'datasets.4.27.2021' + outdirname = DOWNLOADED_DATASET_DIR if not os.path.exists(outdirname): try: from urllib import urlretrieve @@ -48,17 +51,21 @@ def retrieve_dataset(dataset, **kwargs): if extension == '.npz': # sparse format file from scipy.sparse import load_npz - return load_npz(filepath) + in_memory_dataset = load_npz(filepath) elif extension == '.svmlight': from sklearn import datasets - return datasets.load_svmlight_file(filepath) + in_memory_dataset = datasets.load_svmlight_file(filepath) elif extension == '.json': import json with open(filepath, encoding='utf-8') as f: - dataset = json.load(f) - return dataset + in_memory_dataset = json.load(f) elif extension == '.csv': import pandas as pd - return pd.read_csv(filepath, **kwargs) + in_memory_dataset = pd.read_csv(filepath, **kwargs) else: raise Exception('Unrecognized file extension: ' + extension) + + shutil.rmtree(outdirname) + os.remove(zipfilename) + + return in_memory_dataset diff --git a/rai_test_utils/tests/test_utils.py b/rai_test_utils/tests/test_utils.py index 7eb450cd3e..c0555fde98 100644 --- a/rai_test_utils/tests/test_utils.py +++ b/rai_test_utils/tests/test_utils.py @@ -1,7 +1,10 @@ # Copyright (c) Microsoft Corporation # Licensed under the MIT License. -from rai_test_utils.utilities import is_valid_uuid +import os + +from rai_test_utils.utilities import (DOWNLOADED_DATASET_DIR, is_valid_uuid, + retrieve_dataset) class TestUtils: @@ -11,3 +14,9 @@ def test_is_valid_uuid(self): assert not is_valid_uuid("123e4567-e89b-12d3-a456-42661417400g") assert not is_valid_uuid("123e4567-e89b-12d3-a456-42661417400-") assert not is_valid_uuid("123e4567-e89b-12d3-a456-42661417400-143") + + def test_retrieve_dataset(self): + energy_data = retrieve_dataset('energyefficiency2012_data.train.csv') + assert energy_data is not None + assert not os.path.exists(DOWNLOADED_DATASET_DIR) + assert not os.path.exists(DOWNLOADED_DATASET_DIR + '.zip')