diff --git a/mltb2/data.py b/mltb2/data.py index 92c078e..d25b65d 100644 --- a/mltb2/data.py +++ b/mltb2/data.py @@ -12,6 +12,7 @@ import os from hashlib import sha256 +from io import StringIO from typing import Tuple import joblib @@ -113,3 +114,40 @@ def load_colon() -> Tuple[pd.Series, pd.DataFrame]: result = joblib.load(full_path) return result + + +def load_prostate_data() -> Tuple[pd.Series, pd.DataFrame]: + """Load prostate data. + + The data is loaded and parsed from `prostate data + `_. + + Returns: + Tuple containing labels and data. + """ + # download data file + url = "https://web.stanford.edu/~hastie/CASI_files/DATA/prostmat.csv" + page = requests.get(url, timeout=10) + page_str = page.text + + # check checksum of data file + page_hash = sha256(page_str.encode("utf-8")).hexdigest() + assert page_hash == "f1ccfd3c9a837c002ec5d6489ab139c231739c3611189be14d15ca5541b92036", page_hash + + data_df = pd.read_csv(StringIO(page_str)) + data_df = data_df.T + + labels = [] + for label in data_df.index: + if "control" in label: + labels.append(0) + elif "cancer" in label: + labels.append(1) + else: + assert False, "This must not happen!" + + data_df = data_df.reset_index(drop=True) # reset the index to default integer index + label_series = pd.Series(labels) + result = (label_series, data_df) + + return result diff --git a/mltb2/fasttext.py b/mltb2/fasttext.py index 85e096b..7b5cbbf 100644 --- a/mltb2/fasttext.py +++ b/mltb2/fasttext.py @@ -43,7 +43,7 @@ def get_model_path_and_download() -> str: url=model_url, sha256_checksum=sha256_checksum, ) - assert fetch_remote_file_path == model_full_path # noqa: S101 + assert fetch_remote_file_path == model_full_path return model_full_path diff --git a/pyproject.toml b/pyproject.toml index 7eba875..7fdf9ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,7 +136,8 @@ ignore = [ "S106", # Possible hardcoded password assigned to argument: "{}" "COM812", # Trailing comma missing "S101", # Use of `assert` detected - "PLR2004", # Magic value used in comparison, + "PLR2004", # Magic value used in comparison + "B011", # Do not `assert False` ] [tool.ruff.per-file-ignores] diff --git a/tests/test_data.py b/tests/test_data.py index 45a6acf..c737681 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -4,7 +4,7 @@ import pandas as pd -from mltb2.data import _load_colon_data, _load_colon_label, load_colon +from mltb2.data import _load_colon_data, _load_colon_label, load_colon, load_prostate_data def test_load_colon_data(): @@ -30,3 +30,14 @@ def test_load_colon(): assert isinstance(result[1], pd.DataFrame) assert result[0].shape == (62,) assert result[1].shape == (62, 2000) + + +def test_load_prostate_data(): + result = load_prostate_data() + assert result is not None + assert isinstance(result, tuple) + assert len(result) == 2 + assert isinstance(result[0], pd.Series) + assert isinstance(result[1], pd.DataFrame) + assert result[0].shape == (102,) + assert result[1].shape == (102, 6033)