Skip to content

Commit

Permalink
Merge branch 'main' into test-py312
Browse files Browse the repository at this point in the history
  • Loading branch information
rkansal47 authored Oct 16, 2023
2 parents 2f273d9 + 957690a commit cae2bbe
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ docs/_build
/dist
*egg-info
*test*.py
*test*.ipynb
/datasets
.vscode

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_language_version:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
Expand Down
9 changes: 7 additions & 2 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
# Required
version: 2

# Set the OS, Python version and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.11"

# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/conf.py
Expand All @@ -13,6 +19,5 @@ sphinx:
formats: all

python:
version: 3.8
install:
- requirements: docs/requirements.txt
- requirements: docs/requirements.txt
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ loss.backward()

## Documentation

The full API reference and tutorials are available at [jetnet.readthedocs.io](https://jetnet.readthedocs.io/en/latest/). Tutorial notebooks are in the [tutorials](tutorials) folder, with more to come.
The full API reference and tutorials are available at [jetnet.readthedocs.io](https://jetnet.readthedocs.io/en/latest/). Tutorial notebooks are in the [tutorials](https://github.com/jet-net/JetNet/tree/main/tutorials) folder, with more to come.

More detailed information about each dataset can (or will) be found at [jet-net.github.io](https://jet-net.github.io/).
<!-- More detailed information about each dataset can (or will) be found at [jet-net.github.io](https://jet-net.github.io/). -->

## Contributing

Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ nbsphinx
numpy
readthedocs-sphinx-search
scipy
sphinx
sphinx<7
sphinx_rtd_theme
torch
tqdm
68 changes: 55 additions & 13 deletions jetnet/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from __future__ import annotations

import hashlib
import os
import sys
from os.path import exists
Expand Down Expand Up @@ -48,28 +49,69 @@ def download_progress_bar(file_url: str, file_dest: str):
sys.stdout.write("\n")


def checkDownloadZenodoDataset(data_dir: str, dataset_name: str, record_id: int, key: str):
"""Checks if dataset exists, if not downloads it from Zenodo, and returns the file path"""
file_path = f"{data_dir}/{key}"
if not exists(file_path):
os.system(f"mkdir -p {data_dir}")
file_url = getZenodoFileURL(record_id, key)
# from TorchVision
# https://github.com/pytorch/vision/blob/48f8473e21b0f3e425aabc60db201b68fedf59b3/torchvision/datasets/utils.py#L51-L66 # noqa: E501
def _calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
# Setting the `usedforsecurity` flag does not change anything about the functionality, but
# indicates that we are not using the MD5 checksum for cryptography. This enables its usage
# in restricted environments like FIPS.
if sys.version_info >= (3, 9):
md5 = hashlib.md5(usedforsecurity=False)
else:
md5 = hashlib.md5()
with open(fpath, "rb") as f:
# switch to simpler assignment operator once we support only Python >=3.8
# while chunk := f.read(chunk_size):
for chunk in iter(lambda: f.read(chunk_size), b""):
md5.update(chunk)
return md5.hexdigest()

print(f"Downloading {dataset_name} dataset to {file_path}")
download_progress_bar(file_url, file_path)

return file_path
def _check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
fmd5 = _calculate_md5(fpath, **kwargs)
return (md5 == fmd5), fmd5


def getZenodoFileURL(record_id: int, file_name: str) -> str:
"""Finds URL for downloading the file ``file_name`` from a Zenodo record."""
def _getZenodoFileURL(record_id: int, file_name: str) -> str:
"""Finds URL and md5 hash for downloading the file ``file_name`` from a Zenodo record."""

import requests

records_url = f"https://zenodo.org/api/records/{record_id}"
r = requests.get(records_url).json()
file_url = next(item for item in r["files"] if item["key"] == file_name)["links"]["self"]
return file_url
file = next(item for item in r["files"] if item["key"] == file_name)
file_url = file["links"]["self"]
md5 = file["checksum"].split("md5:")[1]
return file_url, md5


def checkDownloadZenodoDataset(data_dir: str, dataset_name: str, record_id: int, key: str) -> str:
"""
Checks if dataset exists and md5 hash matches;
if not, downloads it from Zenodo, and returns the file path.
"""
file_path = f"{data_dir}/{key}"
file_url, md5 = _getZenodoFileURL(record_id, key)

if exists(file_path):
match_md5, fmd5 = _check_md5(file_path, md5)
if not match_md5:
print(
f"MD5 hash of {file_path} does not match "
f"(expected md5:{md5}, got md5:{fmd5}), "
"removing existing file and re-downloading. "
"Please open an issue at https://github.com/jet-net/JetNet/issues/new "
"if you believe the matching is failing incorrectly."
)
os.remove(file_path)

if not exists(file_path):
os.makedirs(data_dir, exist_ok=True)

print(f"Downloading {dataset_name} dataset to {file_path}")
download_progress_bar(file_url, file_path)

return file_path


def getOrderedFeatures(
Expand Down
8 changes: 7 additions & 1 deletion tests/datasets/test_jetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

data_dir = "./datasets/jetnet"
DataClass = JetNet
jet_types = ["g", "q"] # faster testing than using full dataset
jet_types = ["g", "q"] # subset of jet types
gq_length = 177252 + 170679


Expand All @@ -24,6 +24,12 @@
)
@pytest.mark.parametrize("num_particles", [30, 75])
def test_getData(jet_types, num_particles, expected_length, class_id):
# test md5 checksum is working for one of the datasets
if jet_types == "q":
# write random data to file
with open(f"{data_dir}/q{'150' if num_particles > 30 else ''}.hdf5", "wb") as f:
f.write(np.random.bytes(100))

pf, jf = DataClass.getData(jet_types, data_dir, num_particles=num_particles)
assert pf.shape == (expected_length, num_particles, 4)
assert jf.shape == (expected_length, 5)
Expand Down
6 changes: 6 additions & 0 deletions tests/datasets/test_qgjets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
)
@pytest.mark.parametrize("file_list", [test_file_list_withbc, test_file_list_withoutbc])
def test_getData(jet_types, split, expected_length, class_id, file_list):
# test md5 checksum is working for one of the datasets
if jet_types == "q" and file_list == test_file_list_withoutbc:
# write random data to file
with open(data_dir + "/" + file_list[-1], "wb") as f:
f.write(np.random.bytes(100))

pf, jf = DataClass.getData(jet_types, data_dir, file_list=file_list, split=split)
assert pf.shape == (expected_length, num_particles, 4)
assert jf.shape == (expected_length, 1)
Expand Down
6 changes: 6 additions & 0 deletions tests/datasets/test_toptagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
],
)
def test_getData(jet_types, split, expected_length, class_id):
# test md5 checksum is working for one of the datasets
if jet_types == "top" and split == "valid":
# write random data to file
with open(f"{data_dir}/val.h5", "wb") as f:
f.write(np.random.bytes(100))

pf, jf = DataClass.getData(jet_types, data_dir, split=split)
assert pf.shape == (expected_length, num_particles, 4)
assert jf.shape == (expected_length, 5)
Expand Down

0 comments on commit cae2bbe

Please sign in to comment.