Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add datasets module #minor #117

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions docs/examples/plot_lcppn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,11 @@
"""
from sklearn.ensemble import RandomForestClassifier
from hiclass import LocalClassifierPerParentNode, Explainer
import requests
import pandas as pd
import shap
from hiclass.datasets import load_platypus

# Download training data
url = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/1152c0b9613c2bda144a38fc4f74b5fe12255f4d/platypus_diseases.csv"
path = "platypus_diseases.csv"
response = requests.get(url)
with open(path, "wb") as file:
file.write(response.content)

# Load training data into pandas dataframe
training_data = pd.read_csv(path).fillna(" ")

# Define data
X_train = training_data.drop(["label"], axis=1)
X_test = X_train[:100] # Use first 100 samples as test set
Y_train = training_data["label"]
Y_train = [eval(my) for my in Y_train]
# Load train and test splits
X_train, X_test, Y_train, Y_test = load_platypus()

# Use random forest classifiers for every node
rfc = RandomForestClassifier()
Expand Down
20 changes: 3 additions & 17 deletions docs/examples/plot_parallel_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,15 @@
"""
import sys
from os import cpu_count

import pandas as pd
import requests
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline

from hiclass import LocalClassifierPerParentNode
from hiclass.datasets import load_hierarchical_text_classification


# Download training data
url = "https://zenodo.org/record/6657410/files/train_40k.csv?download=1"
path = "train_40k.csv"
response = requests.get(url)
with open(path, "wb") as file:
file.write(response.content)

# Load training data into pandas dataframe
training_data = pd.read_csv(path).fillna(" ")
# Load train and test splits
X_train, X_test, Y_train, Y_test = load_hierarchical_text_classification()

# We will use logistic regression classifiers for every parent node
lr = LogisticRegression(max_iter=1000)
Expand All @@ -51,10 +41,6 @@
]
)

# Select training data
X_train = training_data["Title"]
Y_train = training_data[["Cat1", "Cat2", "Cat3"]]

# Fixes bug AttributeError: '_LoggingTee' object has no attribute 'fileno'
# This only happens when building the documentation
# Hence, you don't actually need it for your code to work
Expand Down
20 changes: 20 additions & 0 deletions docs/source/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,23 @@ F-score
^^^^^^^

.. autofunction:: metrics.f1

..................................


Datasets
----------

Platypus diseases dataset
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: datasets.load_platypus

..................................

Hierarchical text classification dataset
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: datasets.load_hierarchical_text_classification

..................................
1 change: 1 addition & 0 deletions hiclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@
"Explainer",
"MultiLabelLocalClassifierPerNode",
"MultiLabelLocalClassifierPerParentNode",
"datasets",
]
138 changes: 138 additions & 0 deletions hiclass/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Datasets util for downloading and maintaining sample datasets."""

import requests
import pandas as pd
import os
import tempfile
import logging
from sklearn.model_selection import train_test_split

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Use temp directory to store cached datasets
CACHE_DIR = tempfile.gettempdir()

# Ensure cache directory exists
os.makedirs(CACHE_DIR, exist_ok=True)

# Dataset urls
PLATYPUS_URL = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/1152c0b9613c2bda144a38fc4f74b5fe12255f4d/platypus_diseases.csv"
HIERARCHICAL_TEXT_CLASSIFICATION_URL = (
"https://zenodo.org/record/6657410/files/train_40k.csv?download=1"
)


def _download_file(url, destination):
"""Download file from given URL to specified destination."""
try:
response = requests.get(url)
# Raise HTTPError if response code is not OK
response.raise_for_status()
with open(destination, "wb") as f:
f.write(response.content)
except requests.RequestException as e:
raise RuntimeError(f"Failed to download file from {url}: {str(e)}")


def load_platypus(test_size=0.3, random_state=42):
"""
Load platypus diseases dataset.

Parameters
----------
test_size : float, default=0.3
The proportion of the dataset to include in the test split.
random_state : int or None, default=42
Controls the randomness of the dataset. Pass an int for reproducible output across multiple function calls.

Returns
-------
list
List containing train-test split of inputs.

Raises
------
RuntimeError
If failed to access or process the dataset.
Examples
--------
>>> from hiclass.datasets import load_platypus
>>> X_train, X_test, Y_train, Y_test = load_platypus()
>>> X_train[:3]
fever diarrhea stomach pain skin rash cough sniffles short breath headache size
220 37.8 0 3 5 1 1 0 2 27.6
539 37.2 0 6 1 1 1 0 3 28.4
326 39.9 0 2 5 1 1 1 2 30.7
>>> X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
(572, 9) (246, 9) (572,) (246,)
"""
dataset_name = "platypus_diseases.csv"
cached_file_path = os.path.join(CACHE_DIR, dataset_name)

# Check if the file exists in the cache
if not os.path.exists(cached_file_path):
try:
logger.info("Downloading platypus diseases dataset..")
_download_file(PLATYPUS_URL, cached_file_path)
except Exception as e:
raise RuntimeError(f"Failed to access or download dataset: {str(e)}")

data = pd.read_csv(cached_file_path).fillna(" ")
X = data.drop(["label"], axis=1)
y = pd.Series([eval(val) for val in data["label"]])

# Return tuple (X_train, X_test, y_train, y_test)
return train_test_split(X, y, test_size=test_size, random_state=random_state)


def load_hierarchical_text_classification(test_size=0.3, random_state=42):
"""
Load hierarchical text classification dataset.

