Skip to content

Commit

Permalink
Merge pull request #51 from rkansal47/checksum
Browse files Browse the repository at this point in the history
feat: download option and verify integrity before loading
  • Loading branch information
rkansal47 authored Oct 26, 2023
2 parents 436091b + 4ac9b5c commit 494a7ce
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 28 deletions.
10 changes: 10 additions & 0 deletions jetnet/datasets/jetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class JetNet(JetDataset):
testing data respectively. Defaults to [0.7, 0.15, 0.15].
seed (int, optional): PyTorch manual seed - important to use the same seed for all
dataset splittings. Defaults to 42.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in the ``data_dir`` directory. If dataset is already downloaded, it is not
downloaded again. Defaults to False.
"""

_zenodo_record_ids = {"30": 6975118, "150": 6975117}
Expand Down Expand Up @@ -81,6 +84,7 @@ def __init__(
split: str = "train",
split_fraction: List[float] = [0.7, 0.15, 0.15],
seed: int = 42,
download: bool = False,
):
self.particle_data, self.jet_data = self.getData(
jet_type,
Expand All @@ -91,6 +95,7 @@ def __init__(
split,
split_fraction,
seed,
download,
)

super().__init__(
Expand Down Expand Up @@ -119,6 +124,7 @@ def getData(
split: str = "all",
split_fraction: List[float] = [0.7, 0.15, 0.15],
seed: int = 42,
download: bool = False,
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""
Downloads, if needed, and loads and returns JetNet data.
Expand All @@ -142,6 +148,9 @@ def getData(
testing data respectively. Defaults to [0.7, 0.15, 0.15].
seed (int, optional): PyTorch manual seed - important to use the same seed for all
dataset splittings. Defaults to 42.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in the ``data_dir`` directory. If dataset is already downloaded, it is not
downloaded again. Defaults to False.
Returns:
Tuple[Optional[np.ndarray], Optional[np.ndarray]]: particle data, jet data
Expand Down Expand Up @@ -170,6 +179,7 @@ def getData(
dataset_name=dname,
record_id=cls._zenodo_record_ids["150" if use_150 else "30"],
key=f"{dname}.hdf5",
download=download,
)

with h5py.File(hdf5_file, "r") as f:
Expand Down
10 changes: 10 additions & 0 deletions jetnet/datasets/qgjets.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class QuarkGluon(JetDataset):
dataset splittings. Defaults to 42.
file_list (List[str], optional): list of files to load, if full dataset is not required.
Defaults to None (will load all files).
download (bool, optional): If True, downloads the dataset from the internet and
puts it in the ``data_dir`` directory. If dataset is already downloaded, it is not
downloaded again. Defaults to False.
"""