Parameters
----------
test_size : float, default=0.3
The proportion of the dataset to include in the test split.
random_state : int or None, default=42
Controls the randomness of the dataset. Pass an int for reproducible output across multiple function calls.

Returns
-------
list
List containing train-test split of inputs.

Raises
------
RuntimeError
If failed to access or process the dataset.
Examples
--------
>>> from hiclass.datasets import load_hierarchical_text_classification
>>> X_train, X_test, Y_train, Y_test = load_hierarchical_text_classification()
>>> X_train[:3]
38015 Nature's Way Selenium
2281 Music In Motion Developmental Mobile W Remote
36629 Twinings Ceylon Orange Pekoe Tea, Tea Bags, 20...
Name: Title, dtype: object
>>> X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
(28000,) (12000,) (28000, 3) (12000, 3)
"""
dataset_name = "hierarchical_text_classification.csv"
cached_file_path = os.path.join(CACHE_DIR, dataset_name)

# Check if the file exists in the cache
if not os.path.exists(cached_file_path):
try:
logger.info("Downloading hierarchical text classification dataset..")
_download_file(HIERARCHICAL_TEXT_CLASSIFICATION_URL, cached_file_path)
except Exception as e:
raise RuntimeError(f"Failed to access or download dataset: {str(e)}")

data = pd.read_csv(cached_file_path).fillna(" ")
X = data["Title"]
y = data[["Cat1", "Cat2", "Cat3"]]

# Return tuple (X_train, X_test, y_train, y_test)
return train_test_split(X, y, test_size=test_size, random_state=random_state)
121 changes: 121 additions & 0 deletions tests/test_Datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import numpy as np
import pytest

import hiclass.datasets
from hiclass.datasets import load_platypus, load_hierarchical_text_classification
import os
import tempfile


def test_load_platypus_output_shape():
X_train, X_test, y_train, y_test = load_platypus(test_size=0.2, random_state=42)
assert X_train.shape[0] == y_train.shape[0]
assert X_test.shape[0] == y_test.shape[0]


def test_load_platypus_random_state():
X_train_1, X_test_1, y_train_1, y_test_1 = load_platypus(
test_size=0.2, random_state=42
)
X_train_2, X_test_2, y_train_2, y_test_2 = load_platypus(
test_size=0.2, random_state=42
)
assert (X_train_1.values == X_train_2.values).all()
assert (X_test_1.values == X_test_2.values).all()
assert (y_train_1.index == y_train_2.index).all()
assert (y_test_1.index == y_test_2.index).all()


def test_load_hierarchical_text_classification_shape():
X_train, X_test, y_train, y_test = load_hierarchical_text_classification(
test_size=0.2, random_state=42
)
assert X_train.shape[0] == y_train.shape[0]
assert X_test.shape[0] == y_test.shape[0]


def test_load_hierarchical_text_classification_random_state():
X_train_1, X_test_1, y_train_1, y_test_1 = load_hierarchical_text_classification(
test_size=0.2, random_state=42
)
X_train_2, X_test_2, y_train_2, y_test_2 = load_hierarchical_text_classification(
test_size=0.2, random_state=42
)
assert (X_train_1 == X_train_2).all()
assert (X_test_1 == X_test_2).all()
assert (y_train_1.index == y_train_2.index).all()
assert (y_test_1.index == y_test_2.index).all()


def test_load_hierarchical_text_classification_file_exists():
dataset_name = "hierarchical_text_classification.csv"
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)

if os.path.exists(cached_file_path):
os.remove(cached_file_path)

if not os.path.exists(cached_file_path):
load_hierarchical_text_classification()
assert os.path.exists(cached_file_path)


def test_load_platypus_file_exists():
dataset_name = "platypus_diseases.csv"
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)

if os.path.exists(cached_file_path):
os.remove(cached_file_path)

if not os.path.exists(cached_file_path):
load_platypus()
assert os.path.exists(cached_file_path)


def test_download_dataset():
dataset_name = "platypus_diseases_test.csv"
url = hiclass.datasets.PLATYPUS_URL
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)

if os.path.exists(cached_file_path):
os.remove(cached_file_path)

if not os.path.exists(cached_file_path):
hiclass.datasets._download_file(url, cached_file_path)
assert os.path.exists(cached_file_path)


def test_download_error_load_platypus():
dataset_name = "platypus_diseases.csv"
backup_url = hiclass.datasets.PLATYPUS_URL
hiclass.datasets.PLATYPUS_URL = ""
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)

if os.path.exists(cached_file_path):
os.remove(cached_file_path)

if not os.path.exists(cached_file_path):
with pytest.raises(RuntimeError):
load_platypus()

hiclass.datasets.PLATYPUS_URL = backup_url


def test_download_error_load_hierarchical_text():
dataset_name = "hierarchical_text_classification.csv"
backup_url = hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL
hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL = ""
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)

if os.path.exists(cached_file_path):
os.remove(cached_file_path)

if not os.path.exists(cached_file_path):
with pytest.raises(RuntimeError):
load_hierarchical_text_classification()

hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL = backup_url


def test_url_links():
assert hiclass.datasets.PLATYPUS_URL != ""
assert hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL != ""
Loading