_zenodo_record_id = 3164691
Expand Down Expand Up @@ -127,6 +130,7 @@ def __init__(
split_fraction: List[float] = [0.7, 0.15, 0.15],
seed: int = 42,
file_list: List[str] = None,
download: bool = False,
):
self.particle_data, self.jet_data = self.getData(
jet_type,
Expand All @@ -139,6 +143,7 @@ def __init__(
split_fraction,
seed,
file_list,
download,
)

super().__init__(
Expand Down Expand Up @@ -169,6 +174,7 @@ def getData(
split_fraction: List[float] = [0.7, 0.15, 0.15],
seed: int = 42,
file_list: List[str] = None,
download: bool = False,
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""
Downloads, if needed, and loads and returns Quark Gluon data.
Expand All @@ -194,6 +200,9 @@ def getData(
dataset splittings. Defaults to 42.
file_list (List[str], optional): list of files to load, if full dataset is not required.
Defaults to None (will load all files).
download (bool, optional): If True, downloads the dataset from the internet and
puts it in the ``data_dir`` directory. If dataset is already downloaded, it is not
downloaded again. Defaults to False.
Returns:
Tuple[Optional[np.ndarray], Optional[np.ndarray]]: particle data, jet data
Expand Down Expand Up @@ -221,6 +230,7 @@ def getData(
dataset_name=file_name,
record_id=cls._zenodo_record_id,
key=file_name,
download=download,
)

print(f"Loading {file_name}")
Expand Down
11 changes: 10 additions & 1 deletion jetnet/datasets/toptagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class TopTagging(JetDataset):
to 200.
split (str, optional): dataset split, out of {"train", "valid", "test", "all"}. Defaults
to "train".
download (bool, optional): If True, downloads the dataset from the internet and
puts it in the ``data_dir`` directory. If dataset is already downloaded, it is not
downloaded again. Defaults to False.
"""

_zenodo_record_id = 2603256
Expand All @@ -63,9 +66,10 @@ def __init__(
jet_transform: Optional[Callable] = None,
num_particles: int = max_num_particles,
split: str = "train",
download: bool = False,
):
self.particle_data, self.jet_data = self.getData(
jet_type, data_dir, particle_features, jet_features, num_particles, split
jet_type, data_dir, particle_features, jet_features, num_particles, split, download
)

super().__init__(
Expand All @@ -91,6 +95,7 @@ def getData(
jet_features: List[str] = all_jet_features,
num_particles: int = max_num_particles,
split: str = "all",
download: bool = False,
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""
Downloads, if needed, and loads and returns Top Quark Tagging data.
Expand All @@ -107,6 +112,9 @@ def getData(
Defaults to 200.
split (str, optional): dataset split, out of {"train", "valid", "test", "all"}. Defaults
to "all".
download (bool, optional): If True, downloads the dataset from the internet and
puts it in the ``data_dir`` directory. If dataset is already downloaded, it is not
downloaded again. Defaults to False.
Returns:
(Tuple[Optional[np.ndarray], Optional[np.ndarray]]): particle data, jet data
Expand Down Expand Up @@ -134,6 +142,7 @@ def getData(
dataset_name=cls._split_key_mapping[s],
record_id=cls._zenodo_record_id,
key=f"{cls._split_key_mapping[s]}.h5",
download=download,
)

data = np.array(pd.read_hdf(hdf5_file, key="table"))
Expand Down
60 changes: 40 additions & 20 deletions jetnet/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import hashlib
import os
import sys
from os.path import exists
from os.path import isfile
from typing import Any, List, Set, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -85,31 +85,51 @@ def _getZenodoFileURL(record_id: int, file_name: str) -> str:
return file_url, md5


def checkDownloadZenodoDataset(data_dir: str, dataset_name: str, record_id: int, key: str) -> str:
def checkDownloadZenodoDataset(
data_dir: str, dataset_name: str, record_id: int, key: str, download: bool
) -> str:
"""
Checks if dataset exists and md5 hash matches;
if not, downloads it from Zenodo, and returns the file path.
if not and download = True, downloads it from Zenodo, and returns the file path.
or if not and download = False, raises an error.
"""
file_path = f"{data_dir}/{key}"
file_url, md5 = _getZenodoFileURL(record_id, key)

if exists(file_path):
match_md5, fmd5 = _check_md5(file_path, md5)
if not match_md5:
print(
f"MD5 hash of {file_path} does not match "
f"(expected md5:{md5}, got md5:{fmd5}), "
"removing existing file and re-downloading. "
"Please open an issue at https://github.com/jet-net/JetNet/issues/new "
"if you believe the matching is failing incorrectly."
)
os.remove(file_path)

if not exists(file_path):
os.makedirs(data_dir, exist_ok=True)

print(f"Downloading {dataset_name} dataset to {file_path}")
download_progress_bar(file_url, file_path)
if download:
if isfile(file_path):
match_md5, fmd5 = _check_md5(file_path, md5)
if not match_md5:
print(
f"File corrupted - MD5 hash of {file_path} does not match: "
f"(expected md5:{md5}, got md5:{fmd5}), "
"removing existing file and re-downloading."
"\nPlease open an issue at https://github.com/jet-net/JetNet/issues/new "
"if you believe this is an error."
)
os.remove(file_path)

if not isfile(file_path):
os.makedirs(data_dir, exist_ok=True)

print(f"Downloading {dataset_name} dataset to {file_path}")
download_progress_bar(file_url, file_path)

if not isfile(file_path):
raise RuntimeError(
f"Dataset {dataset_name} not found at {file_path}, "
"you can use download=True to download it."
)

match_md5, fmd5 = _check_md5(file_path, md5)
if not match_md5:
raise RuntimeError(
f"File corrupted - MD5 hash of {file_path} does not match: "
f"(expected md5:{md5}, got md5:{fmd5}), "
"you can use download=True to re-download it."
"\nPlease open an issue at https://github.com/jet-net/JetNet/issues/new "
"if you believe this is an error."
)

return file_path

Expand Down
22 changes: 19 additions & 3 deletions tests/datasets/test_jetnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
from os.path import isfile

import numpy as np
import pytest
from pytest import approx
Expand All @@ -24,13 +27,26 @@
)
@pytest.mark.parametrize("num_particles", [30, 75])
def test_getData(jet_types, num_particles, expected_length, class_id):
# test md5 checksum is working for one of the datasets
# test getData errors and md5 checksum for one of the datasets
if jet_types == "q":
file_path = f"{data_dir}/q{'150' if num_particles > 30 else ''}.hdf5"

if isfile(file_path):
os.remove(file_path)

# should raise a RunetimeError since file doesn't exist
with pytest.raises(RuntimeError):
DataClass.getData(jet_types, data_dir, num_particles=num_particles)

# write random data to file
with open(f"{data_dir}/q{'150' if num_particles > 30 else ''}.hdf5", "wb") as f:
with open(file_path, "wb") as f:
f.write(np.random.bytes(100))

pf, jf = DataClass.getData(jet_types, data_dir, num_particles=num_particles)
# should raise a RunetimeError since file exists but is incorret
with pytest.raises(RuntimeError):
DataClass.getData(jet_types, data_dir, num_particles=num_particles)

pf, jf = DataClass.getData(jet_types, data_dir, num_particles=num_particles, download=True)
assert pf.shape == (expected_length, num_particles, 4)
assert jf.shape == (expected_length, 5)
if class_id is not None:
Expand Down
20 changes: 18 additions & 2 deletions tests/datasets/test_qgjets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
from os.path import isfile

import numpy as np
import pytest
from pytest import approx
Expand Down Expand Up @@ -36,11 +39,24 @@
def test_getData(jet_types, split, expected_length, class_id, file_list):
# test md5 checksum is working for one of the datasets
if jet_types == "q" and file_list == test_file_list_withoutbc:
file_path = data_dir + "/" + file_list[-1]

if isfile(file_path):
os.remove(file_path)

# should raise a RunetimeError since file doesn't exist
with pytest.raises(RuntimeError):
DataClass.getData(jet_types, data_dir, file_list=file_list, split=split)

# write random data to file
with open(data_dir + "/" + file_list[-1], "wb") as f:
with open(file_path, "wb") as f:
f.write(np.random.bytes(100))

pf, jf = DataClass.getData(jet_types, data_dir, file_list=file_list, split=split)
# should raise a RunetimeError since file exists but is incorret
with pytest.raises(RuntimeError):
DataClass.getData(jet_types, data_dir, file_list=file_list, split=split)

pf, jf = DataClass.getData(jet_types, data_dir, file_list=file_list, split=split, download=True)
assert pf.shape == (expected_length, num_particles, 4)
assert jf.shape == (expected_length, 1)
if class_id is not None:
Expand Down
20 changes: 18 additions & 2 deletions tests/datasets/test_toptagging.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
from os.path import isfile

import numpy as np
import pytest
from pytest import approx
Expand Down Expand Up @@ -34,11 +37,24 @@
def test_getData(jet_types, split, expected_length, class_id):
# test md5 checksum is working for one of the datasets
if jet_types == "top" and split == "valid":
file_path = f"{data_dir}/val.h5"

if isfile(file_path):
os.remove(file_path)

# should raise a RunetimeError since file doesn't exist
with pytest.raises(RuntimeError):
DataClass.getData(jet_types, data_dir, split=split)

# write random data to file
with open(f"{data_dir}/val.h5", "wb") as f:
with open(file_path, "wb") as f:
f.write(np.random.bytes(100))

pf, jf = DataClass.getData(jet_types, data_dir, split=split)
# should raise a RunetimeError since file exists but is incorret
with pytest.raises(RuntimeError):
DataClass.getData(jet_types, data_dir, split=split)

pf, jf = DataClass.getData(jet_types, data_dir, split=split, download=True)
assert pf.shape == (expected_length, num_particles, 4)
assert jf.shape == (expected_length, 5)
if class_id is not None:
Expand Down

0 comments on commit 494a7ce

Please sign in to comment.